Source code for pipelinex.mlflow_on_kedro.hooks.mlflow.mlflow_artifacts_logger

from importlib.util import find_spec
from logging import getLogger
from typing import Any, Dict, List  # NOQA

from kedro.io import DataCatalog
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node

from .mlflow_utils import hook_impl, mlflow_log_artifacts

log = getLogger(__name__)


[docs]class MLflowArtifactsLoggerHook: """Logs artifacts of specified file paths and dataset names to MLflow"""
[docs] def __init__( self, filepaths_before_pipeline_run: List[str] = None, filepaths_after_pipeline_run: List[str] = None, datasets_after_node_run: List[str] = None, enable_mlflow: bool = True, ): """ Args: filepaths_before_pipeline_run: The file paths of artifacts to log before the pipeline is run. filepaths_after_pipeline_run: The file paths of artifacts to log after the pipeline is run. datasets_after_node_run: The dataset names to log after the node is run. enable_mlflow: Enable logging to MLflow. """ self.enable_mlflow = find_spec("mlflow") and enable_mlflow self.filepaths_before_pipeline_run = filepaths_before_pipeline_run or [] self.datasets_after_node_run = datasets_after_node_run or [] self.filepaths_after_pipeline_run = filepaths_after_pipeline_run or []
def _log_artifacts(self, artifacts=[]): mlflow_log_artifacts(artifacts, enable_mlflow=self.enable_mlflow)
[docs] @hook_impl def before_pipeline_run( self, run_params: Dict[str, Any], pipeline: Pipeline, catalog: DataCatalog ): self._log_artifacts(self.filepaths_before_pipeline_run)
[docs] @hook_impl def after_pipeline_run( self, run_params: Dict[str, Any], pipeline: Pipeline, catalog: DataCatalog ): self._log_artifacts(self.filepaths_after_pipeline_run)
def _log_datasets(self, catalog, datasets): for d in datasets: ds = getattr(catalog.datasets, d, None) if not ds: log.warning("{} was not found in catalog.datasets.".format(d)) return fp = getattr(ds, "_filepath", None) if not fp: low_ds = getattr(ds, "_dataset", None) if low_ds: fp = getattr(low_ds, "_filepath", None) if not fp: log.warning("_filepath of '{}' was not found.".format(d)) return mlflow_log_artifacts(fp, enable_mlflow=self.enable_mlflow)
[docs] @hook_impl def after_node_run( self, node: Node, catalog: DataCatalog, inputs: Dict[str, Any], outputs: Dict[str, Any], ): datasets = [d for d in outputs.keys() if d in self.datasets_after_node_run] self._log_datasets(catalog, datasets)