Esempio n. 1
0
    def on_batch_end(self, losses: dict, imgs_list=None):
        values = []
        for loss_name in losses.keys():
            loss = losses[loss_name]
            if loss_name not in self.losses.keys():
                self.losses[loss_name] = []
            if not isinstance(loss, list):
                self.losses[loss_name].append(loss)
                values.append((loss_name, loss))
        self.bar.update(self.batch_num, values=values)
        # TODO: Fix dimension bug
        if imgs_list is not None and self.batch_num % 5 == 0:
            ori_imgs = imgs_list[0]
            gts = imgs_list[1]
            preds = imgs_list[2]
            un_norm = UnNormalize(self.args.mean, self.args.std)
            ori_imgs = un_norm(ori_imgs)
            imgs_list = [ori_imgs.float(), gts.float(), preds.float()]
            imgs = torch.cat(imgs_list, dim=0)
            imgs = make_grid(imgs, nrow=self.batch_size)
            self.writer.add_image('segmentation',
                                  imgs,
                                  global_step=self.batch_num)

        self.batch_num += 1
Esempio n. 2
0
    def update_board(self, loss, dice, iou, images, masks):
        # update the tensorboard
        if self.phase == 'train':
            name_str = 'Train/'
        else:
            name_str = 'Val/'
        self.writer.add_scalar(tag=name_str + "Loss",
                               scalar_value=loss,
                               global_step=self.epoch)
        self.writer.add_scalar(tag=name_str + "Dice",
                               scalar_value=dice,
                               global_step=self.epoch)
        self.writer.add_scalar(tag=name_str + "Iou",
                               scalar_value=iou,
                               global_step=self.epoch)
        # unnormalize images before adding to board

        for index in range(images.shape[0]):
            img = images[index]
            unnorm = UnNormalize()
            img = unnorm(img)
            self.writer.add_image(tag=name_str + "Img",
                                  img_tensor=img,
                                  global_step=self.epoch,
                                  dataformats='CHW')
        self.writer.flush()
Esempio n. 3
0
def explain_pair(alg, pair: List[torch.tensor], labels: List[int], **kwargs):
    """
    Use a Captum explanation algorithm on a pair of images and plot side-by-side

    Parameters:
        alg: Captum algorithm, e.g. Saliency()
        pair: list of 2 images as torch tensors
        labels: the labels for each image
        **kwargs: additional arguments for Captum algorithm
    """

    def _prepare_explainer_input(img):
        input = img.unsqueeze(0)
        input.requires_grad = True
        input = input.cuda()
        return input

    inputs = [_prepare_explainer_input(img) for img in pair]
    grads = [explainer(alg, inp, lab, **kwargs) for inp, lab in zip(inputs, labels)]

    unorm = UnNormalize(img_means, img_stds)
    org_images = [unorm(img) for img in pair]
    org_images = [
        np.transpose(org_img.cpu().detach().numpy(), (1, 2, 0))
        for org_img in org_images
    ]

    fig, (ax1, ax2) = plt.subplots(1, 2)
    _ = viz.visualize_image_attr(
        grads[0],
        org_images[0],
        method="blended_heat_map",
        sign="absolute_value",
        show_colorbar=True,
        title="Predicted",
        plt_fig_axis=(fig, ax1),
        # use_pyplot to false to  avoid viz calling plt.show()
        use_pyplot=False,
    )
    _ = viz.visualize_image_attr(
        grads[1],
        org_images[1],
        method="blended_heat_map",
        sign="absolute_value",
        show_colorbar=True,
        title="Nearest neighbor",
        plt_fig_axis=(fig, ax2),
    )

    return fig, (ax1, ax2)
