Source code for pipelinex.mlflow_on_kedro.hooks.mlflow.mlflow_catalog_logger

import tempfile
from importlib.util import find_spec
from logging import getLogger
from typing import Any, Dict, Union  # NOQA

from kedro.io import DataCatalog
from kedro_datasets._io import AbstractDataset as AbstractDataSet
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node

from .mlflow_utils import (
    hook_impl,
    mlflow_log_artifacts,
    mlflow_log_metrics,
    mlflow_log_params,
    mlflow_log_values,
)

log = getLogger(__name__)


[docs]def get_kedro_runner(): import inspect from kedro.runner import AbstractRunner return next( caller[0].f_locals.get("runner") for caller in inspect.stack() if isinstance(caller[0].f_locals.get("runner"), AbstractRunner) )
[docs]def running_parallel(): from kedro.runner import ParallelRunner return isinstance(get_kedro_runner(), ParallelRunner)
datasets_dict = {} try: from kedro.extras.datasets.json import JSONDataSet datasets_dict["json"] = JSONDataSet except ImportError: pass try: from kedro.extras.datasets.pandas import CSVDataSet, ExcelDataSet, ParquetDataSet datasets_dict["csv"] = CSVDataSet datasets_dict["xls"] = ExcelDataSet datasets_dict["parquet"] = ParquetDataSet except ImportError: pass try: from kedro.extras.datasets.pickle import PickleDataSet datasets_dict["pkl"] = PickleDataSet datasets_dict["pickle"] = PickleDataSet except ImportError: pass try: from kedro.extras.datasets.pillow import ImageDataSet datasets_dict["png"] = ImageDataSet datasets_dict["jpg"] = ImageDataSet datasets_dict["jpeg"] = ImageDataSet datasets_dict["img"] = ImageDataSet except ImportError: pass try: from kedro.extras.datasets.text import TextDataSet datasets_dict["txt"] = TextDataSet except ImportError: pass try: from kedro.extras.datasets.yaml import YAMLDataSet datasets_dict["yml"] = YAMLDataSet datasets_dict["yaml"] = YAMLDataSet except ImportError: pass
[docs]def mlflow_log_dataset(dataset, enable_mlflow=True): fp = getattr(dataset, "_filepath", None) if not fp: low_ds = getattr(dataset, "_dataset", None) if low_ds: fp = getattr(low_ds, "_filepath", None) if not fp: log.warning("_filepath of '{}' was not found.".format(d)) return mlflow_log_artifacts([fp], enable_mlflow=enable_mlflow)
[docs]class MLflowCatalogLoggerHook: """Logs datasets to MLflow""" _logged_set = set()
[docs] def __init__( self, auto: bool = True, mlflow_catalog: Dict[str, Union[str, AbstractDataSet]] = {}, enable_mlflow: bool = True, ): """ Args: auto: If True, each dataset (Python func input/output) not listed in the catalog will be logged following the same rule as "a" option below. mlflow_catalog: [Deprecated in favor of MLflowDataSet] Specify how to log each dataset (Python func input/output). - If set to "p", the value will be saved/loaded as an MLflow parameter (string). - If set to "m", the value will be saved/loaded as an MLflow metric (numeric). - If set to "a", the value will be saved/loaded based on the data type. - If the data type is either {float, int}, the value will be saved/loaded as an MLflow metric. - If the data type is either {str, list, tuple, set}, the value will be saved/load as an MLflow parameter. - If the data type is dict, the value will be flattened with dot (".") as the separator and then saved/loaded as either an MLflow metric or parameter based on each data type as explained above. - If set to either {"json", "csv", "xls", "parquet", "png", "jpg", "jpeg", "img", "pkl", "txt", "yml", "yaml"}, the backend dataset instance will be created accordingly to save/load as an MLflow artifact. - If set to a Kedro DataSet object or a dictionary, it will be used as the backend dataset to save/load as an MLflow artifact. - If set to None (default), MLflow logging will be skipped. enable_mlflow: Enable logging to MLflow. """ self.enable_mlflow = find_spec("mlflow") and enable_mlflow self.mlflow_catalog = mlflow_catalog self.auto = auto
[docs] @hook_impl def before_pipeline_run( self, run_params: Dict[str, Any], pipeline: Pipeline, catalog: DataCatalog ) -> None: for dataset_name in catalog._data_sets: if catalog._data_sets[dataset_name].__class__.__name__ == "MLflowDataSet": setattr(catalog._data_sets[dataset_name], "_dataset_name", dataset_name) setattr( catalog._data_sets[dataset_name], "_running_parallel", running_parallel(), ) catalog._data_sets[dataset_name]._init_dataset()
[docs] @hook_impl def after_node_run( self, node: Node, catalog: DataCatalog, inputs: Dict[str, Any], outputs: Dict[str, Any], ): for name, value in inputs.items(): if name not in self._logged_set: self._logged_set.add(name) self._log_dataset(name, value) for name, value in outputs.items(): if name not in self._logged_set: self._logged_set.add(name) self._log_dataset(name, value)
def _log_dataset(self, name: str, value: Any): if name not in self.mlflow_catalog: if not self.auto: return mlflow_log_values({name: value}, enable_mlflow=self.enable_mlflow) return catalog_instance = self.mlflow_catalog.get(name, None) if not catalog_instance: return elif isinstance(catalog_instance, str): if catalog_instance in {"p"}: mlflow_log_params({name: value}, enable_mlflow=self.enable_mlflow) elif catalog_instance in {"m"}: mlflow_log_metrics({name: value}, enable_mlflow=self.enable_mlflow) elif catalog_instance in {"a"}: mlflow_log_values({name: value}, enable_mlflow=self.enable_mlflow) elif catalog_instance in datasets_dict: ds = datasets_dict.get(catalog_instance) fp = tempfile.gettempdir() + "/" + name + "." + catalog_instance ds(filepath=fp).save(value) mlflow_log_artifacts([fp], enable_mlflow=self.enable_mlflow) else: log.warning( "'{}' is not supported as mlflow_catalog entry and ignored.".format( catalog_instance ) ) return else: if not hasattr(catalog_instance, "save"): log.warning("'save' attr is not found in mlflow_catalog instance.") catalog_instance.save(value) mlflow_log_dataset(catalog_instance, enable_mlflow=self.enable_mlflow)