示例#1
0
def selecticity_evaluation(args):
    # make directory
    if not os.path.isdir('../evaluation'):
        os.mkdir('../evaluation')

    # model load
    weights = torch.load('../checkpoint/simple_cnn_{}.pth'.format(args.target))
    model = SimpleCNN(args.target)
    model.load_state_dict(weights['model'])

    # evaluation method
    selectivity_method = Selectivity(model=model,
                                     target=args.target,
                                     batch_size=args.batch_size,
                                     method=args.method,
                                     sample_pct=args.ratio)
    # evaluation
    selectivity_method.eval(args.steps)
示例#2
0
def train(seq_len,
          window_size,
          model_type,
          params,
          batch_size=16,
          num_epochs=20,
          print_every=5):
    metrics = []
    max_val_f2_score = 0.
    best_model = None

    train_data, validation_data = load_pytorch_data(seq_len, window_size)
    if model_type == 'LSTM':
        model = SimpleLSTM(INPUT_SIZE, params['lstm_hidden_size'])
    elif model_type == 'CNN':
        model = SimpleCNN(int(HZ * seq_len), params['cnn_hidden_size'])
    else:
        raise Exception('invalid model type')
    optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
    criterion = torch.nn.CrossEntropyLoss(
        weight=torch.tensor(params['loss_weights']))
    print('starting training!')
    for epoch in range(num_epochs):
        print('starting epoch {}...'.format(epoch))
        for iter, (X_batch, y_batch, idx) in enumerate(train_data):
            X_batch = X_batch.float()
            y_batch = y_batch.long()
            output = model(X_batch)
            output = torch.squeeze(output, 0)
            loss = criterion(output, y_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if iter % print_every == 0:
                # print('Iter {} loss: {}'.format(iter, loss.item()))
                f1_val, f2_val, precision_val, recall_val, accuracy_val = check_accuracy(
                    model, validation_data, False)
                f1_train, f2_train, precision_train, recall_train, accuracy_train = check_accuracy(
                    model, train_data, False)
                train_loss = loss.item()
                metrics.append(
                    '{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
                        train_loss, f1_val, f2_val, precision_train,
                        recall_val, accuracy_val, f1_train, f2_train,
                        precision_train, recall_train, accuracy_train))

                if f2_val > max_val_f2_score:
                    max_val_f2_score = f2_val
                    best_model = copy.deepcopy(model)

    print('finished training!')
    return best_model, max_val_f2_score, metrics
示例#3
0
def test():
    device = 'cpu'
    # loading model
    weights = torch.load(f"{baseAddr}/model.pth")
    model = SimpleCNN('cifar10', 'CAM').to(device)
    model.load_state_dict(weights['model'])

    #acquire sample image
    # !pip3 install pickle5
    import pickle5 as pickle
    file = open('images_sample.pickle', 'rb')
    sample_image = pickle.load(file)

    CAM_cifar10 = CAM(model)
    masked_imgs = attention_mask_filter(sample_image,
                                        showImage=0,
                                        payload={
                                            'model': CAM_cifar10,
                                            'mean': (0.4914, 0.4822, 0.4465),
                                            'std': (0.2023, 0.1994, 0.2010)
                                        })
    return masked_imgs
            acc.update((output, targets))

    acc_value = acc.compute()
    test_loss /= len(test_loader.sampler)
    writer.add_scalar('Test Loss', test_loss, int((epoch + 1)))
    writer.add_scalar('Test Acc', acc_value, int((epoch + 1)))

    print('Test set: Average loss: {:.4f}, Accuracy: ({:.0f}%)\n'.format(
        test_loss, acc_value * 100))


if __name__ == "__main__":
    writer = SummaryWriter(tmp_dir.joinpath(args.run_id, 'log'))

    model = SimpleCNN(num_targets=len(train_dataset.classes),
                      num_channels=num_channels)
    model = model.to(device)

    # we choose binary cross entropy loss with logits (i.e. sigmoid applied before calculating loss)
    # because we want to detect multiple events concurrently (at least later on, when we have more labels)
    criterion = nn.BCEWithLogitsLoss()

    # for most cases adam is a good choice
    optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

    start_epoch = 0

    # optionally, resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
