Beispiel #1
0
def get_samples(target, nb_class=10, sample_index=0):
    '''
    params:
        target : [mnist, cifar10]
        nb_class : number of classes
        example_index : index of image by class

    returns:
        original_images (numpy array) : Original images, shape = (number of class, W, H, C)
        pre_images (torch array) : Preprocessing images, shape = (number of class, C, W, H)
        target_classes (dictionary) : keys = class index, values = class name
        model (pytorch model) : pretrained model
    '''

    if target == 'mnist':
        image_size = (28, 28, 1)

        _, _, testloader = mnist_load()
        testset = testloader.dataset

    elif target == 'cifar10':
        image_size = (32, 32, 3)

        _, _, testloader = cifar10_load()
        testset = testloader.dataset

    # idx2class
    target_class2idx = testset.class_to_idx
    target_classes = dict(
        zip(list(target_class2idx.values()), list(target_class2idx.keys())))

    # select images
    idx_by_class = [
        np.where(np.array(testset.targets) == i)[0][sample_index]
        for i in range(nb_class)
    ]
    original_images = testset.data[idx_by_class]
    if not isinstance(original_images, np.ndarray):
        original_images = original_images.numpy()
    original_images = original_images.reshape((nb_class, ) + image_size)
    # select targets
    if isinstance(testset.targets, list):
        original_targets = torch.LongTensor(testset.targets)[idx_by_class]
    else:
        original_targets = testset.targets[idx_by_class]

    # model load
    weights = torch.load('../checkpoint/simple_cnn_{}.pth'.format(target))
    model = SimpleCNN(target)
    model.load_state_dict(weights['model'])

    # image preprocessing
    pre_images = torch.zeros(original_images.shape)
    pre_images = np.transpose(pre_images, (0, 3, 1, 2))
    for i in range(len(original_images)):
        pre_images[i] = testset.transform(original_images[i])

    return original_images, original_targets, pre_images, target_classes, model
Beispiel #2
0
    # TODO: Tensorboard Check

    # python main.py --train --target=['mnist','cifar10'] --attention=['CAM','CBAM','RAN','WARN']
    if args.train:
        main(args=args)

    elif args.eval == 'selectivity':
        # make evalutation directory
        if not os.path.isdir('../evaluation'):
            os.mkdir('../evaluation')

        # pretrained model load
        weights = torch.load('../checkpoint/simple_cnn_{}.pth'.format(
            args.target))
        model = SimpleCNN(args.target)
        model.load_state_dict(weights['model'])

        # selectivity evaluation
        selectivity_method = Selectivity(model=model,
                                         target=args.target,
                                         batch_size=args.batch_size,
                                         method=args.method,
                                         sample_pct=args.ratio)
        # evaluation
        selectivity_method.eval(steps=args.steps, savedir='../evaluation')

    elif (args.eval == 'ROAR') or (args.eval == 'KAR'):
        # ratio
        ratio_lst = np.arange(0, 1, args.ratio)[1:]  # exclude zero
        for ratio in ratio_lst:
            main(args=args, ratio=ratio)