示例#1
0
def selecticity_evaluation(args):
    # make directory
    if not os.path.isdir('../evaluation'):
        os.mkdir('../evaluation')

    # model load
    weights = torch.load('../checkpoint/simple_cnn_{}.pth'.format(args.target))
    model = SimpleCNN(args.target)
    model.load_state_dict(weights['model'])

    # evaluation method
    selectivity_method = Selectivity(model=model,
                                     target=args.target,
                                     batch_size=args.batch_size,
                                     method=args.method,
                                     sample_pct=args.ratio)
    # evaluation
    selectivity_method.eval(args.steps)
示例#2
0
def test():
    device = 'cpu'
    # loading model
    weights = torch.load(f"{baseAddr}/model.pth")
    model = SimpleCNN('cifar10', 'CAM').to(device)
    model.load_state_dict(weights['model'])

    #acquire sample image
    # !pip3 install pickle5
    import pickle5 as pickle
    file = open('images_sample.pickle', 'rb')
    sample_image = pickle.load(file)

    CAM_cifar10 = CAM(model)
    masked_imgs = attention_mask_filter(sample_image,
                                        showImage=0,
                                        payload={
                                            'model': CAM_cifar10,
                                            'mean': (0.4914, 0.4822, 0.4465),
                                            'std': (0.2023, 0.1994, 0.2010)
                                        })
    return masked_imgs
    # because we want to detect multiple events concurrently (at least later on, when we have more labels)
    criterion = nn.BCEWithLogitsLoss()

    # for most cases adam is a good choice
    optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

    start_epoch = 0

    # optionally, resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("no checkpoint found at '{}'".format(args.resume))

    for epoch in range(0, args.epochs):
        if epoch > args.linear_decrease_start_epoch:
            for g in optimizer.param_groups:
                g['lr'] = args.lr - args.lr * (
                    epoch - args.linear_decrease_start_epoch) / (
                        args.epochs - args.linear_decrease_start_epoch)

        tqdm.write(str(epoch))
        tqdm.write('Training')
示例#4
0
def main():
    parser = argparse.ArgumentParser(description='Domain adaptation experiments with Office dataset.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-m', '--model', default='MODAFM', type=str, metavar='', help='model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\'')
    parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/OfficeRsz', type=str, metavar='', help='data directory path')
    parser.add_argument('-t', '--target', default='amazon', type=str, metavar='', help='target domain (\'amazon\' / \'dslr\' / \'webcam\')')
    parser.add_argument('-i', '--input', default='msda.pth', type=str, metavar='', help='model file (output of train)')
    parser.add_argument('--arch', default='resnet50', type=str, metavar='', help='network architecture (\'resnet50\' / \'alexnet\'')
    parser.add_argument('--batch_size', default=20, type=int, metavar='', help='batch size (per domain)')
    parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU')
    parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed')
    args = vars(parser.parse_args())
    cfg = args.copy()

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    # normalization transformation (required for pretrained networks)
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    if 'FM' in cfg['model']:
        transform = T.ToTensor()
    else:
        transform = T.Compose([
            T.ToTensor(),
            normalize,
        ])

    domains = ['amazon', 'dslr', 'webcam']
    datasets = {domain: Office(cfg['data_path'], domain=domain, transform=transform) for domain in domains}
    n_classes = len(datasets[cfg['target']].class_names)

    if 'FM' in cfg['model']:
        test_set = deepcopy(datasets[cfg['target']])
        test_set.transform = T.ToTensor()  # no data augmentation in test
    else:
        test_set = datasets[cfg['target']]

    if cfg['model'] != 'FS':
        test_loader = {'target pub': DataLoader(test_set, batch_size=3*cfg['batch_size'])}
    else:
        train_indices = random.sample(range(len(datasets[cfg['target']])), int(0.8*len(datasets[cfg['target']])))
        test_indices = list(set(range(len(datasets[cfg['target']]))) - set(train_indices))
        test_loader = {'target pub': DataLoader(
            datasets[cfg['target']],
            batch_size=cfg['batch_size'],
            sampler=SubsetRandomSampler(test_indices))}
    log.print('target domain:', cfg['target'], level=1)

    if cfg['model'] in ['FS', 'FM']:
        model = SimpleCNN(n_classes=n_classes, arch=cfg['arch']).to(device)
    elif args['model'] == 'MDAN':
        model = MDANet(n_classes=n_classes, n_domains=len(domains)-1, arch=cfg['arch']).to(device)
    elif cfg['model'] in ['DANNS', 'DANNM', 'MODA', 'MODAFM']:
        model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    if cfg['model'] != 'DANNS':
        model.load_state_dict(torch.load(cfg['input']))
        accuracies, losses = test_routine(model, test_loader, cfg)
        print('target pub: acc = {:.3f},'.format(accuracies['target pub']), 'loss = {:.3f}'.format(losses['target pub']))

    else:  # for DANNS, report results for the best source domain
        src_domains = ['amazon', 'dslr', 'webcam']
        src_domains.remove(cfg['target'])
        for i, src in enumerate(src_domains):
            model.load_state_dict(torch.load(cfg['input']+'_'+src))
            acc, loss = test_routine(model, test_loader, cfg)
            if i == 0:
                accuracies = acc
                losses = loss
            else:
                for key in accuracies.keys():
                    accuracies[key] = acc[key] if (acc[key] > accuracies[key]) else accuracies[key]
                    losses[key] = loss[key] if (acc[key] > accuracies[key]) else losses[key]
        log.print('target pub: acc = {:.3f},'.format(accuracies['target pub']), 'loss = {:.3f}'.format(losses['target pub']), level=1)
