Source code for pipelinex.extras.ops.shap_ops

import numpy as np
import torch
from torch.utils.data import DataLoader

from .numpy_ops import to_channel_last_arr


[docs]class ExplainModel:
[docs] def __init__(self, **kwargs): self.kwargs = kwargs
def __call__(self, model, train_dataset, val_dataset): kwargs = self.kwargs train_size = kwargs.get("train_size") val_size = kwargs.get("val_size") train_data_loader_params = kwargs.get( "train_data_loader_params", dict(batch_size=100) ) val_data_loader_params = kwargs.get( "val_data_loader_params", dict(batch_size=3) ) if train_size: train_data_loader_params["batch_size"] = train_size if val_size: val_data_loader_params["batch_size"] = val_size output_transform = kwargs.get("output_transform") train_loader = DataLoader(train_dataset, **train_data_loader_params) val_loader = DataLoader(val_dataset, **val_data_loader_params) train_batch = next(iter(train_loader)) val_batch = next(iter(val_loader)) if output_transform is not None: model = torch.nn.Sequential(model, output_transform) images = _explain_pytorch_model(model, train_batch, val_batch) return images
def _explain_pytorch_model( model, # type: torch.nn.Module train_batch, val_batch, ): try: import shap except ImportError: print("Failed to import shap.") return "_", "Failed to import shap." if isinstance(train_batch, (tuple, list)): train_images_tt, _ = train_batch else: train_images_tt = train_batch if isinstance(val_batch, (tuple, list)): val_images_tt, val_labels = val_batch else: val_images_tt = val_batch val_labels = [""] * val_images_tt.shape[0] e = shap.DeepExplainer(model, train_images_tt) shap_nchw_arr_list = e.shap_values(val_images_tt) shap_nhwc_arr_list = [to_channel_last_arr(s) for s in shap_nchw_arr_list] val_images_nhwc_arr = to_channel_last_arr(val_images_tt.numpy()) shap.image_plot(shap_nhwc_arr_list, -val_images_nhwc_arr, show=True) shap_nhwc_arr = np.concatenate(shap_nhwc_arr_list, axis=0) shap_image_names = [ "y_{}_x_{}_p_{}".format(val_labels[samp_i], samp_i, class_i) for class_i in range(len(shap_nchw_arr_list)) for samp_i in range(len(val_images_tt)) ] val_image_names = [ "y_{}_x_{}".format(val_labels[samp_i], samp_i) for samp_i in range(len(val_images_tt)) ] shap_nhwc_arr = Scale(lower=0, upper=255)(shap_nhwc_arr) val_images_nhwc_arr = Scale(lower=0, upper=255)(val_images_nhwc_arr) all_images_nhwc_arr = np.concatenate([shap_nhwc_arr, val_images_nhwc_arr], axis=0) all_names = shap_image_names + val_image_names assert all_images_nhwc_arr.shape[0] == len(all_names) images_dict = dict(images=all_images_nhwc_arr, names=all_names) return images_dict
[docs]class Scale:
[docs] def __init__(self, **kwargs): self.kwargs = kwargs
def __call__(self, a): kwargs = self.kwargs lower = kwargs.get("lower") upper = kwargs.get("upper") if (lower is not None) or (upper is not None): max_val = a.max() min_val = a.min() upper = upper or max_val lower = lower or min_val a = ( ((a - min_val) / (max_val - min_val)) * (upper - lower) + lower ).astype(np.uint8) return a