コード例 #1
0
 def get_train_dataloader(self, data_args) -> DataLoader:
     self.mode = 'training'
     # Deepcopy ensures any changes to mode variable will not influence this dataloader
     return torch.utils.data.DataLoader(deepcopy(self),
                                        batch_size=data_args['batch_size'],
                                        shuffle=data_args['shuffle'],
                                        num_workers=get_cores_count())
コード例 #2
0
 def get_test_dataloader(self, data_args):
     self.image_dataset.mode = 'test'
     self.attribution_dataset.mode = 'test'
     return torch.utils.data.DataLoader(deepcopy(self),
                                        batch_size=data_args['batch_size'],
                                        shuffle=data_args['shuffle'],
                                        num_workers=get_cores_count())
コード例 #3
0
 def test_dataloader(self):
     return torch.utils.data.DataLoader(
         self.testset,
         batch_size=self.val_data_args['batch_size'],
         shuffle=self.val_data_args['shuffle'],
         pin_memory=True,
         num_workers=get_cores_count())
コード例 #4
0
 def validation_dataloader(self) -> DataLoader:
     return torch.utils.data.DataLoader(
         self.validationset,
         batch_size=self.train_data_args['batch_size'],
         shuffle=self.train_data_args['shuffle'],
         pin_memory=True,
         num_workers=get_cores_count())
コード例 #5
0
ファイル: CIFAR10.py プロジェクト: YaNgZhAnG-V5/RoarTorch
    def __init__(
        self,
        dataset_args,
        train_data_args,
        val_data_args,
        device="cuda:1",
    ):
        """
        use_random_flip not used.
        """

        torch.cuda.set_device(device)
        self.cpu_count = get_cores_count()
        self.train_data_args = train_data_args
        self.val_data_args = val_data_args

        dataset_dir = dataset_args['dataset_dir']
        split_ratio = dataset_args.get('split_ratio', 7.0 / 8.0)
        assert split_ratio < 1.0, 'CIFAR train set should be split into train and cross-validation set.'

        # Use augmentations for training models but not during generating dataset.
        self.train_transform = CIFAR10.get_train_transform(
            enable_augmentation=train_data_args.get('enable_augmentation',
                                                    False))
        self.validation_transform = CIFAR10.get_validation_transform()
        # Normalization transform does (x - mean) / std
        # To denormalize use mean* = (-mean/std) and std* = (1/std)
        self.demean = [-m / s for m, s in zip(self.mean, self.std)]
        self.destd = [1 / s for s in self.std]
        self.denormalization_transform = torchvision.transforms.Normalize(
            self.demean, self.destd, inplace=False)

        self.trainset = torchvision.datasets.CIFAR10(
            root=dataset_dir,
            train=True,
            download=True,
            transform=self.train_transform)
        self.validationset = torchvision.datasets.CIFAR10(
            root=dataset_dir,
            train=True,
            download=True,
            transform=self.validation_transform)

        # Split train data into training and cross validation dataset using 9:1 split ration
        training_indices, validation_indices = self._uniform_train_val_split(
            self.trainset.targets, split_ratio)
        self.trainset = torch.utils.data.Subset(self.trainset,
                                                training_indices)
        self.validationset = torch.utils.data.Subset(self.validationset,
                                                     validation_indices)

        self.testset = torchvision.datasets.CIFAR10(
            root=dataset_dir,
            train=False,
            download=True,
            transform=self.validation_transform)
