Source code for pipelinex.mlflow_on_kedro.datasets.mlflow.mlflow_dataset

import logging
import tempfile
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Dict, Union

from kedro_datasets.pickle import PickleDataset as PickleDataSet
from kedro.io import MemoryDataset as MemoryDataSet
from kedro_datasets._io import AbstractDataset as AbstractDataSet


from pipelinex.mlflow_on_kedro.hooks.mlflow.mlflow_utils import (
    mlflow_log_artifacts,
    mlflow_log_metrics,
    mlflow_log_params,
    mlflow_log_values,
)

log = logging.getLogger(__name__)

dataset_dicts = {
    "json": {"type": "json.JSONDataset"},
    "csv": {"type": "pandas.CSVDataset"},
    "xls": {"type": "pandas.ExcelDataset"},
    "parquet": {"type": "pandas.ParquetDataset"},
    "pkl": {"type": "pickle.PickleDataset"},
    "png": {"type": "pillow.ImageDataset"},
    "jpg": {"type": "pillow.ImageDataset"},
    "jpeg": {"type": "pillow.ImageDataset"},
    "img": {"type": "pillow.ImageDataset"},
    "txt": {"type": "text.TextDataset"},
    "yaml": {"type": "yaml.YAMLDataset"},
    "yml": {"type": "yaml.YAMLDataset"},
}