Esempio n. 4
0
    def __init__(self, args):
        self.args = args

        kwargs = {"num_classes": self.args.num_class}

        if args.net == "resnet18":
            from nets.resnet import ResNet18
            model = ResNet18(pretrained=(args.load == 'imagenet'), **kwargs)
        elif args.net == "resnet50":
            from nets.resnet import ResNet50
            model = ResNet50(pretrained=(args.load == 'imagenet'), **kwargs)
        elif args.net == "wideresnet282":
            from nets.wideresnet import WRN28_2
            model = WRN28_2(**kwargs)
        else:
            raise NotImplementedError

        print("Number of parameters",
              sum(p.numel() for p in model.parameters() if p.requires_grad))

        self.model = nn.DataParallel(model).cuda()

        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.args.lr,
                                         momentum=0.9,
                                         nesterov=True,
                                         weight_decay=5e-4)
        self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
        self.criterion_nored = nn.CrossEntropyLoss(reduction="none")

        self.kwargs = {"num_workers": 12, "pin_memory": False}
        self.train_loader, self.val_loader = make_data_loader(
            args, **self.kwargs)

        self.best = 0
        self.best_epoch = 0
        self.acc = []
        self.train_acc = []
        self.med_clean = []
        self.med_noisy = []
        self.perc_clean = []
        self.perc_noisy = []

        self.reductor_plot = umap.UMAP(n_components=2)

        self.toPIL = torchvision.transforms.ToPILImage()

        self.unorm = UnNormalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
Esempio n. 5
0
def show_pair(pair, labels, pred, savename=None):
    """Plot a pair of images side by side"""
    unorm = UnNormalize(img_means, img_stds)
    pair = [unorm(img).numpy() for img in pair]  # unnormalize and transform to numpy
    labels = [classes[l] for l in labels]  # get text label

    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.imshow(np.transpose(pair[0], (1, 2, 0)))
    ax1.set_title(f"true: {labels[0]}, pred: {classes[pred]}")
    remove_axis_ticks(ax1)

    ax2.imshow(np.transpose(pair[1], (1, 2, 0)))
    ax2.set_title(f"nn: {labels[1]}")
    remove_axis_ticks(ax2)

    if savename is not None:
        plt.savefig(savename + ".png")

    return fig, (ax1, ax2)
Esempio n. 6
0
def show_mask_image(imgs,
                    mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225),
                    unnormalize=True,
                    masks=None):
    """
    Mask is 4 dimension because we have mask for 4 different classes
    """

    for index in range(imgs.shape[0]):

        img = imgs[index]
        if unnormalize:
            # rev   erse normalize it
            unnorm = UnNormalize(mean=mean, std=std)
            img = unnorm(img)

        img = img.numpy().transpose((1, 2, 0))

        if type(masks) is np.ndarray:
            mask = masks.transpose((1, 2, 0))
        else:
            mask = masks[index].numpy().transpose((1, 2, 0))
        extra = 'Has defect type:'
        fig, ax = plt.subplots(figsize=(15, 50))
        for j in range(4):
            msk = mask[:, :, j]
            if np.sum(msk) != 0: extra += ' ' + str(j + 1)
            if j == 0:  # yellow
                img[msk == 1, 0] = 235
                img[msk == 1, 1] = 235
            elif j == 1:
                img[msk == 1, 0] = 210  # green
            elif j == 2:
                img[msk == 1, 0] = 255  # blue
            elif j == 3:  # magenta
                img[msk == 1, 0] = 255
                img[msk == 1, 2] = 255

        plt.axis('off')
        plt.title(extra)
        plt.imshow(img)
        plt.show()
Esempio n. 7
0
def visualise_mask(imgs, masks, unnormalize=True):
    """ open an image and draws clear masks, so we don't lose sight of the
        interesting features hiding underneath
    """

    # going through the 4 layers in the last dimension
    # of our mask with shape (256, 1600, 4)
    # we go through each image

    for index in range(imgs.shape[0]):

        img = imgs[index]
        if unnormalize:
            # rev   erse normalize it
            unnorm = UnNormalize()
            img = unnorm(img)

        img = img.numpy().transpose((1, 2, 0))
        mask = masks[index].numpy().transpose((1, 2, 0))

        # fig, axes = plt.subplots(mask.shape[-1], 1, figsize=(16, 20))
        # fig.tight_layout()

        for i in range(mask.shape[-1]):
            # indeces are [0, 1, 2, 3], corresponding classes are [1, 2, 3, 4]
            if np.amax(mask[:, :, i]) > 0.0:
                label = i + 1
                # ax = axes[i]
                # add the contours, layer per layer
                image = mask_to_contours(img,
                                         mask[:, :, i],
                                         color=palette[label])

                plt.figure(figsize=(16, 30))
                plt.title("In image {} for defect {}".format(index, i))
                plt.imshow(image.get())
Esempio n. 8
0
momentum = 0.9
count = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

batch_size = 1
attack_method = "ITER"
dataset = "CIFAR10"
shuffle = False
save_pics = False

# Set CUDA
use_cuda = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_url = './trained_models/VGG19.pth'