示例#5
0
    # TODO: Tensorboard Check

    # python main.py --train --target=['mnist','cifar10'] --attention=['CAM','CBAM','RAN','WARN']
    if args.train:
        main(args=args)

    elif args.eval == 'selectivity':
        # make evalutation directory
        if not os.path.isdir('../evaluation'):
            os.mkdir('../evaluation')

        # pretrained model load
        weights = torch.load('../checkpoint/simple_cnn_{}.pth'.format(
            args.target))
        model = SimpleCNN(args.target)
        model.load_state_dict(weights['model'])

        # selectivity evaluation
        selectivity_method = Selectivity(model=model,
                                         target=args.target,
                                         batch_size=args.batch_size,
                                         method=args.method,
                                         sample_pct=args.ratio)
        # evaluation
        selectivity_method.eval(steps=args.steps, save_dir='../evaluation')

    elif (args.eval == 'ROAR') or (args.eval == 'KAR'):
        # ratio
        ratio_lst = np.arange(0, 1, args.ratio)[1:]  # exclude zero
        for ratio in ratio_lst:
            main(args=args, ratio=ratio)
示例#6
0
    # because we want to detect multiple events concurrently (at least later on, when we have more labels)
    criterion = nn.BCEWithLogitsLoss()

    # for most cases adam is a good choice
    optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

    start_epoch = 0

    # optionally, resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint["epoch"]
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            print("loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint["epoch"]))
        else:
            print("no checkpoint found at '{}'".format(args.resume))

    for epoch in range(0, args.epochs):
        if epoch > args.linear_decrease_start_epoch:
            for g in optimizer.param_groups:
                g["lr"] = args.lr - args.lr * (
                    epoch - args.linear_decrease_start_epoch) / (
                        args.epochs - args.linear_decrease_start_epoch)

        tqdm.write(str(epoch))
        tqdm.write("Training")
示例#7
0
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with digits datasets.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='MNIST',
        type=str,
        metavar='',
        help=
        'target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')')
    parser.add_argument('-i',
                        '--input',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--n_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from each domain')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())
    cfg = args.copy()

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    # define all datasets
    datasets = {}
    datasets['MNIST'] = MNIST(train=True,
                              path=os.path.join(cfg['data_path'], 'MNIST'),
                              transform=T.ToTensor())
    datasets['MNIST_M'] = MNIST_M(train=True,
                                  path=os.path.join(cfg['data_path'],
                                                    'MNIST_M'),
                                  transform=T.ToTensor())
    datasets['SVHN'] = SVHN(train=True,
                            path=os.path.join(cfg['data_path'], 'SVHN'),
                            transform=T.ToTensor())
    datasets['SynthDigits'] = SynthDigits(train=True,
                                          path=os.path.join(
                                              cfg['data_path'], 'SynthDigits'),
                                          transform=T.ToTensor())
    test_set = datasets[cfg['target']]

    # get a subset of cfg['n_images'] from each dataset
    # define public and private test sets: the private is not used at training time to learn invariant representations
    for ds_name in datasets:
        if ds_name == cfg['target']:
            indices = random.sample(range(len(datasets[ds_name])),
                                    2 * cfg['n_images'])
            test_pub_set = Subset(test_set, indices[0:cfg['n_images']])
            test_priv_set = Subset(test_set, indices[cfg['n_images']::])
        else:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_images'])
        datasets[ds_name] = Subset(datasets[ds_name],
                                   indices[0:cfg['n_images']])

    # build the dataloader
    test_pub_loader = DataLoader(test_pub_set,
                                 batch_size=4 * cfg['batch_size'])
    test_priv_loader = DataLoader(test_priv_set,
                                  batch_size=4 * cfg['batch_size'])
    test_loaders = {
        'target pub': test_pub_loader,
        'target priv': test_priv_loader
    }
    log.print('target domain:', cfg['target'], level=0)

    if cfg['model'] in ['FS', 'FM']:
        model = SimpleCNN().to(device)
    elif args['model'] == 'MDAN':
        model = MDANet(len(datasets) - 1).to(device)
    elif cfg['model'] in ['DANNS', 'DANNM', 'MODA', 'MODAFM']:
        model = MODANet().to(device)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    if cfg['model'] != 'DANNS':
        model.load_state_dict(torch.load(cfg['input']))
        accuracies, losses = test_routine(model, test_loaders, cfg)
        print('target pub: acc = {:.3f},'.format(accuracies['target pub']),
              'loss = {:.3f}'.format(losses['target pub']))
        print('target priv: acc = {:.3f},'.format(accuracies['target priv']),
              'loss = {:.3f}'.format(losses['target priv']))

    else:  # for DANNS, report results for the best source domain
        src_domains = ['MNIST', 'MNIST_M', 'SVHN', 'SynthDigits']
        src_domains.remove(cfg['target'])
        for i, src in enumerate(src_domains):
            model.load_state_dict(torch.load(cfg['input'] + '_' + src))
            acc, loss = test_routine(model, test_loaders, cfg)
            if i == 0:
                accuracies = acc
                losses = loss
            else:
                for key in accuracies.keys():
                    accuracies[key] = acc[key] if (
                        acc[key] > accuracies[key]) else accuracies[key]
                    losses[key] = loss[key] if (
                        acc[key] > accuracies[key]) else losses[key]
        log.print('target pub: acc = {:.3f},'.format(accuracies['target pub']),
                  'loss = {:.3f}'.format(losses['target pub']),
                  level=1)
        log.print('target priv: acc = {:.3f},'.format(
            accuracies['target priv']),
                  'loss = {:.3f}'.format(losses['target priv']),
                  level=1)