示例#5
0
文件: train.py 项目: dpernes/modafm
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with digits datasets.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='MNIST',
        type=str,
        metavar='',
        help=
        'target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')')
    parser.add_argument('-o',
                        '--output',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--icfg',
                        default=None,
                        type=str,
                        metavar='',
                        help='config file (overrides args)')
    parser.add_argument('--n_src_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from each source domain')
    parser.add_argument('--n_tgt_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from the target domain')
    parser.add_argument(
        '--mu_d',
        type=float,
        default=1e-2,
        help=
        "hyperparameter of the coefficient for the domain discriminator loss")
    parser.add_argument(
        '--mu_s',
        type=float,
        default=0.2,
        help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--mu_c',
                        type=float,
                        default=1e-1,
                        help="hyperparameter of the FixMatch loss")
    parser.add_argument('--n_rand_aug',
                        type=int,
                        default=2,
                        help="N parameter of RandAugment")
    parser.add_argument('--m_min_rand_aug',
                        type=int,
                        default=3,
                        help="minimum M parameter of RandAugment")
    parser.add_argument('--m_max_rand_aug',
                        type=int,
                        default=10,
                        help="maximum M parameter of RandAugment")
    parser.add_argument('--weight_decay',
                        default=0.,
                        type=float,
                        metavar='',
                        help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr',
                        default=1e-1,
                        type=float,
                        metavar='',
                        help='learning rate')
    parser.add_argument('--epochs',
                        default=30,
                        type=int,
                        metavar='',
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size (per domain)')
    parser.add_argument(
        '--checkpoint',
        default=0,
        type=int,
        metavar='',
        help=
        'number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--eval_target',
                        default=False,
                        type=int,
                        metavar='',
                        help='evaluate target during training')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--use_visdom',
                        default=False,
                        type=int,
                        metavar='',
                        help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env',
                        default='digits_train',
                        type=str,
                        metavar='',
                        help='Visdom environment name')
    parser.add_argument('--visdom_port',
                        default=8888,
                        type=int,
                        metavar='',
                        help='Visdom port')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())

    # override args with icfg (if provided)
    cfg = args.copy()
    if cfg['icfg'] is not None:
        cv_parser = ConfigParser()
        cv_parser.read(cfg['icfg'])
        cv_param_names = []
        for key, val in cv_parser.items('main'):
            cfg[key] = ast.literal_eval(val)
            cv_param_names.append(key)

    # dump cfg to a txt file for your records
    with open(cfg['output'] + '.txt', 'w') as f:
        f.write(str(cfg) + '\n')

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    if ('FS' in cfg['model']) or ('FM' in cfg['model']):
        # weak data augmentation (small rotation + small translation)
        data_aug = T.Compose([
            T.RandomAffine(5, translate=(0.125, 0.125)),
            T.ToTensor(),
        ])
    else:
        data_aug = T.ToTensor()

    # define all datasets
    datasets = {}
    datasets['MNIST'] = MNIST(train=True,
                              path=os.path.join(cfg['data_path'], 'MNIST'),
                              transform=data_aug)
    datasets['MNIST_M'] = MNIST_M(train=True,
                                  path=os.path.join(cfg['data_path'],
                                                    'MNIST_M'),
                                  transform=data_aug)
    datasets['SVHN'] = SVHN(train=True,
                            path=os.path.join(cfg['data_path'], 'SVHN'),
                            transform=data_aug)
    datasets['SynthDigits'] = SynthDigits(train=True,
                                          path=os.path.join(
                                              cfg['data_path'], 'SynthDigits'),
                                          transform=data_aug)
    if ('FS' in cfg['model']) or ('FM' in cfg['model']):
        test_set = deepcopy(datasets[cfg['target']])
        test_set.transform = T.ToTensor()  # no data augmentation in test
    else:
        test_set = datasets[cfg['target']]

    # get a subset of cfg['n_images'] from each dataset
    # define public and private test sets: the private is not used at training time to learn invariant representations
    for ds_name in datasets:
        if ds_name == cfg['target']:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_tgt_images'] + cfg['n_src_images'])
            test_pub_set = Subset(test_set, indices[0:cfg['n_tgt_images']])
            test_priv_set = Subset(test_set, indices[cfg['n_tgt_images']::])
            datasets[cfg['target']] = Subset(datasets[cfg['target']],
                                             indices[0:cfg['n_tgt_images']])
        else:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_src_images'])
            datasets[ds_name] = Subset(datasets[ds_name],
                                       indices[0:cfg['n_src_images']])

    # build the dataloader
    train_loader = MSDA_Loader(datasets,
                               cfg['target'],
                               batch_size=cfg['batch_size'],
                               shuffle=True,
                               device=device)
    test_pub_loader = DataLoader(test_pub_set,
                                 batch_size=4 * cfg['batch_size'])
    test_priv_loader = DataLoader(test_priv_set,
                                  batch_size=4 * cfg['batch_size'])
    valid_loaders = ({
        'target pub': test_pub_loader,
        'target priv': test_priv_loader
    } if cfg['eval_target'] else None)
    log.print('target domain:',
              cfg['target'],
              '| source domains:',
              train_loader.sources,
              level=1)

    if cfg['model'] == 'FS':
        model = SimpleCNN().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        if valid_loaders is not None:
            del valid_loaders['target pub']
        fs_train_routine(model, optimizer, test_pub_loader, valid_loaders, cfg)
    elif cfg['model'] == 'FM':
        model = SimpleCNN().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        cfg['excl_transf'] = [Flip]
        fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'DANNS':
        for src in train_loader.sources:
            model = MODANet().to(device)
            optimizer = optim.Adadelta(model.parameters(),
                                       lr=cfg['lr'],
                                       weight_decay=cfg['weight_decay'])
            dataset_ss = {
                src: datasets[src],
                cfg['target']: datasets[cfg['target']]
            }
            train_loader = MSDA_Loader(dataset_ss,
                                       cfg['target'],
                                       batch_size=cfg['batch_size'],
                                       shuffle=True,
                                       device=device)
            dann_train_routine(model, optimizer, train_loader, valid_loaders,
                               cfg)
            torch.save(model.state_dict(), cfg['output'] + '_' + src)
    elif cfg['model'] == 'DANNM':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MDAN':
        model = MDANet(len(train_loader.sources)).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MDANU':
        model = MDANet(len(train_loader.sources)).to(device)
        model.grad_reverse = nn.ModuleList([
            nn.Identity() for _ in range(len(model.domain_class))
        ])  # remove grad reverse
        task_optim = optim.Adadelta(list(model.feat_ext.parameters()) +
                                    list(model.task_class.parameters()),
                                    lr=cfg['lr'],
                                    weight_decay=cfg['weight_decay'])
        adv_optim = optim.Adadelta(model.domain_class.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        optimizers = (task_optim, adv_optim)
        mdan_unif_train_routine(model, optimizers, train_loader, valid_loaders,
                                cfg)
    elif cfg['model'] == 'MDANFM':
        model = MDANet(len(train_loader.sources)).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mdan_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)
    elif cfg['model'] == 'MDANUFM':
        model = MDANet(len(train_loader.sources)).to(device)
        task_optim = optim.Adadelta(list(model.feat_ext.parameters()) +
                                    list(model.task_class.parameters()),
                                    lr=cfg['lr'],
                                    weight_decay=cfg['weight_decay'])
        adv_optim = optim.Adadelta(model.domain_class.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        optimizers = (task_optim, adv_optim)
        cfg['excl_transf'] = [Flip]
        mdan_unif_fm_train_routine(model, optimizer, train_loader,
                                   valid_loaders, cfg)
    elif cfg['model'] == 'MODA':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MODAFM':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        cfg['excl_transf'] = [Flip]
        moda_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    if cfg['model'] != 'DANNS':
        torch.save(model.state_dict(), cfg['output'])
示例#6
0
def main():
    parser = argparse.ArgumentParser(description='Domain adaptation experiments with Office dataset.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-m', '--model', default='MODAFM', type=str, metavar='', help='model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\'')
    parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/OfficeRsz', type=str, metavar='', help='data directory path')
    parser.add_argument('-t', '--target', default='amazon', type=str, metavar='', help='target domain (\'amazon\' / \'dslr\' / \'webcam\')')
    parser.add_argument('-i', '--input', default='msda.pth', type=str, metavar='', help='model file (output of train)')
    parser.add_argument('--arch', default='resnet50', type=str, metavar='', help='network architecture (\'resnet50\' / \'alexnet\'')
    parser.add_argument('--batch_size', default=20, type=int, metavar='', help='batch size (per domain)')
    parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU')
    parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed')
    args = vars(parser.parse_args())
    cfg = args.copy()

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    # normalization transformation (required for pretrained networks)
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    if 'FM' in cfg['model']:
        transform = T.ToTensor()
    else:
        transform = T.Compose([
            T.ToTensor(),
            normalize,
        ])

    domains = ['amazon', 'dslr', 'webcam']
    datasets = {domain: Office(cfg['data_path'], domain=domain, transform=transform) for domain in domains}
    n_classes = len(datasets[cfg['target']].class_names)

    if 'FM' in cfg['model']:
        test_set = deepcopy(datasets[cfg['target']])
        test_set.transform = T.ToTensor()  # no data augmentation in test
    else:
        test_set = datasets[cfg['target']]

    if cfg['model'] != 'FS':
        test_loader = {'target pub': DataLoader(test_set, batch_size=3*cfg['batch_size'])}
    else:
        train_indices = random.sample(range(len(datasets[cfg['target']])), int(0.8*len(datasets[cfg['target']])))
        test_indices = list(set(range(len(datasets[cfg['target']]))) - set(train_indices))
        test_loader = {'target pub': DataLoader(
            datasets[cfg['target']],
            batch_size=cfg['batch_size'],
            sampler=SubsetRandomSampler(test_indices))}
    log.print('target domain:', cfg['target'], level=1)

    if cfg['model'] in ['FS', 'FM']:
        model = SimpleCNN(n_classes=n_classes, arch=cfg['arch']).to(device)
    elif args['model'] == 'MDAN':
        model = MDANet(n_classes=n_classes, n_domains=len(domains)-1, arch=cfg['arch']).to(device)
    elif cfg['model'] in ['DANNS', 'DANNM', 'MODA', 'MODAFM']:
        model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    if cfg['model'] != 'DANNS':
        model.load_state_dict(torch.load(cfg['input']))
        accuracies, losses = test_routine(model, test_loader, cfg)
        print('target pub: acc = {:.3f},'.format(accuracies['target pub']), 'loss = {:.3f}'.format(losses['target pub']))

    else:  # for DANNS, report results for the best source domain
        src_domains = ['amazon', 'dslr', 'webcam']
        src_domains.remove(cfg['target'])
        for i, src in enumerate(src_domains):
            model.load_state_dict(torch.load(cfg['input']+'_'+src))
            acc, loss = test_routine(model, test_loader, cfg)
            if i == 0:
                accuracies = acc
                losses = loss
            else:
                for key in accuracies.keys():
                    accuracies[key] = acc[key] if (acc[key] > accuracies[key]) else accuracies[key]
                    losses[key] = loss[key] if (acc[key] > accuracies[key]) else losses[key]
        log.print('target pub: acc = {:.3f},'.format(accuracies['target pub']), 'loss = {:.3f}'.format(losses['target pub']), level=1)
示例#7
0
def get_samples(target,
                nb_class=10,
                sample_index=0,
                attention=None,
                device='cpu'):
    '''
    Get samples : original images, preprocessed images, target class, trained model

    args:
    - target: [mnist, cifar10]
    - nb_class: number of classes
    - example_index: index of image by class

    return:
    - original_images (numpy array): Original images, shape = (number of class, W, H, C)
    - pre_images (torch array): Preprocessing images, shape = (number of class, C, W, H)
    - target_classes (dictionary): keys = class index, values = class name
    - model (pytorch model): pretrained model
    '''

    if target == 'mnist':
        image_size = (28, 28, 1)

        _, _, testloader = mnist_load()
        testset = testloader.dataset

    elif target == 'cifar10':
        image_size = (32, 32, 3)

        _, _, testloader = cifar10_load()
        testset = testloader.dataset

    # idx2class
    target_class2idx = testset.class_to_idx
    target_classes = dict(
        zip(list(target_class2idx.values()), list(target_class2idx.keys())))

    # select images
    idx_by_class = [
        np.where(np.array(testset.targets) == i)[0][sample_index]
        for i in range(nb_class)
    ]
    original_images = testset.data[idx_by_class]
    if not isinstance(original_images, np.ndarray):
        original_images = original_images.numpy()
    original_images = original_images.reshape((nb_class, ) + image_size)
    # select targets
    if isinstance(testset.targets, list):
        original_targets = torch.LongTensor(testset.targets)[idx_by_class]
    else:
        original_targets = testset.targets[idx_by_class]

    # model load
    filename = f'simple_cnn_{target}'
    if attention in ['CAM', 'CBAM']:
        filename += f'_{attention}'
    elif attention in ['RAN', 'WARN']:
        filename = f'{target}_{attention}'
    print('filename: ', filename)
    weights = torch.load(f'../checkpoint/{filename}.pth')

    if attention == 'RAN':
        model = RAN(target).to(device)
    elif attention == 'WARN':
        model = WideResNetAttention(target).to(device)
    else:
        model = SimpleCNN(target, attention).to(device)
    model.load_state_dict(weights['model'])

    # image preprocessing
    pre_images = torch.zeros(original_images.shape)
    pre_images = np.transpose(pre_images, (0, 3, 1, 2))
    for i in range(len(original_images)):
        pre_images[i] = testset.transform(original_images[i])

    return original_images, original_targets, pre_images, target_classes, model
示例#8
0
def main(args, **kwargs):
    #################################
    # Config
    #################################
    epochs = args.epochs
    batch_size = args.batch_size
    valid_rate = args.valid_rate
    lr = args.lr
    verbose = args.verbose

    # checkpoint
    target = args.target
    attention = args.attention
    monitor = args.monitor
    mode = args.mode

    # save name
    model_name = 'simple_cnn_{}'.format(target)
    if attention in ['CAM', 'CBAM']:
        model_name = model_name + '_{}'.format(attention)
    elif attention in ['RAN', 'WARN']:
        model_name = '{}_{}'.format(target, attention)

    # save directory
    savedir = '../checkpoint'
    logdir = '../logs'

    # device setting cpu or cuda(gpu)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('=====Setting=====')
    print('Training: ', args.train)
    print('Epochs: ', epochs)
    print('Batch Size: ', batch_size)
    print('Validation Rate: ', valid_rate)
    print('Learning Rate: ', lr)
    print('Target: ', target)
    print('Monitor: ', monitor)
    print('Model Name: ', model_name)
    print('Mode: ', mode)
    print('Attention: ', attention)
    print('Save Directory: ', savedir)
    print('Log Directory: ', logdir)
    print('Device: ', device)
    print('Verbose: ', verbose)
    print()
    print('Evaluation: ', args.eval)
    if args.eval != None:
        print('Pixel ratio: ', kwargs['ratio'])
    print()
    print('Setting Random Seed')
    print()
    seed_everything()  # seed setting

    #################################
    # Data Load
    #################################
    print('=====Data Load=====')
    if target == 'mnist':
        trainloader, validloader, testloader = mnist_load(
            batch_size=batch_size, validation_rate=valid_rate, shuffle=True)

    elif target == 'cifar10':
        trainloader, validloader, testloader = cifar10_load(
            batch_size=batch_size, validation_rate=valid_rate, shuffle=True)

    #################################
    # ROAR or KAR
    #################################
    if (args.eval == 'ROAR') or (args.eval == 'KAR'):
        # saliency map load
        filename = f'../saliency_maps/[{args.target}]{args.method}'
        if attention in ['CBAM', 'RAN']:
            filename += f'_{attention}'
        hf = h5py.File(f'{filename}_train.hdf5', 'r')
        sal_maps = np.array(hf['saliencys'])
        # adjust image
        trainloader = adjust_image(kwargs['ratio'], trainloader, sal_maps,
                                   args.eval)
        # hdf5 close
        hf.close()
        # model name
        model_name = model_name + '_{0:}_{1:}{2:.1f}'.format(
            args.method, args.eval, kwargs['ratio'])

    # check exit
    if os.path.isfile('{}/{}_logs.txt'.format(logdir, model_name)):
        sys.exit()

    #################################
    # Load model
    #################################
    print('=====Model Load=====')
    if attention == 'RAN':
        net = RAN(target).to(device)
    elif attention == 'WARN':
        net = WideResNetAttention(target).to(device)
    else:
        net = SimpleCNN(target, attention).to(device)
    n_parameters = sum([np.prod(p.size()) for p in net.parameters()])
    print('Total number of parameters:', n_parameters)
    print()

    # Model compile
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)
    criterion = nn.CrossEntropyLoss()

    #################################
    # Train
    #################################
    modeltrain = ModelTrain(model=net,
                            data=trainloader,
                            epochs=epochs,
                            criterion=criterion,
                            optimizer=optimizer,
                            device=device,
                            model_name=model_name,
                            savedir=savedir,
                            monitor=monitor,
                            mode=mode,
                            validation=validloader,
                            verbose=verbose)

    #################################
    # Test
    #################################
    modeltest = ModelTest(model=net,
                          data=testloader,
                          loaddir=savedir,
                          model_name=model_name,
                          device=device)

    modeltrain.history['test_result'] = modeltest.results

    # History save as json file
    if not (os.path.isdir(logdir)):
        os.mkdir(logdir)
    with open(f'{logdir}/{model_name}_logs.txt', 'w') as outfile:
        json.dump(modeltrain.history, outfile)