unnorm = UnNormalize(mean=(0.4914, 0.4822, 0.4465),
                     std=(0.2023, 0.1994, 0.2010))
# Transform
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Dataloader
if dataset == 'CIFAR10':
    # Test unseen
    test_set = torchvision.datasets.CIFAR10(root='./data',
                                            train=False,
                                            download=False,
                                            transform=transform_train)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
        parameters['model_name'] = model_name
        calibrated_pipeline = _load_model(**parameters)
        model = calibrated_pipeline.calibrated_classifiers_[0].base_estimator

        examples_transformed, target_values_transformed = just_transforms(
            model, examples, target_values)
        myInterpreter = InterpretToolkit(model=[model.steps[-1][1]],
                                         model_names=[model_name],
                                         examples=examples_transformed,
                                         targets=target_values_transformed,
                                         feature_names=feature_names)

        njobs = 1 if model_name == 'XGBoost' else len(important_vars)

        if normalize_method != None:
            unnormalize_func = UnNormalize(model.steps[1][1], feature_names)
        else:
            unnormalize_func = None

        result_dict = myInterpreter.calc_ale(
            features=important_vars,
            nbootstrap=100,
            subsample=1.0,
            njobs=njobs,
            nbins=30,
        )
        ale_results.append(result_dict)

    myInterpreter.model_names = model_set

    ale_results = merge_nested_dict(ale_results)
    def __getitem__(self, index):
        mask_name = self.df.iloc[index]['image_name'] + ".png"
        im_path = os.path.join(self.imfolder,
                               self.df.iloc[index]['image_name'] + ".jpg")
        x = cv2.imread(im_path)
        x = cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
        # meta = np.array(self.df.iloc[index][self.meta_features].values, dtype=np.float32)

        if self.type == "train":
            mask_path = train_mask_path
        else:
            mask_path = test_mask_path

        seed = np.random.randint(
            2147483647)  # make a seed with numpy generator
        random.seed(seed)  # apply this seed to img tranfsorms
        torch.manual_seed(seed)  # needed for torchvision 0.7
        if self.transforms:
            res = self.transforms(image=x)
            # for albumentations transforms
            x = res['image'].astype(np.float32)
        x = albumentations.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224,
                                          0.225])(image=x)['image']
        # after_transform = torch.tensor(x.transpose(2, 0, 1), dtype=torch.float32)
        original_image = UnNormalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225
                                          ])(torch.tensor(x.transpose(2, 0, 1),
                                                          dtype=torch.float32))

        mask = cv2.imread(mask_path + mask_name)
        original_mask = mask
        random.seed(seed)  # apply this seed to target tranfsorms
        torch.manual_seed(seed)  # needed for torchvision 0.7
        if self.transforms is not None:
            res = self.transforms(image=mask)
            mask = res['image'].astype(np.float32)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY).astype(
            np.uint8)  # mask should be np.uint8 type or will be an error
        ret, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)

        x = cv2.bitwise_and(x, x, mask=mask)

        # for ablumentations
        x = x.transpose(2, 0, 1)
        # data = torch.tensor(x).float()
        self.count += 1
        if self.count <= 2:
            viz.image(original_image)
            viz2.image(mask, env="mask")
            viz3.image(x, env="after mask")
            viz4.image(original_mask, env="original mask")

        if self.meta_features:
            data = (torch.tensor(x, dtype=torch.float32),
                    torch.tensor(self.df.iloc[index][self.meta_features],
                                 dtype=torch.float32))
        else:
            data = torch.tensor(x, dtype=torch.float32)

        if self.train:
            # y = self.df.iloc[index]['target']
            # for albumentations transforms
            y = torch.tensor(self.df.iloc[index]['target'],
                             dtype=torch.float32)
            return data, y
        else:
            return data