[docs]class MLflowDataSet(AbstractDataSet): """``MLflowDataSet`` saves data to, and loads data from MLflow. You can also specify a ``MLflowDataSet`` in catalog.yml Example: :: >>> test_ds: >>> type: MLflowDataSet >>> dataset: pkl """
[docs] def __init__( self, dataset: Union[AbstractDataSet, Dict, str] = None, filepath: str = None, dataset_name: str = None, saving_tracking_uri: str = None, saving_experiment_name: str = None, saving_run_id: str = None, loading_tracking_uri: str = None, loading_run_id: str = None, caching: bool = True, copy_mode: str = None, file_caching: bool = True, ): """ Args: dataset: Specify how to treat the dataset as an MLflow metric, parameter, or artifact. - If set to "p", the value will be saved/loaded as an MLflow parameter (string). - If set to "m", the value will be saved/loaded as an MLflow metric (numeric). - If set to "a", the value will be saved/loaded based on the data type. - If the data type is either {float, int}, the value will be saved/loaded as an MLflow metric. - If the data type is either {str, list, tuple, set}, the value will be saved/load as an MLflow parameter. - If the data type is dict, the value will be flattened with dot (".") as the separator and then saved/loaded as either an MLflow metric or parameter based on each data type as explained above. - If set to either {"json", "csv", "xls", "parquet", "png", "jpg", "jpeg", "img", "pkl", "txt", "yml", "yaml"}, the backend dataset instance will be created accordingly to save/load as an MLflow artifact. - If set to a Kedro DataSet object or a dictionary, it will be used as the backend dataset to save/load as an MLflow artifact. - If set to None (default), MLflow logging will be skipped. filepath: File path, usually in local file system, to save to and load from. Used only if the dataset arg is a string. If None (default), ``<temp directory>/<dataset_name arg>.<dataset arg>`` is used. dataset_name: Used only if the dataset arg is a string and filepath arg is None. If None (default), Python object ID is used, but will be overwritten by MLflowCatalogLoggerHook. saving_tracking_uri: MLflow Tracking URI to save to. If None (default), MLFLOW_TRACKING_URI environment variable is used. saving_experiment_name: MLflow experiment name to save to. If None (default), new experiment will not be created or started. Ignored if saving_run_id is set. saving_run_id: An existing MLflow experiment run ID to save to. If None (default), no existing experiment run will be resumed. loading_tracking_uri: MLflow Tracking URI to load from. If None (default), MLFLOW_TRACKING_URI environment variable is used. loading_run_id: MLflow experiment run ID to load from. If None (default), current active run ID will be used if available. caching: Enable caching if parallel runner is not used. True in default. copy_mode: The copy mode used to copy the data. Possible values are: "deepcopy", "copy" and "assign". If not provided, it is inferred based on the data type. Ignored if caching arg is False. file_caching: Attempt to use the file at filepath when loading if no cache found in memory. True in default. """ self.dataset = dataset or MemoryDataSet() self.filepath = filepath self.dataset_name = dataset_name self.saving_tracking_uri = saving_tracking_uri self.saving_experiment_name = saving_experiment_name self.saving_run_id = saving_run_id self.loading_tracking_uri = loading_tracking_uri self.loading_run_id = loading_run_id self.caching = caching self.file_caching = file_caching self.copy_mode = copy_mode self._dataset_name = str(id(self)) if isinstance(dataset, str): if (dataset not in {"p", "m"}) and (dataset not in dataset_dicts): raise ValueError( "`dataset`: {} not supported. Specify one of {}.".format( dataset, list(dataset_dicts.keys()) ) ) self._ready = False self._running_parallel = None self._cache = None
def _init_dataset(self): if not getattr(self, "_ready", None): self._ready = True self.dataset_name = self.dataset_name or self._dataset_name _dataset = self.dataset if isinstance(self.dataset, str): dataset_dict = dataset_dicts.get( self.dataset, {"type": "pickle.PickleDataset"} ) dataset_dict["filepath"] = self.filepath = ( self.filepath or tempfile.gettempdir() + "/" + self.dataset_name + "." + self.dataset ) _dataset = dataset_dict if isinstance(_dataset, dict): self._dataset = AbstractDataSet.from_config( self._dataset_name, _dataset ) elif isinstance(_dataset, AbstractDataSet): self._dataset = _dataset else: raise ValueError( "The argument type of `dataset` should be either a dict/YAML " "representation of the dataset, or the actual dataset object." ) _filepath = getattr(self._dataset, "_filepath", None) if _filepath: self.filepath = str(_filepath) if self.caching and (not self._running_parallel): self._cache = MemoryDataSet(copy_mode=self.copy_mode) def _release(self) -> None: self._init_dataset() self._dataset.release() if self._cache: self._cache.release() def _describe(self) -> Dict[str, Any]: return { "Dataset": self._dataset._describe() if getattr(self, "_ready", None) else self.dataset, # pylint: disable=protected-access "filepath": self.filepath, "saving_tracking_uri": self.saving_tracking_uri, "saving_experiment_name": self.saving_experiment_name, "saving_run_id": self.saving_run_id, "loading_tracking_uri": self.loading_tracking_uri, "loading_run_id": self.loading_run_id, } def _load(self): self._init_dataset() if self._cache and self._cache.exists(): return self._cache.load() if self.file_caching and self._dataset.exists(): return self._dataset.load() import mlflow client = mlflow.tracking.MlflowClient(tracking_uri=self.loading_tracking_uri) self.loading_run_id = self.loading_run_id or mlflow.active_run().info.run_id if self.dataset in {"p"}: run = client.get_run(self.loading_run_id) value = run.data.params.get(self.dataset_name, None) if value is None: raise KeyError( "param '{}' not found in run_id '{}'.".format( self.dataset_name, self.loading_run_id ) ) PickleDataSet(filepath=self.filepath).save(value) elif self.dataset in {"m"}: run = client.get_run(self.loading_run_id) value = run.data.metrics.get(self.dataset_name, None) if value is None: raise KeyError( "metric '{}' not found in run_id '{}'.".format( self.dataset_name, self.loading_run_id ) ) PickleDataSet(filepath=self.filepath).save(value) else: p = Path(self.filepath) dst_path = tempfile.gettempdir() downloaded_path = client.download_artifacts( run_id=self.loading_run_id, path=p.name, dst_path=dst_path, ) if Path(downloaded_path) != p: Path(downloaded_path).rename(p) return self._dataset.load() def _save(self, data: Any) -> None: self._init_dataset() self._dataset.save(data) if self._cache: self._cache.save(data) if find_spec("mlflow"): import mlflow if self.saving_tracking_uri: mlflow.set_tracking_uri(self.saving_tracking_uri) if self.saving_run_id: mlflow.start_run(run_id=self.saving_run_id) elif self.saving_experiment_name: experiment_id = mlflow.get_experiment_by_name( self.saving_experiment_name ).experiment_id mlflow.start_run(run_id=self.saving_run_id, experiment_id=experiment_id) if self.dataset in {"p"}: mlflow_log_params({self.dataset_name: data}) elif self.dataset in {"m"}: mlflow_log_metrics({self.dataset_name: data}) elif self.dataset in {"a"}: mlflow_log_values({self.dataset_name: data}) else: mlflow_log_artifacts([self.filepath]) if self.saving_run_id or self.saving_experiment_name: mlflow.end_run() def _exists(self) -> bool: self._init_dataset() if self._cache: return self._cache.exists() else: return False def __getstate__(self): return self.__dict__