import copy
from pathlib import Path
from typing import Any, Dict, Union
from PIL import Image
import logging
import numpy as np
from ..core import AbstractVersionedDataSet, DataSetError, Version
from ...ops.numpy_ops import to_channel_first_arr, to_channel_last_arr, ReverseChannel
log = logging.getLogger(__name__)
[docs]class ImagesLocalDataSet(AbstractVersionedDataSet):
"""Loads/saves a dict of numpy 3-D or 2-D arrays from/to a folder containing images.
Works like ``kedro.extras.datasets.pillow.ImageDataSet`` and
``kedro.io.PartitionedDataSet`` with conversion between numpy arrays and Pillow images.
"""
[docs] def __init__(
self,
path: str,
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = {"suffix": ".jpg"},
channel_first=False,
reverse_color=False,
version: Version = None,
) -> None:
"""
Args:
path: The folder path containing images
load_args: Args fed to
https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open
save_args:
Args, e.g.
- `suffix`: file suffix such as ".jpg"
- `upper`: optionally used as the upper pixel value corresponding to 0xFF (255)
for linear scaling to ensure the pixel value is between 0 and 255.
- `lower`: optionally used as the lower pixel value corresponding to 0x00 (0)
for linear scaling to ensure the pixel value is between 0 and 255.
- `mode`: fed to
https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.fromarray
- Any other args fed to
https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.save
channel_first: If true, the first dimension of 3-D array is
treated as channel (color) as in PyTorch.
If false, the last dimension of the 3-D array is
treated as channel (color) as in TensorFlow, Pillow, and OpenCV.
reverse_color: If true, the order of channel (color) is reversed
(RGB to BGR when loading, BGR to RGB when saving).
Set true to use packages such as OpenCV which uses BGR order natively.
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
"""
super().__init__(
filepath=Path(path),
version=version,
exists_function=self._exists,
)
self._load_args = load_args
self._save_args = save_args
self._channel_first = channel_first
self._reverse_color = reverse_color
def _load(self) -> Any:
load_path = Path(self._get_load_path())
load_args = copy.deepcopy(self._load_args)
load_args = load_args or dict()
dict_structure = load_args.pop("dict_structure", True)
as_numpy = load_args.pop("as_numpy", True)
channel_first = self._channel_first
reverse_color = self._reverse_color
if load_path.is_dir():
images_dict = {}
for p in load_path.glob("*"):
img = load_image(
p,
load_args,
as_numpy=as_numpy,
channel_first=channel_first,
reverse_color=reverse_color,
)
images_dict[p.stem] = img
if dict_structure is None:
return list(images_dict.values())
if dict_structure == "sep_names":
return dict(
images=list(images_dict.values()), names=list(images_dict.keys())
)
return images_dict
else:
return load_image(
load_path,
load_args,
as_numpy=self.as_numpy,
channel_first=channel_first,
reverse_color=reverse_color,
)
def _save(self, data: Union[dict, list, np.ndarray, type(Image.Image)]) -> None:
save_path = Path(self._get_save_path())
save_path.parent.mkdir(parents=True, exist_ok=True)
p = save_path
save_args = copy.deepcopy(self._save_args)
save_args = save_args or dict()
suffix = save_args.pop("suffix", ".jpg")
mode = save_args.pop("mode", None)
upper = save_args.pop("upper", None)
lower = save_args.pop("lower", None)
to_scale = (upper is not None) or (lower is not None)
if isinstance(data, dict):
images = list(data.values())
names = list(data.keys())
if "names" in names and "images" in names:
images = data.get("images")
names = data.get("names")
else:
images = data
names = None
if hasattr(images, "save"):
if not to_scale:
img = images
img.save(p, **save_args)
return None
else:
images = np.asarray(images)
if isinstance(images, np.ndarray):
if self._channel_first:
images = to_channel_last_arr(images)
if self._reverse_color:
images = ReverseChannel(channel_first=self._channel_first)(images)
if images.ndim in {2, 3}:
img = images
img = scale(lower=lower, upper=upper)(img)
img = np.squeeze(img)
img = Image.fromarray(img, mode=mode)
img.save(p, **save_args)
return None
elif images.ndim in {4}:
images = scale(lower=lower, upper=upper)(images)
dataset = Np3DArrDataset(images)
else:
raise ValueError(
"Unsupported number of dimensions: {}".format(images.ndim)
)
elif hasattr(images, "__getitem__") and hasattr(images, "__len__"):
if not to_scale:
p.mkdir(parents=True, exist_ok=True)
for i, img in enumerate(images):
if isinstance(img, np.ndarray):
if self._channel_first:
img = to_channel_last_arr(img)
if self._reverse_color:
img = ReverseChannel(channel_first=self._channel_first)(img)
img = np.squeeze(img)
img = Image.fromarray(img)
name = names[i] if names else "{:05d}".format(i)
s = p / "{}{}".format(name, suffix)
img.save(s, **save_args)
return None
else:
dataset = Np3DArrDatasetFromList(
images, transform=scale(lower=lower, upper=upper)
)
else:
raise ValueError("Unsupported data type: {}".format(type(images)))
p.mkdir(parents=True, exist_ok=True)
for i in range(len(dataset)):
img = dataset[i]
if isinstance(img, (tuple, list)):
img = img[0]
if self._channel_first:
img = to_channel_last_arr(img)
if self._reverse_color:
img = ReverseChannel(channel_first=self._channel_first)(img)
img = np.squeeze(img)
img = Image.fromarray(img, mode=mode)
name = names[i] if names else "{:05d}".format(i)
s = p / "{}{}".format(name, suffix)
img.save(s, **save_args)
return None
def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._filepath,
load_args=self._save_args,
save_args=self._save_args,
channel_first=self._channel_first,
reverse_color=self._reverse_color,
version=self._version,
)
def _exists(self) -> bool:
try:
path = self._get_load_path()
except DataSetError:
return False
return Path(path).exists()
[docs]def load_image(
load_path, load_args, as_numpy=False, channel_first=False, reverse_color=False
):
with load_path.open("rb") as local_file:
img = Image.open(local_file, **load_args)
if as_numpy:
img = np.asarray(img)
if channel_first:
img = to_channel_first_arr(img)
if reverse_color:
img = ReverseChannel(channel_first=channel_first)(img)
return img
[docs]def scale(**kwargs):
def _scale(a):
lower = kwargs.get("lower")
upper = kwargs.get("upper")
if (lower is not None) or (upper is not None):
max_val = a.max()
min_val = a.min()
stat_dict = dict(max_val=max_val, min_val=min_val)
log.info(stat_dict)
upper = upper or max_val
lower = lower or min_val
a = (
((a - min_val) / (max_val - min_val)) * (upper - lower) + lower
).astype(np.uint8)
return a
return _scale
[docs]class Np3DArrDataset:
[docs] def __init__(self, a):
self.a = a
def __getitem__(self, index):
return self.a[index, ...]
def __len__(self):
return len(self.a)
[docs]class Np3DArrDatasetFromList:
[docs] def __init__(self, a, transform=None):
self.a = a
self.transform = transform
def __getitem__(self, index):
item = np.asarray(self.a[index])
if self.transform:
item = self.transform(item)
return item
def __len__(self):
return len(self.a)