示例#9
0
    # TODO: Tensorboard Check

    # python main.py --train --target=['mnist','cifar10'] --attention=['CAM','CBAM','RAN','WARN']
    if args.train:
        main(args=args)

    elif args.eval == 'selectivity':
        # make evalutation directory
        if not os.path.isdir('../evaluation'):
            os.mkdir('../evaluation')

        # pretrained model load
        weights = torch.load('../checkpoint/simple_cnn_{}.pth'.format(
            args.target))
        model = SimpleCNN(args.target)
        model.load_state_dict(weights['model'])

        # selectivity evaluation
        selectivity_method = Selectivity(model=model,
                                         target=args.target,
                                         batch_size=args.batch_size,
                                         method=args.method,
                                         sample_pct=args.ratio)
        # evaluation
        selectivity_method.eval(steps=args.steps, save_dir='../evaluation')

    elif (args.eval == 'ROAR') or (args.eval == 'KAR'):
        # ratio
        ratio_lst = np.arange(0, 1, args.ratio)[1:]  # exclude zero
        for ratio in ratio_lst:
示例#10
0
文件: train.py 项目: dpernes/modafm
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with the DomainNet dataset.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/DomainNet192',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='clipart',
        type=str,
        metavar='',
        help=
        'target domain (\'clipart\' / \'infograph\' / \'painting\' / \'quickdraw\' / \'real\' / \'sketch\')'
    )
    parser.add_argument('-o',
                        '--output',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--icfg',
                        default=None,
                        type=str,
                        metavar='',
                        help='config file (overrides args)')
    parser.add_argument(
        '--arch',
        default='resnet152',
        type=str,
        metavar='',
        help='network architecture (\'resnet101\' / \'resnet152\'')
    parser.add_argument(
        '--mu_d',
        type=float,
        default=1e-2,
        help=
        "hyperparameter of the coefficient for the domain discriminator loss")
    parser.add_argument(
        '--mu_s',
        type=float,
        default=0.2,
        help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--mu_c',
                        type=float,
                        default=1e-1,
                        help="hyperparameter of the FixMatch loss")
    parser.add_argument('--n_rand_aug',
                        type=int,
                        default=2,
                        help="N parameter of RandAugment")
    parser.add_argument('--m_min_rand_aug',
                        type=int,
                        default=3,
                        help="minimum M parameter of RandAugment")
    parser.add_argument('--m_max_rand_aug',
                        type=int,
                        default=10,
                        help="maximum M parameter of RandAugment")
    parser.add_argument('--weight_decay',
                        default=0.,
                        type=float,
                        metavar='',
                        help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr',
                        default=1e-3,
                        type=float,
                        metavar='',
                        help='learning rate')
    parser.add_argument('--epochs',
                        default=50,
                        type=int,
                        metavar='',
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size (per domain)')
    parser.add_argument(
        '--checkpoint',
        default=0,
        type=int,
        metavar='',
        help=
        'number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--eval_target',
                        default=False,
                        type=int,
                        metavar='',
                        help='evaluate target during training')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--use_visdom',
                        default=False,
                        type=int,
                        metavar='',
                        help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env',
                        default='domainnet_train',
                        type=str,
                        metavar='',
                        help='Visdom environment name')
    parser.add_argument('--visdom_port',
                        default=8888,
                        type=int,
                        metavar='',
                        help='Visdom port')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())

    # override args with icfg (if provided)
    cfg = args.copy()
    if cfg['icfg'] is not None:
        cv_parser = ConfigParser()
        cv_parser.read(cfg['icfg'])
        cv_param_names = []
        for key, val in cv_parser.items('main'):
            cfg[key] = ast.literal_eval(val)
            cv_param_names.append(key)

    # dump args to a txt file for your records
    with open(cfg['output'] + '.txt', 'w') as f:
        f.write(str(cfg) + '\n')

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    # normalization transformation (required for pretrained networks)
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    if 'FM' in cfg['model']:
        # weak data augmentation (small rotation + small translation)
        data_aug = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.RandomHorizontalFlip(),
            T.RandomAffine(5, translate=(0.125, 0.125)),
            T.ToTensor(),
            # normalize,  # normalization disrupts FixMatch
        ])

        eval_transf = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.ToTensor(),
        ])

    else:
        data_aug = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            normalize,
        ])

        eval_transf = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.ToTensor(),
            normalize,
        ])

    domains = [
        'clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'
    ]
    datasets = {
        domain: DomainNet(cfg['data_path'],
                          domain=domain,
                          train=True,
                          transform=data_aug)
        for domain in domains
    }
    n_classes = len(datasets[cfg['target']].class_names)

    test_set = DomainNet(cfg['data_path'],
                         domain=cfg['target'],
                         train=False,
                         transform=eval_transf)
    if 'FM' in cfg['model']:
        target_pub = deepcopy(datasets[cfg['target']])
        target_pub.transform = eval_transf  # no data augmentation in test
    else:
        target_pub = datasets[cfg['target']]

    if cfg['model'] != 'FS':
        train_loader = MSDA_Loader(datasets,
                                   cfg['target'],
                                   batch_size=cfg['batch_size'],
                                   shuffle=True,
                                   num_workers=0,
                                   device=device)
        if cfg['eval_target']:
            valid_loaders = {
                'target pub':
                DataLoader(target_pub, batch_size=6 * cfg['batch_size']),
                'target priv':
                DataLoader(test_set, batch_size=6 * cfg['batch_size'])
            }
        else:
            valid_loaders = None
        log.print('target domain:',
                  cfg['target'],
                  '| source domains:',
                  train_loader.sources,
                  level=1)
    else:
        train_loader = DataLoader(datasets[cfg['target']],
                                  batch_size=cfg['batch_size'],
                                  shuffle=True)
        test_loader = DataLoader(test_set, batch_size=cfg['batch_size'])
        log.print('target domain:', cfg['target'], level=1)

    if cfg['model'] == 'FS':
        model = SimpleCNN(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        valid_loaders = {
            'target pub': test_loader
        } if cfg['eval_target'] else None
        fs_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'FM':
        model = SimpleCNN(n_classes=n_classes, arch=cfg['arch']).to(device)
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        cfg['excl_transf'] = None
        fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'DANNS':
        for src in train_loader.sources:
            model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
            conv_params, fc_params = [], []
            for name, param in model.named_parameters():
                if 'fc' in name.lower():
                    fc_params.append(param)
                else:
                    conv_params.append(param)
            optimizer = optim.Adadelta([{
                'params': conv_params,
                'lr': 0.1 * cfg['lr'],
                'weight_decay': cfg['weight_decay']
            }, {
                'params': fc_params,
                'lr': cfg['lr'],
                'weight_decay': cfg['weight_decay']
            }])
            dataset_ss = {
                src: datasets[src],
                cfg['target']: datasets[cfg['target']]
            }
            train_loader = MSDA_Loader(dataset_ss,
                                       cfg['target'],
                                       batch_size=cfg['batch_size'],
                                       shuffle=True,
                                       device=device)
            dann_train_routine(model, optimizer, train_loader, valid_loaders,
                               cfg)
            torch.save(model.state_dict(), cfg['output'] + '_' + src)

    elif cfg['model'] == 'DANNM':
        model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif args['model'] == 'MDAN':
        model = MDANet(n_classes=n_classes,
                       n_domains=len(train_loader.sources),
                       arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'MODA':
        model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'MODAFM':
        model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        cfg['excl_transf'] = None
        moda_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)

    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    torch.save(model.state_dict(), cfg['output'])
示例#11
0
def get_model(args):
    return SimpleCNN(args.gsz)
示例#12
0
        correct = 0
        total = 0

        for image, label in test_loader:
            output = model(image)
            _, pred = torch.max(output, 1)
            if label.item() == pred.item():
                correct += 1

        return correct / len(test_loader) * 100


tf = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307, ), (0.3081, ))])
train_data = datasets.MNIST(root='./data/',
                            train=True,
                            transform=tf,
                            download=True)

