Ejemplo n.º 1
0
def get_samples(target, nb_class=10, sample_index=0):
    '''
    params:
        target : [mnist, cifar10]
        nb_class : number of classes
        example_index : index of image by class

    returns:
        original_images (numpy array) : Original images, shape = (number of class, W, H, C)
        pre_images (torch array) : Preprocessing images, shape = (number of class, C, W, H)
        target_classes (dictionary) : keys = class index, values = class name
        model (pytorch model) : pretrained model
    '''

    if target == 'mnist':
        image_size = (28, 28, 1)

        _, _, testloader = mnist_load()
        testset = testloader.dataset

    elif target == 'cifar10':
        image_size = (32, 32, 3)

        _, _, testloader = cifar10_load()
        testset = testloader.dataset

    # idx2class
    target_class2idx = testset.class_to_idx
    target_classes = dict(
        zip(list(target_class2idx.values()), list(target_class2idx.keys())))

    # select images
    idx_by_class = [
        np.where(np.array(testset.targets) == i)[0][sample_index]
        for i in range(nb_class)
    ]
    original_images = testset.data[idx_by_class]
    if not isinstance(original_images, np.ndarray):
        original_images = original_images.numpy()
    original_images = original_images.reshape((nb_class, ) + image_size)
    # select targets
    if isinstance(testset.targets, list):
        original_targets = torch.LongTensor(testset.targets)[idx_by_class]
    else:
        original_targets = testset.targets[idx_by_class]

    # model load
    weights = torch.load('../checkpoint/simple_cnn_{}.pth'.format(target))
    model = SimpleCNN(target)
    model.load_state_dict(weights['model'])

    # image preprocessing
    pre_images = torch.zeros(original_images.shape)
    pre_images = np.transpose(pre_images, (0, 3, 1, 2))
    for i in range(len(original_images)):
        pre_images[i] = testset.transform(original_images[i])

    return original_images, original_targets, pre_images, target_classes, model
Ejemplo n.º 2
0
    def __init__(self,
                 model,
                 target,
                 batch_size,
                 method,
                 ensemble=None,
                 sample_pct=0.1,
                 nb_class=10,
                 sample_index=0):
        '''
        Args
            model : pretrained model
            target : [mnist, cifar10]
            method : {
                VBP : Vanilla Backpropagation,
                IB : Input x Backpropagation,
                IG : Integrated Gradients,
                GB : Guided Backpropagation,
                GC : Grad CAM,
                GB-GC : Guided GradCAM,
                DeconvNet : DeconvNet,
            }
            ensemble : {
                SG : SmoothGrad,
                SG-SQ : SmoothGrad Square
                SG-VAR : SmoothGrad VAR
            }
            deconv_model : Deconvolution model. Only used if method is set to DeconvNet.
            nb_class : number of class
            sample_index : sample image index by class
        '''

        # data load
        if target == 'mnist':
            _, _, testloader = mnist_load()
        elif target == 'cifar10':
            _, _, testloader = cifar10_load()
        self.target = target
        self.testset = testloader.dataset
        self.img_size = self.testset.data.shape[
            1:]  # mnist : (28,28), cifar10 : (32,32,3)
        self.batch_size = batch_size

        # sampling
        seed_everything()
        sample_size = int(len(self.testset) * sample_pct)
        sample_idx = np.random.choice(len(self.testset),
                                      sample_size,
                                      replace=False)
        self.testset.data = self.testset.data[sample_idx]
        self.testset.targets = np.array(self.testset.targets)[sample_idx]
        self.sample_pct = sample_pct
        self.data_size = len(self.testset)

        # model setting
        self.model = model
        self.model.eval()
        self.deconv_model = None

        # saliency map
        self.method = method
        self.ensemble = ensemble
        self.saliency_map, self.layer, self.color = self.saliency_map_choice()

        # sample
        self.nb_class = nb_class
        self.nb_checkpoint = 5
        self.idx_by_class = [
            np.where(np.array(self.testset.targets) == i)[0][sample_index]
            for i in range(self.nb_class)
        ]