Esempio n. 11
0
def plot_original_and_explained_pair(
    pair: List[torch.tensor],
    labels: List[int],
    alg,
    pred: int,
    savename: str = None,
    method: str = "blended_heat_map",
    sign: str = "absolute_value",
):
    """
    Plot 2x2 grid of images. First row shows original images, second the gradient explanations.

    Args:
        pair (List[torch.tensor]): List of 2 images as torch tensors
        labels (List[int]): the true labels for the images
        alg ([type]): a Captum algorithm
        pred (int): the prediction for the original image
        savename (str, optional): If given, saves the image to disk. Defaults to None.
        method (str, optional): which visualization method to use (heat_map, blended_heat_map, original_image, masked_image, alpha_scaling)
        sign (str, optional): sign of attributions to visualiuze (positive, absolute_value, negative, all)
    """

    def _prepare_explainer_input(img):
        input = img.unsqueeze(0)
        input.requires_grad = True
        input = input.cuda()
        return input

    inputs = [_prepare_explainer_input(img) for img in pair]
    # Explaining the target
    grads_target = [explainer(alg, inp, lab) for inp, lab in zip(inputs, labels)]
    # Explaining the actual prediction
    grads_pred = [explainer(alg, inp, pred) for inp in inputs]

    unorm = UnNormalize(img_means, img_stds)
    org_images = [unorm(img) for img in pair]
    org_images = [
        np.transpose(org_img.cpu().detach().numpy(), (1, 2, 0))
        for org_img in org_images
    ]

    text_labels = [classes[l] for l in labels]  # get text label

    fig, axes = plt.subplots(2, 3)
    # plt.subplots_adjust(wspace=0.0001)
    ### Plot original images
    # Wrongly predicted
    _ = viz.visualize_image_attr(
        grads_target[0],
        org_images[0],
        method="original_image",
        title=f"true: {text_labels[0]}, pred: {classes[pred]}",
        plt_fig_axis=(fig, axes[0, 0]),
        use_pyplot=False,
    )

    # Nearest neighbor
    _ = viz.visualize_image_attr(
        grads_target[1],
        org_images[1],
        method="original_image",
        title=f"nn: {text_labels[1]}",
        plt_fig_axis=(fig, axes[1, 0]),
        use_pyplot=False,
    )

    ### Gradient explanations for predicted
    _ = viz.visualize_image_attr(
        grads_pred[0],
        org_images[0],
        method=method,
        sign=sign,  # org: "absolute_value"
        show_colorbar=True,
        title=f"Exp. wrt. {classes[pred]}",
        plt_fig_axis=(fig, axes[0, 1]),
        # use_pyplot to false to  avoid viz calling plt.show()
        use_pyplot=False,
    )
    _ = viz.visualize_image_attr(
        grads_pred[1],
        org_images[1],
        method=method,
        sign=sign,
        show_colorbar=True,
        title="",
        plt_fig_axis=(fig, axes[1, 1]),
        use_pyplot=True,
    )
    ### Gradient explanations for target
    _ = viz.visualize_image_attr(
        grads_target[0],
        org_images[0],
        method=method,
        sign=sign,  # org: "absolute_value"
        show_colorbar=True,
        title=f"Exp. wrt. {text_labels[0]}",
        plt_fig_axis=(fig, axes[0, 2]),
        # use_pyplot to false to  avoid viz calling plt.show()
        use_pyplot=False,
    )
    _ = viz.visualize_image_attr(
        grads_target[1],
        org_images[1],
        method=method,
        sign=sign,
        show_colorbar=True,
        title="",
        plt_fig_axis=(fig, axes[1, 2]),
        use_pyplot=False,
    )

    if savename is not None:
        plt.savefig(savename + ".png")

    plt.close()
    return fig, axes
Esempio n. 12
0
    writer.add_scalar('data/test_d_loss', test_d_loss, epoch)

    print(
        '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_r_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
    print(
        '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_d_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

unnormalize = UnNormalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])