test_data = datasets.MNIST(root='./data/', train=False, transform=tf)
torch.manual_seed(33)

#model = LinearModel()
model = SimpleCNN()
batch_size = 64
epochs = 50

train(model, test_data, batch_size, epochs)

print("percentage correctly classified images: ", eval(model, test_data))
示例#13
0
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with digits datasets.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='MNIST',
        type=str,
        metavar='',
        help=
        'target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')')
    parser.add_argument('-i',
                        '--input',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--n_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from each domain')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())
    cfg = args.copy()

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    # define all datasets
    datasets = {}
    datasets['MNIST'] = MNIST(train=True,
                              path=os.path.join(cfg['data_path'], 'MNIST'),
                              transform=T.ToTensor())
    datasets['MNIST_M'] = MNIST_M(train=True,
                                  path=os.path.join(cfg['data_path'],
                                                    'MNIST_M'),
                                  transform=T.ToTensor())
    datasets['SVHN'] = SVHN(train=True,
                            path=os.path.join(cfg['data_path'], 'SVHN'),
                            transform=T.ToTensor())
    datasets['SynthDigits'] = SynthDigits(train=True,
                                          path=os.path.join(
                                              cfg['data_path'], 'SynthDigits'),
                                          transform=T.ToTensor())
    test_set = datasets[cfg['target']]

    # get a subset of cfg['n_images'] from each dataset
    # define public and private test sets: the private is not used at training time to learn invariant representations
    for ds_name in datasets:
        if ds_name == cfg['target']:
            indices = random.sample(range(len(datasets[ds_name])),
                                    2 * cfg['n_images'])
            test_pub_set = Subset(test_set, indices[0:cfg['n_images']])
            test_priv_set = Subset(test_set, indices[cfg['n_images']::])
        else:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_images'])
        datasets[ds_name] = Subset(datasets[ds_name],
                                   indices[0:cfg['n_images']])

    # build the dataloader
    test_pub_loader = DataLoader(test_pub_set,
                                 batch_size=4 * cfg['batch_size'])
    test_priv_loader = DataLoader(test_priv_set,
                                  batch_size=4 * cfg['batch_size'])
    test_loaders = {
        'target pub': test_pub_loader,
        'target priv': test_priv_loader
    }
    log.print('target domain:', cfg['target'], level=0)

    if cfg['model'] in ['FS', 'FM']:
        model = SimpleCNN().to(device)
    elif args['model'] == 'MDAN':
        model = MDANet(len(datasets) - 1).to(device)
    elif cfg['model'] in ['DANNS', 'DANNM', 'MODA', 'MODAFM']:
        model = MODANet().to(device)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    if cfg['model'] != 'DANNS':
        model.load_state_dict(torch.load(cfg['input']))
        accuracies, losses = test_routine(model, test_loaders, cfg)
        print('target pub: acc = {:.3f},'.format(accuracies['target pub']),
              'loss = {:.3f}'.format(losses['target pub']))
        print('target priv: acc = {:.3f},'.format(accuracies['target priv']),
              'loss = {:.3f}'.format(losses['target priv']))

    else:  # for DANNS, report results for the best source domain
        src_domains = ['MNIST', 'MNIST_M', 'SVHN', 'SynthDigits']
        src_domains.remove(cfg['target'])
        for i, src in enumerate(src_domains):
            model.load_state_dict(torch.load(cfg['input'] + '_' + src))
            acc, loss = test_routine(model, test_loaders, cfg)
            if i == 0:
                accuracies = acc
                losses = loss
            else:
                for key in accuracies.keys():
                    accuracies[key] = acc[key] if (
                        acc[key] > accuracies[key]) else accuracies[key]
                    losses[key] = loss[key] if (
                        acc[key] > accuracies[key]) else losses[key]
        log.print('target pub: acc = {:.3f},'.format(accuracies['target pub']),
                  'loss = {:.3f}'.format(losses['target pub']),
                  level=1)
        log.print('target priv: acc = {:.3f},'.format(
            accuracies['target priv']),
                  'loss = {:.3f}'.format(losses['target priv']),
                  level=1)