Ejemplo n.º 3
0
def main(args, **kwargs):
    #################################
    # Config
    #################################
    epochs = args.epochs
    batch_size = args.batch_size
    valid_rate = args.valid_rate
    lr = args.lr
    verbose = args.verbose

    # checkpoint
    target = args.target
    attention = args.attention
    monitor = args.monitor
    mode = args.mode

    # save name
    model_name = 'simple_cnn_{}'.format(target)
    if attention in ['CAM', 'CBAM']:
        model_name = model_name + '_{}'.format(attention)
    elif attention in ['RAN', 'WARN']:
        model_name = '{}_{}'.format(target, attention)

    # save directory
    savedir = '../checkpoint'
    logdir = '../logs'

    # device setting cpu or cuda(gpu)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('=====Setting=====')
    print('Training: ', args.train)
    print('Epochs: ', epochs)
    print('Batch Size: ', batch_size)
    print('Validation Rate: ', valid_rate)
    print('Learning Rate: ', lr)
    print('Target: ', target)
    print('Monitor: ', monitor)
    print('Model Name: ', model_name)
    print('Mode: ', mode)
    print('Attention: ', attention)
    print('Save Directory: ', savedir)
    print('Log Directory: ', logdir)
    print('Device: ', device)
    print('Verbose: ', verbose)
    print()
    print('Evaluation: ', args.eval)
    if args.eval != None:
        print('Pixel ratio: ', kwargs['ratio'])
    print()
    print('Setting Random Seed')
    print()
    seed_everything()  # seed setting

    #################################
    # Data Load
    #################################
    print('=====Data Load=====')
    if target == 'mnist':
        trainloader, validloader, testloader = mnist_load(
            batch_size=batch_size, validation_rate=valid_rate, shuffle=True)

    elif target == 'cifar10':
        trainloader, validloader, testloader = cifar10_load(
            batch_size=batch_size, validation_rate=valid_rate, shuffle=True)

    #################################
    # ROAR or KAR
    #################################
    if (args.eval == 'ROAR') or (args.eval == 'KAR'):
        # saliency map load
        filename = f'../saliency_maps/[{args.target}]{args.method}'
        if attention in ['CBAM', 'RAN']:
            filename += f'_{attention}'
        hf = h5py.File(f'{filename}_train.hdf5', 'r')
        sal_maps = np.array(hf['saliencys'])
        # adjust image
        data_lst = adjust_image(kwargs['ratio'], trainloader, sal_maps,
                                args.eval)
        # hdf5 close
        hf.close()
        # model name
        model_name = model_name + '_{0:}_{1:}{2:.1f}'.format(
            args.method, args.eval, kwargs['ratio'])

    # check exit
    if os.path.isfile('{}/{}_logs.txt'.format(logdir, model_name)):
        sys.exit()

    #################################
    # Load model
    #################################
    print('=====Model Load=====')
    if attention == 'RAN':
        net = RAN(target).to(device)
    elif attention == 'WARN':
        net = WideResNetAttention(target).to(device)
    else:
        net = SimpleCNN(target, attention).to(device)
    n_parameters = sum([np.prod(p.size()) for p in net.parameters()])
    print('Total number of parameters:', n_parameters)
    print()

    # Model compile
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)
    criterion = nn.CrossEntropyLoss()

    #################################
    # Train
    #################################
    modeltrain = ModelTrain(model=net,
                            data=trainloader,
                            epochs=epochs,
                            criterion=criterion,
                            optimizer=optimizer,
                            device=device,
                            model_name=model_name,
                            savedir=savedir,
                            monitor=monitor,
                            mode=mode,
                            validation=validloader,
                            verbose=verbose)

    #################################
    # Test
    #################################
    modeltest = ModelTest(model=net,
                          data=testloader,
                          loaddir=savedir,
                          model_name=model_name,
                          device=device)

    modeltrain.history['test_result'] = modeltest.results

    # History save as json file
    if not (os.path.isdir(logdir)):
        os.mkdir(logdir)
    with open(f'{logdir}/{model_name}_logs.txt', 'w') as outfile:
        json.dump(modeltrain.history, outfile)
