Source code for pipelinex.extras.datasets.core

# Copyright 2020 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
#     or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.


from importlib.util import find_spec

if find_spec("kedro-datasets"):
    from kedro_datasets._io import (
        AbstractDataset as AbstractDataSet,
        AbstractVersionedDataset as AbstractVersionedDataSet,
    )
else:

    """This module provides a set of classes which underpin the data loading and
    saving functionality provided by ``kedro.io``.
    """

    import abc
    import copy
    import logging
    import re
    import warnings
    from collections import namedtuple
    from datetime import datetime, timezone
    from functools import lru_cache
    from glob import iglob
    from pathlib import Path, PurePath
    from typing import Any, Callable, Dict, List, Optional, Tuple, Type
    from urllib.parse import urlsplit

    from pipelinex.hatch_dict.hatch_dict import load_obj

    warnings.simplefilter("default", DeprecationWarning)

    VERSION_FORMAT = "%Y-%m-%dT%H.%M.%S.%fZ"
    VERSIONED_FLAG_KEY = "versioned"
    VERSION_KEY = "version"
    HTTP_PROTOCOLS = ("http", "https")
    PROTOCOL_DELIMITER = "://"
    CLOUD_PROTOCOLS = ("s3", "gcs", "gs", "adl", "abfs")

[docs] class DataSetError(Exception): """``DataSetError`` raised by ``AbstractDataSet`` implementations in case of failure of input/output methods. ``AbstractDataSet`` implementations should provide instructive information in case of failure. """ pass
[docs] class DataSetNotFoundError(DataSetError): """``DataSetNotFoundError`` raised by ``DataCatalog`` class in case of trying to use a non-existing data set. """ pass
[docs] class DataSetAlreadyExistsError(DataSetError): """``DataSetAlreadyExistsError`` raised by ``DataCatalog`` class in case of trying to add a data set which already exists in the ``DataCatalog``. """ pass
[docs] class VersionNotFoundError(DataSetError): """``VersionNotFoundError`` raised by ``AbstractVersionedDataSet`` implementations in case of no load versions available for the data set. """ pass
[docs] class AbstractDataSet(abc.ABC): """``AbstractDataSet`` is the base class for all data set implementations. All data set implementations should extend this abstract class and implement the methods marked as abstract. Example: :: >>> from kedro.io import AbstractDataSet >>> import pandas as pd >>> >>> class MyOwnDataSet(AbstractDataSet): >>> def __init__(self, param1, param2): >>> self._param1 = param1 >>> self._param2 = param2 >>> >>> def _load(self) -> pd.DataFrame: >>> print("Dummy load: {}".format(self._param1)) >>> return pd.DataFrame() >>> >>> def _save(self, df: pd.DataFrame) -> None: >>> print("Dummy save: {}".format(self._param2)) >>> >>> def _describe(self): >>> return dict(param1=self._param1, param2=self._param2) """
[docs] @classmethod def from_config( cls: Type, name: str, config: Dict[str, Any], load_version: str = None, save_version: str = None, ) -> "AbstractDataSet": """Create a data set instance using the configuration provided. Args: name: Data set name. config: Data set config dictionary. load_version: Version string to be used for ``load`` operation if the data set is versioned. Has no effect on the data set if versioning was not enabled. save_version: Version string to be used for ``save`` operation if the data set is versioned. Has no effect on the data set if versioning was not enabled. Returns: An instance of an ``AbstractDataSet`` subclass. Raises: DataSetError: When the function fails to create the data set from its config. """ try: class_obj, config = parse_dataset_definition( config, load_version, save_version ) except Exception as ex: raise DataSetError( "An exception occurred when parsing config " "for DataSet `{}`:\n{}".format(name, str(ex)) ) try: data_set = class_obj(**config) # type: ignore except TypeError as err: raise DataSetError( "\n{}.\nDataSet '{}' must only contain " "arguments valid for the constructor " "of `{}.{}`.".format( str(err), name, class_obj.__module__, class_obj.__qualname__ ) ) except Exception as err: raise DataSetError( "\n{}.\nFailed to instantiate DataSet " "'{}' of type `{}.{}`.".format( str(err), name, class_obj.__module__, class_obj.__qualname__ ) ) return data_set
@property def _logger(self) -> logging.Logger: return logging.getLogger(__name__)
[docs] def load(self) -> Any: """Loads data by delegation to the provided load method. Returns: Data returned by the provided load method. Raises: DataSetError: When underlying load method raises error. """ self._logger.debug("Loading %s", str(self)) try: return self._load() except DataSetError: raise except Exception as exc: # This exception handling is by design as the composed data sets # can throw any type of exception. message = "Failed while loading data from data set {}.\n{}".format( str(self), str(exc) ) raise DataSetError(message) from exc
[docs] def save(self, data: Any) -> None: """Saves data by delegation to the provided save method. Args: data: the value to be saved by provided save method. Raises: DataSetError: when underlying save method raises error. """ if data is None: raise DataSetError("Saving `None` to a `DataSet` is not allowed") try: self._logger.debug("Saving %s", str(self)) self._save(data) except DataSetError: raise except Exception as exc: message = "Failed while saving data to data set {}.\n{}".format( str(self), str(exc) ) raise DataSetError(message) from exc
def __str__(self): def _to_str(obj, is_root=False): """Returns a string representation where 1. The root level (i.e. the DataSet.__init__ arguments) are formatted like DataSet(key=value). 2. Dictionaries have the keys alphabetically sorted recursively. 3. Empty dictionaries and None values are not shown. """ fmt = "{}={}" if is_root else "'{}': {}" # 1 if isinstance(obj, dict): sorted_dict = sorted( obj.items(), key=lambda pair: str(pair[0]) ) # 2 text = ", ".join( fmt.format(key, _to_str(value)) # 2 for key, value in sorted_dict if value or isinstance(value, bool) ) # 3 return text if is_root else "{" + text + "}" # 1 # not a dictionary return str(obj) return "{}({})".format(type(self).__name__, _to_str(self._describe(), True)) @abc.abstractmethod def _load(self) -> Any: raise NotImplementedError( "`{}` is a subclass of AbstractDataSet and" "it must implement the `_load` method".format(self.__class__.__name__) ) @abc.abstractmethod def _save(self, data: Any) -> None: raise NotImplementedError( "`{}` is a subclass of AbstractDataSet and" "it must implement the `_save` method".format(self.__class__.__name__) ) @abc.abstractmethod def _describe(self) -> Dict[str, Any]: raise NotImplementedError( "`{}` is a subclass of AbstractDataSet and" "it must implement the `_describe` method".format( self.__class__.__name__ ) )
[docs] def exists(self) -> bool: """Checks whether a data set's output already exists by calling the provided _exists() method. Returns: Flag indicating whether the output already exists. Raises: DataSetError: when underlying exists method raises error. """ try: self._logger.debug("Checking whether target of %s exists", str(self)) return self._exists() except Exception as exc: message = "Failed during exists check for data set {}.\n{}".format( str(self), str(exc) ) raise DataSetError(message) from exc
def _exists(self) -> bool: self._logger.warning( "`exists()` not implemented for `%s`. Assuming output does not exist.", self.__class__.__name__, ) return False
[docs] def release(self) -> None: """Release any cached data. Raises: DataSetError: when underlying release method raises error. """ try: self._logger.debug("Releasing %s", str(self)) self._release() except Exception as exc: message = "Failed during release for data set {}.\n{}".format( str(self), str(exc) ) raise DataSetError(message) from exc
def _release(self) -> None: pass def _copy(self, **overwrite_params) -> "AbstractDataSet": dataset_copy = copy.deepcopy(self) for name, value in overwrite_params.items(): setattr(dataset_copy, name, value) return dataset_copy
[docs] def generate_timestamp() -> str: """Generate the timestamp to be used by versioning. Returns: String representation of the current timestamp. """ current_ts = datetime.now(tz=timezone.utc).strftime(VERSION_FORMAT) return current_ts[:-4] + current_ts[-1:] # Don't keep microseconds
[docs] class Version(namedtuple("Version", ["load", "save"])): """This namedtuple is used to provide load and save versions for versioned data sets. If ``Version.load`` is None, then the latest available version is loaded. If ``Version.save`` is None, then save version is formatted as YYYY-MM-DDThh.mm.ss.sssZ of the current timestamp. """ __slots__ = ()
_CONSISTENCY_WARNING = ( "Save version `{}` did not match load version `{}` for {}. This is strongly " "discouraged due to inconsistencies it may cause between `save` and " "`load` operations. Please refrain from setting exact load version for " "intermediate data sets where possible to avoid this warning." ) _DEFAULT_PACKAGES = ["kedro.io.", "kedro.extras.datasets.", ""]
[docs] def parse_dataset_definition( config: Dict[str, Any], load_version: str = None, save_version: str = None ) -> Tuple[Type[AbstractDataSet], Dict[str, Any]]: """Parse and instantiate a dataset class using the configuration provided. Args: config: Data set config dictionary. It *must* contain the `type` key with fully qualified class name. load_version: Version string to be used for ``load`` operation if the data set is versioned. Has no effect on the data set if versioning was not enabled. save_version: Version string to be used for ``save`` operation if the data set is versioned. Has no effect on the data set if versioning was not enabled. Raises: DataSetError: If the function fails to parse the configuration provided. Returns: 2-tuple: (Dataset class object, configuration dictionary) """ save_version = save_version or generate_timestamp() config = copy.deepcopy(config) if "type" not in config: raise DataSetError("`type` is missing from DataSet catalog configuration") class_obj = config.pop("type") if isinstance(class_obj, str): if len(class_obj.strip(".")) != len(class_obj): raise DataSetError( "`type` class path does not support relative " "paths or paths ending with a dot." ) class_paths = (prefix + class_obj for prefix in _DEFAULT_PACKAGES) trials = (_load_obj(class_path) for class_path in class_paths) try: class_obj = next(obj for obj in trials if obj is not None) except StopIteration: raise DataSetError("Class `{}` not found.".format(class_obj)) if not issubclass(class_obj, AbstractDataSet): raise DataSetError( "DataSet type `{}.{}` is invalid: all data set types must extend " "`AbstractDataSet`.".format( class_obj.__module__, class_obj.__qualname__ ) ) if VERSION_KEY in config: # remove "version" key so that it's not passed # to the "unversioned" data set constructor message = ( "`%s` attribute removed from data set configuration since it is a " "reserved word and cannot be directly specified" ) logging.getLogger(__name__).warning(message, VERSION_KEY) del config[VERSION_KEY] if config.pop(VERSIONED_FLAG_KEY, False): # data set is versioned config[VERSION_KEY] = Version(load_version, save_version) return class_obj, config
def _load_obj(class_path: str) -> Optional[object]: try: class_obj = load_obj(class_path) except ModuleNotFoundError as error: if error.name is None or error.name in class_path: return None # class_obj was successfully loaded, but some dependencies are missing. raise DataSetError( f"{error} for {class_path}. Please see the documentation on how to " f"install relevant dependencies for {class_path}:\n" f"https://kedro.readthedocs.io/en/stable/02_getting_started/" f"02_install.html#optional-dependencies" ) except (AttributeError, ValueError): return None return class_obj def _local_exists(filepath: str) -> bool: # SKIP_IF_NO_SPARK filepath = Path(filepath) return filepath.exists() or any(par.is_file() for par in filepath.parents)
[docs] class AbstractVersionedDataSet(AbstractDataSet, abc.ABC): """ ``AbstractVersionedDataSet`` is the base class for all versioned data set implementations. All data sets that implement versioning should extend this abstract class and implement the methods marked as abstract. Example: :: >>> from kedro.io import AbstractVersionedDataSet >>> import pandas as pd >>> >>> >>> class MyOwnDataSet(AbstractVersionedDataSet): >>> def __init__(self, param1, param2, filepath, version): >>> super().__init__(filepath, version) >>> self._param1 = param1 >>> self._param2 = param2 >>> >>> def _load(self) -> pd.DataFrame: >>> load_path = self._get_load_path() >>> return pd.read_csv(load_path) >>> >>> def _save(self, df: pd.DataFrame) -> None: >>> save_path = self._get_save_path() >>> df.to_csv(str(save_path)) >>> >>> def _exists(self) -> bool: >>> path = self._get_load_path() >>> return path.is_file() >>> >>> def _describe(self): >>> return dict(version=self._version, param1=self._param1, param2=self._param2) """ # pylint: disable=abstract-method
[docs] def __init__( self, filepath: PurePath, version: Optional[Version], exists_function: Callable[[str], bool] = None, glob_function: Callable[[str], List[str]] = None, ): """Creates a new instance of ``AbstractVersionedDataSet``. Args: filepath: Path to file. version: If specified, should be an instance of ``kedro.io.core.Version``. If its ``load`` attribute is None, the latest version will be loaded. If its ``save`` attribute is None, save version will be autogenerated. exists_function: Function that is used for determining whether a path exists in a filesystem. glob_function: Function that is used for finding all paths in a filesystem, which match a given pattern. """ self._filepath = filepath self._version = version self._exists_function = exists_function or _local_exists self._glob_function = glob_function or iglob
[docs] @lru_cache(maxsize=None) def resolve_load_version(self) -> Optional[str]: """Compute and cache the version the dataset should be loaded with.""" if not self._version: return None if self._version.load: return self._version.load # When load version is unpinned, fetch the most recent existing # version from the given path pattern = str(self._get_versioned_path("*")) version_paths = sorted(self._glob_function(pattern), reverse=True) most_recent = next( (path for path in version_paths if self._exists_function(path)), None ) if not most_recent: raise VersionNotFoundError( "Did not find any versions for {}".format(str(self)) ) return PurePath(most_recent).parent.name
def _get_load_path(self) -> PurePath: if not self._version: # When versioning is disabled, load from original filepath return self._filepath load_version = self.resolve_load_version() return self._get_versioned_path(load_version) # type: ignore
[docs] @lru_cache(maxsize=None) def resolve_save_version(self) -> Optional[str]: """Compute and cache the version the dataset should be saved with.""" if not self._version: return None return self._version.save or generate_timestamp()
def _get_save_path(self) -> PurePath: if not self._version: # When versioning is disabled, return original filepath return self._filepath save_version = self.resolve_save_version() versioned_path = self._get_versioned_path(save_version) # type: ignore if self._exists_function(str(versioned_path)): raise DataSetError( "Save path `{}` for {} must not exist if versioning " "is enabled.".format(versioned_path, str(self)) ) return versioned_path def _get_versioned_path(self, version: str) -> PurePath: return self._filepath / version / self._filepath.name
[docs] def load(self) -> Any: self.resolve_load_version() # Make sure last load version is set return super().load()
[docs] def save(self, data: Any) -> None: save_version = ( self.resolve_save_version() ) # Make sure last save version is set super().save(data) load_version = self.resolve_load_version() if load_version != save_version: warnings.warn( _CONSISTENCY_WARNING.format(save_version, load_version, str(self)) )
[docs] def exists(self) -> bool: """Checks whether a data set's output already exists by calling the provided _exists() method. Returns: Flag indicating whether the output already exists. Raises: DataSetError: when underlying exists method raises error. """ self._logger.debug("Checking whether target of %s exists", str(self)) try: return self._exists() except VersionNotFoundError: return False except Exception as exc: # SKIP_IF_NO_SPARK message = "Failed during exists check for data set {}.\n{}".format( str(self), str(exc) ) raise DataSetError(message) from exc
def _release(self) -> None: super()._release() self.resolve_load_version.cache_clear() self.resolve_save_version.cache_clear()
def _parse_filepath(filepath: str) -> Dict[str, str]: """Split filepath on protocol and path. Based on `fsspec.utils.infer_storage_options`. Args: filepath: Either local absolute file path or URL (s3://bucket/file.csv) Returns: Parsed filepath. """ if ( re.match(r"^[a-zA-Z]:[\\/]", filepath) or re.match(r"^[a-zA-Z0-9]+://", filepath) is None ): return {"protocol": "file", "path": filepath} parsed_path = urlsplit(filepath) protocol = parsed_path.scheme or "file" if protocol in HTTP_PROTOCOLS: return {"protocol": protocol, "path": filepath} path = parsed_path.path if protocol == "file": windows_path = re.match(r"^/([a-zA-Z])[:|]([\\/].*)$", path) if windows_path: path = "{}:{}".format(*windows_path.groups()) options = {"protocol": protocol, "path": path} if parsed_path.netloc: if protocol in CLOUD_PROTOCOLS: host_with_port = parsed_path.netloc.rsplit("@", 1)[-1] host = host_with_port.rsplit(":", 1)[0] options["path"] = host + options["path"] return options
[docs] def get_protocol_and_path( filepath: str, version: Version = None ) -> Tuple[str, str]: """Parses filepath on protocol and path. Args: filepath: raw filepath e.g.: `gcs://bucket/test.json`. version: instance of ``kedro.io.core.Version`` or None. Returns: Protocol and path. Raises: DataSetError: when protocol is http(s) and version is not None. Note: HTTP(s) dataset doesn't support versioning. """ options_dict = _parse_filepath(filepath) path = options_dict["path"] protocol = options_dict["protocol"] if protocol in HTTP_PROTOCOLS: if version: raise DataSetError( "HTTP(s) DataSet doesn't support versioning. " "Please remove version flag from the dataset configuration." ) path = path.split(PROTOCOL_DELIMITER, 1)[-1] return protocol, path
[docs] def get_filepath_str(path: PurePath, protocol: str) -> str: """Returns filepath. Returns full filepath (with protocol) if protocol is HTTP(s). Args: path: filepath without protocol. protocol: protocol. Returns: Filepath string. """ path = str(path) if protocol in HTTP_PROTOCOLS: path = "".join((protocol, PROTOCOL_DELIMITER, path)) return path
[docs] def validate_on_forbidden_chars(**kwargs): """Validate that string values do not include white-spaces or ;""" for key, value in kwargs.items(): if " " in value or ";" in value: raise DataSetError( "Neither white-space nor semicolon are allowed in `{}`.".format(key) )