コード例 #6
0
def perform_perturbation_analysis(arguments):
    """
    """

    val_data_args = dict(batch_size=1, shuffle=False)
    """ Setup result directory """
    outdir = arguments['outdir']
    os.makedirs(outdir, exist_ok=True)
    print('Arguments:\n{}'.format(pformat(arguments)))
    """ Set random seed throughout python"""
    random_seed = random.randint(0, 1000)
    utils.set_random_seed(random_seed=random_seed)
    """ Set device - cpu or gpu """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f'Using device - {device}')
    """ Load Model with weights(if available) """
    dataset_args = arguments['data']
    model_args = arguments['pixel_perturbation_analysis']['model']
    model: torch.nn.Module = models_utils.get_model(model_args, device,
                                                    dataset_args).to(device)
    """ Load parameters for the Dataset """
    if dataset_args['dataset'] == 'ImageNet':
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])
        ])
        testset = datasets.ImageNet(dataset_args['dataset_dir'],
                                    split='val',
                                    transform=transform)
    else:
        # ToDo - Make uniform api for loading different datasets. Birdsnap/Imagenet/Food101 supports needs to be added.
        dataset = create_dataset(
            dataset_args,
            val_data_args,  # Just use val_data_args as train_data_args
            val_data_args)  # Split doesnt matter, we use test dataset
        testset = dataset.testset

    num_samples = min(arguments['pixel_perturbation_analysis']['test_samples'],
                      len(testset))
    print(
        f'Test dataset has {len(testset)} samples. We are using randomly selected {num_samples} samples for testing.'
    )

    testset = torch.utils.data.Subset(
        testset, random.sample(range(0, len(testset)), num_samples))
    dataloader = torch.utils.data.DataLoader(testset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=get_cores_count())

    # Attribution method and percentiles at which to test.
    attribution_methods = arguments['pixel_perturbation_analysis'][
        'attribution_methods']
    print('Running pixel perturbation analysis for: ', attribution_methods)

    # Step sizes to remove top k or bottom k
    percentiles = arguments['pixel_perturbation_analysis']['percentiles']

    # Save plots in outdir in outdir/DATASET_[train/test]/MODEL_ATTRIBUTIONMETHOD/ImageIndex.png
    model.eval()
    timestamp = datetime.datetime.now().isoformat()

    # To save sum of delta output change for each attribution method, percentile and
    # remove top and bottom percentile pixels.
    # ToDo - Use a numpy dictionary with key attribution names.
    #  Pros - Easy plotting. Load npy files without need to remember which index mapped to which attribution method.
    output_deviation_sum = np.zeros(
        (len(attribution_methods), len(percentiles), 2), dtype=float)

    # Save results in corresponding directory
    attribution_output_dir = os.path.join(outdir,
                                          timestamp)  # E.g. outdir/timestamp/
    os.makedirs(attribution_output_dir, exist_ok=True)

    for counter, data in enumerate(tqdm(dataloader, total=num_samples)):
        if counter == num_samples:
            break
        inputs, labels = data
        inputs = inputs.to(device)
        outputs = model(inputs).detach().cpu()
        _, max_prob_indices = torch.max(outputs.data, 1)
        outputs = torch.nn.functional.softmax(outputs, dim=1)
        outputs = outputs.numpy()

        for attribution_method_index, attribution_method in enumerate(
                attribution_methods):

            for preprocessed_image, max_prob_index, output in zip(
                    inputs, max_prob_indices, outputs):
                attribution_map = attribution_loader.generate_attribution(
                    model, preprocessed_image.unsqueeze(0),
                    max_prob_index.to(device), attribution_method)

                # To take absolute value for each pixel channel for each attribution method.
                attribution_map = np.max(attribution_map, axis=0)

                preprocessed_image = preprocessed_image.cpu().numpy()
                modified_images_bottom_remove = remove(
                    preprocessed_image.copy(),
                    attribution_map,
                    replace_value=[0, 0, 0],
                    # Black in original image is -mean/std in preprocessed image
                    percentiles=percentiles,
                    bottom=True,
                    gray=True)
                modified_images_top_remove = remove(
                    preprocessed_image.copy(),
                    attribution_map,
                    replace_value=[0, 0, 0],
                    # Black in original image is -mean/std in preprocessed image
                    percentiles=percentiles,
                    bottom=False,
                    gray=True)

                # Create a batch of all images
                modified_images_top_remove = torch.from_numpy(
                    np.stack(modified_images_top_remove, axis=0)).to(device)
                modified_images_bottom_remove = torch.from_numpy(
                    np.stack(modified_images_bottom_remove, axis=0)).to(device)

                # Run forward pass - ToDo - Do in single pass
                output_top_q = model(modified_images_top_remove)
                output_bottom_q = model(modified_images_bottom_remove)

                output_top_q = torch.nn.functional.softmax(output_top_q, dim=1)
                output_bottom_q = torch.nn.functional.softmax(output_bottom_q,
                                                              dim=1)

                output_top_q = output_top_q.detach().cpu().numpy()
                output_bottom_q = output_bottom_q.detach().cpu().numpy()

                # Get output value at max_prob_index for each percentile
                output_top_q_max_class_prob = output_top_q[:, max_prob_index]
                output_bottom_q_max_class_prob = output_bottom_q[:,
                                                                 max_prob_index]

                # Compute deviation from model output for original image at max_prob_index
                top_deviation = np.abs(
                    (output[max_prob_index] - output_top_q_max_class_prob) /
                    output[max_prob_index])
                bottom_deviation = np.abs(
                    (output[max_prob_index] - output_bottom_q_max_class_prob) /
                    output[max_prob_index])

                # Add this deviation to right dimension of matrix
                output_deviation_sum[attribution_method_index, :,
                                     0] += top_deviation
                output_deviation_sum[attribution_method_index, :,
                                     1] += bottom_deviation

        if counter % 500 == 499:
            # Divide output_deviation_sum each element by num_samples
            output_deviation_mean = output_deviation_sum * 100.0 / (counter +
                                                                    1)

            print("\nAffect of removal of most important pixels at:-")
            for attribution_method_index, attribution_method in enumerate(
                    attribution_methods):
                with np.printoptions(precision=3,
                                     formatter={'float': '{: 0.3f}'.format},
                                     suppress=True,
                                     linewidth=np.inf):
                    print(
                        attribution_method['name'].ljust(20) + ' = ',
                        np.array2string(
                            output_deviation_mean[attribution_method_index, :,
                                                  0],
                            separator=', '))

            print("Affect of removal of least important pixels at:-")
            for attribution_method_index, attribution_method in enumerate(
                    attribution_methods):
                with np.printoptions(precision=3,
                                     formatter={'float': '{: 0.3f}'.format},
                                     suppress=True,
                                     linewidth=np.inf):
                    print(
                        attribution_method['name'].ljust(20) + ' = ',
                        np.array2string(
                            output_deviation_mean[attribution_method_index, :,
                                                  1],
                            separator=', '))
            print()

    # Divide output_deviation_sum each element by num_samples
    output_deviation_mean = output_deviation_sum * 100.0 / num_samples

    # Save in directory
    np.save(os.path.join(attribution_output_dir, 'pixel_perturbation.npy'),
            output_deviation_mean)

    with np.printoptions(precision=3,
                         formatter={'float': '{: 0.3f}'.format},
                         suppress=True,
                         linewidth=np.inf):
        print("Affect of removal of most important pixels at:- \npercentiles ",
              percentiles)
        for ind, attr in enumerate(attribution_methods):
            print(attr['name'].ljust(20), output_deviation_mean[ind, :, 0])
    print()
    with np.printoptions(precision=3,
                         formatter={'float': '{: 0.3f}'.format},
                         suppress=True,
                         linewidth=np.inf):
        print(
            "Affect of removal of least important pixels at:- \npercentiles ",
            percentiles)
        for ind, attr in enumerate(attribution_methods):
            print(attr['name'].ljust(20), output_deviation_mean[ind, :, 1])
コード例 #7
0
 def get_validation_dataloader(self, data_args) -> DataLoader:
     self.mode = 'validation'
     return torch.utils.data.DataLoader(deepcopy(self),
                                        batch_size=data_args['batch_size'],
                                        shuffle=data_args['shuffle'],
                                        num_workers=get_cores_count())