Ejemplo n.º 4
0
def get_samples(target,
                nb_class=10,
                sample_index=0,
                attention=None,
                device='cpu'):
    '''
    Get samples : original images, preprocessed images, target class, trained model

    args:
    - target: [mnist, cifar10]
    - nb_class: number of classes
    - example_index: index of image by class

    return:
    - original_images (numpy array): Original images, shape = (number of class, W, H, C)
    - pre_images (torch array): Preprocessing images, shape = (number of class, C, W, H)
    - target_classes (dictionary): keys = class index, values = class name
    - model (pytorch model): pretrained model
    '''

    if target == 'mnist':
        image_size = (28, 28, 1)

        _, _, testloader = mnist_load()
        testset = testloader.dataset

    elif target == 'cifar10':
        image_size = (32, 32, 3)

        _, _, testloader = cifar10_load()
        testset = testloader.dataset

    # idx2class
    target_class2idx = testset.class_to_idx
    target_classes = dict(
        zip(list(target_class2idx.values()), list(target_class2idx.keys())))

    # select images
    idx_by_class = [
        np.where(np.array(testset.targets) == i)[0][sample_index]
        for i in range(nb_class)
    ]
    original_images = testset.data[idx_by_class]
    if not isinstance(original_images, np.ndarray):
        original_images = original_images.numpy()
    original_images = original_images.reshape((nb_class, ) + image_size)
    # select targets
    if isinstance(testset.targets, list):
        original_targets = torch.LongTensor(testset.targets)[idx_by_class]
    else:
        original_targets = testset.targets[idx_by_class]

    # model load
    filename = f'simple_cnn_{target}'
    if attention in ['CAM', 'CBAM']:
        filename += f'_{attention}'
    elif attention in ['RAN', 'WARN']:
        filename = f'{target}_{attention}'
    print('filename: ', filename)
    weights = torch.load(f'../checkpoint/{filename}.pth')

    if attention == 'RAN':
        model = RAN(target).to(device)
    elif attention == 'WARN':
        model = WideResNetAttention(target).to(device)
    else:
        model = SimpleCNN(target, attention).to(device)
    model.load_state_dict(weights['model'])

    # image preprocessing
    pre_images = torch.zeros(original_images.shape)
    pre_images = np.transpose(pre_images, (0, 3, 1, 2))
    for i in range(len(original_images)):
        pre_images[i] = testset.transform(original_images[i])

    return original_images, original_targets, pre_images, target_classes, model
Ejemplo n.º 5
0
def main(args, **kwargs):
    # Config
    epochs = args.epochs
    batch_size = args.batch_size
    valid_rate = args.valid_rate
    lr = args.lr
    verbose = args.verbose

    # checkpoint
    target = args.target
    monitor = args.monitor
    mode = args.mode

    # save name
    model_name = 'simple_cnn_{}'.format(target)

    # save directory
    savedir = '../checkpoint'
    logdir = '../logs'

    # device setting cpu or cuda(gpu)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('=====Setting=====')
    print('Epochs: ', epochs)
    print('Batch Size: ', batch_size)
    print('Validation Rate: ', valid_rate)
    print('Learning Rate: ', lr)
    print('Target: ', target)
    print('Monitor: ', monitor)
    print('Model Name: ', model_name)
    print('Mode: ', mode)
    print('Save Directory: ', savedir)
    print('Log Directory: ', logdir)
    print('Device: ', device)
    print('Verbose: ', verbose)
    print()
    print('Setting Random Seed')
    print()
    seed_everything()  # seed setting

    # Data Load
    print('=====Data Load=====')
    if target == 'mnist':
        trainloader, validloader, testloader = mnist_load(
            batch_size=batch_size, validation_rate=valid_rate, shuffle=True)

    elif target == 'cifar10':
        trainloader, validloader, testloader = cifar10_load(
            batch_size=batch_size, validation_rate=valid_rate, shuffle=True)

    # ROAR or KAR
    if (args.eval == 'ROAR') or (args.eval == 'KAR'):
        # saliency map load
        hf = h5py.File(
            f'../saliency_maps/[{args.target}]{args.method}_train.hdf5', 'r')
        sal_maps = np.array(hf['saliencys'])
        # adjust image
        data_lst = adjust_image(kwargs['ratio'], trainloader, sal_maps,
                                args.eval)
        # hdf5 close
        hf.close()
        # model name
        model_name = model_name + '_{0:}_{1:}{2:.1f}'.format(
            args.method, args.eval, kwargs['ratio'])

    print('=====Model Load=====')
    # Load model
    net = SimpleCNN(target).to(device)
    print()

    # Model compile
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)
    criterion = nn.CrossEntropyLoss()

    # Train
    modeltrain = ModelTrain(model=net,
                            data=trainloader,
                            epochs=epochs,
                            criterion=criterion,
                            optimizer=optimizer,
                            device=device,
                            model_name=model_name,
                            savedir=savedir,
                            monitor=monitor,
                            mode=mode,
                            validation=validloader,
                            verbose=verbose)
    # Test
    modeltest = ModelTest(model=net,
                          data=testloader,
                          loaddir=savedir,
                          model_name=model_name,
                          device=device)

    modeltrain.history['test_result'] = modeltest.results

    # History save as json file
    if not (os.path.isdir(logdir)):
        os.mkdir(logdir)
    with open(f'{logdir}/{model_name}_logs.txt', 'w') as outfile:
        json.dump(modeltrain.history, outfile)