trainset = ImageFolder(
    'data/gender_images/train',
    transforms.Compose([
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))

testset = ImageFolder(
    'data/gender_images/test',
    transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
Esempio n. 13
0
def run(device):
    # prepare training and testing data
    training_dataset = torchvision.datasets.MNIST('./data', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((0.1307,), (0.3081,))]))
    testing_dataset = torchvision.datasets.MNIST('./data', train=False,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((0.1307,), (0.3081,))]))
    train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=128, shuffle=True)
    test_loader = torch.utils.data.DataLoader(testing_dataset, batch_size=3)

    # create checkpoint dir
    save_checkpoint = True
    save_dir = "checkpoint/mnist"
    os.makedirs(save_dir, exist_ok=True)

    # model and optimizer
    model = Net().to(device)
    model.train()
    optimizer = optim.Adam(model.parameters())

    # training
    for epoch in range(1):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            logits = model(data)
            loss = F.cross_entropy(logits, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % 5 == 0:
                print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()), end='')

        if save_checkpoint:
            torch.save({"model": model.state_dict(), "optimizer": optimizer.state_dict()}, os.path.join(save_dir, "epoch%d.pth"%epoch))

    print()

    # testing
    print("Testing...")
    correct = 0
    model.eval()
    un = UnNormalize((0.1307,), (0.3081,))

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        logits = model(data)

        pred = logits.argmax(dim=1, keepdim=True)           # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

        # visualization
        pred = pred.view(-1).cpu().numpy()
        gt = target.cpu().numpy()
        print("The prediction result: ", end="")
        print(pred, end="")
        print(". Groundtruth:", end="")
        print(gt)

        original_image = data
        original_image = original_image.cpu().numpy()

        for i in range(original_image.shape[0]):
            image = original_image[i]
            grey_layer = image[0]
            image = np.stack([grey_layer, grey_layer, grey_layer], axis=2)
            cv2.imshow("%d"%i, image)
        cv2.waitKey()


    accuracy = 100. * correct / len(test_loader.dataset)
    print('Accuracy: {}/{} ({:.5f}%)'.format(correct, len(test_loader.dataset), accuracy))
    print("Done!")
Esempio n. 14
0
    plt.title("Intermediate Experts")
    plt.imshow(np.transpose(vutils.make_grid(res_experts.detach().to(device), padding=2, normalize=True).cpu(),(1,2,0)))
    
    #theRest = cv2.cvtColor(t_res.transpose((1,2,0)).astype(np.uint8 ),cv2.COLOR_RGB2BGR)
    plt.figure()
    plt.title("Intermediate Direct")
    plt.imshow(np.transpose(vutils.make_grid(res_general.detach().to(device), padding=2, normalize=True).cpu(),(1,2,0)))
    
    
    plt.figure()
    plt.title("Original")
    plt.imshow(np.transpose(vutils.make_grid(li.to(device), padding=2, normalize=True).cpu(),(1,2,0)))
    plt.show()

#Now landmarking
unorm = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

if useGeneral : 
    resTemp = res_general.cpu() 
else : 
    resTemp = res_experts.cpu()
    
for img_ori,img_de in zip(li.cpu(),resTemp) : 
     
    img_dex = unorm(img_de.clone()).numpy()*255
    img_dex = img_dex.transpose((1,2,0))
    img_dex = cv2.cvtColor(img_dex.astype(np.uint8 ),cv2.COLOR_RGB2BGR)
    cvImageDe.append(img_dex)
    resDe = f.forward(img_dex,bb = bb)
    ldmarkDe.append(resDe)
