Source code for pipelinex.extras.hooks.mlflow.mlflow_catalog_logger

from importlib.util import find_spec
from logging import getLogger
import tempfile
from typing import Any, Dict, Union  # NOQA

from kedro.io import AbstractDataSet, DataCatalog
from kedro.pipeline.node import Node
from kedro.pipeline import Pipeline

from .mlflow_utils import (
    hook_impl,
    mlflow_log_metrics,
    mlflow_log_params,
    mlflow_log_artifacts,
)


log = getLogger(__name__)


[docs]def get_kedro_runner(): import inspect from kedro.runner import AbstractRunner return next( caller[0].f_locals.get("runner") for caller in inspect.stack() if isinstance(caller[0].f_locals.get("runner"), AbstractRunner) )
[docs]def running_parallel(): from kedro.runner import ParallelRunner return isinstance(get_kedro_runner(), ParallelRunner)
datasets_dict = {} try: from kedro.extras.datasets.json import JSONDataSet datasets_dict["json"] = JSONDataSet except ImportError: pass try: from kedro.extras.datasets.pandas import CSVDataSet, ExcelDataSet, ParquetDataSet datasets_dict["csv"] = CSVDataSet datasets_dict["xls"] = ExcelDataSet datasets_dict["parquet"] = ParquetDataSet except ImportError: pass try: from kedro.extras.datasets.pickle import PickleDataSet datasets_dict["pkl"] = PickleDataSet datasets_dict["pickle"] = PickleDataSet except ImportError: pass try: from kedro.extras.datasets.pillow import ImageDataSet datasets_dict["png"] = ImageDataSet datasets_dict["jpg"] = ImageDataSet datasets_dict["jpeg"] = ImageDataSet datasets_dict["img"] = ImageDataSet except ImportError: pass try: from kedro.extras.datasets.text import TextDataSet datasets_dict["txt"] = TextDataSet except ImportError: pass try: from kedro.extras.datasets.yaml import YAMLDataSet datasets_dict["yml"] = YAMLDataSet datasets_dict["yaml"] = YAMLDataSet except ImportError: pass
[docs]def mlflow_log_dataset(dataset, enable_mlflow=True): fp = getattr(dataset, "_filepath", None) if not fp: low_ds = getattr(dataset, "_dataset", None) if low_ds: fp = getattr(low_ds, "_filepath", None) if not fp: log.warning("_filepath of '{}' was not found.".format(d)) return mlflow_log_artifacts([fp], enable_mlflow=enable_mlflow)
[docs]class MLflowCatalogLoggerHook: """Logs datasets to MLflow""" _logged_set = set()
[docs] def __init__( self, auto: bool = True, mlflow_catalog: Dict[str, Union[str, AbstractDataSet]] = {}, enable_mlflow: bool = True, ): """ Args: auto: If True, for each dataset (Python func input/output) not listed in mlflow_catalog, log as a metric for float and int types, and log as a param for str, list, tuple, dict, and set types. mlflow_catalog: Specify how to log each dataset (Python func input/output) in dict format. Specify "p" to log as a parameter, "m" to log as a metric, either a file extension, "json", "csv", "xls", "parquet", "png", "jpg", "jpeg", "img", "pkl", "txt", "yml", or "yaml", or Kedro DataSet instance to log as a corresponding file artifact. enable_mlflow: Enable logging to MLflow. """ self.enable_mlflow = find_spec("mlflow") and enable_mlflow self.mlflow_catalog = mlflow_catalog self.auto = auto
[docs] @hook_impl def before_pipeline_run( self, run_params: Dict[str, Any], pipeline: Pipeline, catalog: DataCatalog ) -> None: for dataset_name in catalog._data_sets: if catalog._data_sets[dataset_name].__class__.__name__ == "MLflowDataSet": setattr(catalog._data_sets[dataset_name], "_dataset_name", dataset_name) setattr( catalog._data_sets[dataset_name], "_running_parallel", running_parallel(), ) catalog._data_sets[dataset_name]._init_dataset()
[docs] @hook_impl def after_node_run( self, node: Node, catalog: DataCatalog, inputs: Dict[str, Any], outputs: Dict[str, Any], ): for name, value in inputs.items(): if name not in self._logged_set: self._logged_set.add(name) self._log_dataset(name, value) for name, value in outputs.items(): if name not in self._logged_set: self._logged_set.add(name) self._log_dataset(name, value)
def _log_dataset(self, name: str, value: Any): if name not in self.mlflow_catalog: if not self.auto: return if isinstance(value, (float, int)): mlflow_log_metrics({name: value}, enable_mlflow=self.enable_mlflow) return elif isinstance(value, (str, list, tuple, dict, set)): mlflow_log_params({name: value}, enable_mlflow=self.enable_mlflow) return catalog_instance = self.mlflow_catalog.get(name, None) if not catalog_instance: return elif isinstance(catalog_instance, str): if catalog_instance in {"param", "p", "$"}: mlflow_log_params({name: value}, enable_mlflow=self.enable_mlflow) elif catalog_instance in {"metric", "m", "#"}: mlflow_log_metrics({name: value}, enable_mlflow=self.enable_mlflow) elif catalog_instance in datasets_dict: ds = datasets_dict.get(catalog_instance) fp = tempfile.gettempdir() + "/" + name + "." + catalog_instance ds(filepath=fp).save(value) mlflow_log_artifacts([fp], enable_mlflow=self.enable_mlflow) else: log.warning( "'{}' is not supported as mlflow_catalog entry and ignored.".format( catalog_instance ) ) return else: if not hasattr(catalog_instance, "save"): log.warning("'save' attr is not found in mlflow_catalog instance.") catalog_instance.save(value) mlflow_log_dataset(catalog_instance, enable_mlflow=self.enable_mlflow)