Esempio n. 15
0
 def test(self, test_dataloader, out_dir=None,
          val=False, vis=True, eval=False):
     if not self.args.test and not val:
         self.load(force=True)
     self.model.eval()
     torch.set_grad_enabled(False)
     total = 0
     correct = 0
     images = []
     preds = []
     gts = []
     
     if self.args.n_classes == 3:
         target_list = [0, 2]
         dice_list = [[], []]
     else:
         target_list = [1]
         dice_list = [[]]
     only_test = False # Without evaluation
     if isinstance(test_dataloader, str):
         only_test = True
         print('Load images from {}'.format(test_dataloader))
         transformation = transforms.Compose([transforms.Resize((self.args.size, self.args.size)),
                                              transforms.ToTensor(),
                                              transforms.Normalize(self.args.mean, self.args.std),
                                              ])
         dataset = TestSet(self.args, test_dataloader, transformation)
         test_dataloader = DataLoader(dataset, batch_size=self.args.batch_size,
                                  shuffle=False, num_workers=self.args.n_threads)
     for test_data in tqdm(test_dataloader):
         if self.args.n_colors == 4:
             x, meta, y = test_data
             x = x.cuda()
             meta = meta.cuda()
             y = y.long().cuda()
             concat = torch.cat((x, meta), dim=1)
             y_pred = self.model(concat)
         else:
             x, y = test_data
             x = x.cuda()
             y = y.long().cuda()
             # Forward
             y_pred = self.model(x)
         if not only_test:
             y = y.long().cuda()
         if vis:
             temp_x = x.clone()
             un = UnNormalize(self.args.mean,
                                  self.args.std)
             if not only_test:
                 temp_y = y.clone()
                 if len(y.shape) == 3:
                     temp_y = torch.unsqueeze(temp_y.clone(), dim=1)
             images += [un(temp_x[i, ...]).permute(1, 2, 0).cpu().numpy()
                        for i in range(x.size(0))]
             if only_test:
                 gts += [y[i] for i in range(x.size(0))]
             else:
                 gts += [temp_y[i, ...].permute(1, 2, 0).cpu().numpy()
                            for i in range(x.size(0))]
         if vis or eval:
             if self.args.n_classes > 0: # Regularize to [0, 1]
                 softmax_pred = nn.Softmax(dim=1)(y_pred.clone())
             else:
                 softmax_pred = y_pred.clone()
             preds += [softmax_pred.detach()[i, ...].permute(1, 2, 0).cpu().numpy()
                        for i in range(x.size(0))]
         if not only_test and not eval:
             total += x.size(0)
             correct += seg_accuracy(y_pred, y,regression=self.args.regression) * x.size(0)
             for i in range(len(target_list)):
                 tp_, tn_, fp_, fn_, _dice = dice(y, y_pred, supervised=True, target=target_list[i])
                 dice_list[i].append(_dice.detach().cpu().numpy())        
         acc = round(correct / total, 4)
     print('')
     self.model.train()
     torch.set_grad_enabled(True)
     for i in range(len(target_list)):
         dice_list[i] = np.mean(dice_list[i])
     if not only_test:
         if vis:
             self._vis(images, gts=gts, preds=preds, val=val, out_dir=out_dir)
         if eval:
             return preds, gts, images
         else:
             if val:
                 if self.args.n_classes == 3:
                     return acc, dice_list[0], dice_list[1]
                 else:
                     return acc, dice_list[0]
             else:
                 if self.args.n_classes == 3:
                     return [acc, dice_list[0], \
                     dice_list[1]], preds, gts
                 else:
                     return [acc, dice_list[0].detach().cpu().numpy()], preds, gts
     else:
         self._vis(images, gts=gts, preds=preds, val=val, out_dir=out_dir, name_list=gts)
    display_feature_names = {
        f: to_readable_names([f])[0]
        for f in feature_names
    }
    display_feature_names = _fix_long_names(display_feature_names)
    feature_units = {f: get_units(f) for f in feature_names}

    date_subset = date_col[:len(examples_subset)].reshape(
        len(examples_subset), 1)

    examples_subset = np.concatenate((examples_subset, date_subset), axis=1)
    examples_subset = pd.DataFrame(examples_subset,
                                   columns=original_feature_names)

    if normalize_method != None:
        unnormalize = UnNormalize(model.steps[1][1], feature_names)
        feature_values = unnormalize._full_inverse_transform(examples_subset)
    else:
        unnormalize = None
        feature_values = examples_subset.values

    myInterpreter.plot_shap(plot_type='dependence',
                            display_feature_names=display_feature_names,
                            features=important_vars,
                            data_for_shap=examples_transformed,
                            subsample_size=100,
                            unnormalize=None,
                            feature_values=feature_values)

    fname = f'shap_dependence_{model_name}_{target}_{time}{drop_opt}.png'
    base_plot.save_figure(fname=fname)
    print('First load of the data...')
    examples, target_values = _load_train_data(**parameters)
    original_examples = examples.copy()
    feature_names = list(examples.columns)
    feature_names.remove('Run Date')

    results = []
    for model_name in model_set:
        print(model_name)
        parameters['model_name'] = model_name
        #if model_name == "LogisticRegression":
        #    parameters['normalize'] = 'standard'

        calibrated_pipeline = _load_model(**parameters)
        model = calibrated_pipeline.calibrated_classifiers_[0].base_estimator
        unnormalize = UnNormalize(model.steps[1][1], feature_names)
        examples_transformed, target_values_transformed = just_transforms(
            model, examples, target_values)
        examples_transformed = pd.DataFrame(examples_transformed,
                                            columns=feature_names)

        cape = unnormalize.inverse_transform(
            examples_transformed['cape_ml_ens_mean_spatial_mean'].values,
            'cape_ml_ens_mean_spatial_mean')
        shear_u = unnormalize.inverse_transform(
            examples_transformed['shear_u_0to6_ens_mean_spatial_mean'].values,
            'shear_u_0to6_ens_mean_spatial_mean')
        shear_v = unnormalize.inverse_transform(
            examples_transformed['shear_v_0to6_ens_mean_spatial_mean'].values,
            'shear_v_0to6_ens_mean_spatial_mean')
        cin = unnormalize.inverse_transform(