Esempio n. 1
0
def run(test_phase, data_seed, n_labeled, training_length, rampdown_length):
    minibatch_size = 100
    n_labeled_per_batch = 100

    tf.reset_default_graph()
    model = Model(RunContext(__file__, data_seed))

    cifar = SVHN(n_labeled=n_labeled,
                 data_seed=data_seed,
                 test_phase=test_phase)

    model['ema_consistency'] = True
    model['max_consistency_cost'] = 0.0
    model['apply_consistency_to_labeled'] = False
    model['rampdown_length'] = rampdown_length
    model['training_length'] = training_length

    # Turn off augmentation
    model['translate'] = False
    model['flip_horizontally'] = False

    training_batches = minibatching.training_batches(cifar.training,
                                                     minibatch_size,
                                                     n_labeled_per_batch)
    evaluation_batches_fn = minibatching.evaluation_epoch_generator(
        cifar.evaluation, minibatch_size)

    tensorboard_dir = model.save_tensorboard_graph()
    LOG.info("Saved tensorboard graph to %r", tensorboard_dir)

    model.train(training_batches, evaluation_batches_fn)
Esempio n. 2
0
def run(data_seed, num_logits, logit_distance_cost, test_phase=False, n_labeled=500, n_extra_unlabeled=0, model_type='mean_teacher'):
    minibatch_size = 100
    hyperparams = model_hyperparameters(model_type, n_labeled, n_extra_unlabeled)

    tf.reset_default_graph()
    model = Model(RunContext(__file__, data_seed))

    svhn = SVHN(n_labeled=n_labeled,
                n_extra_unlabeled=n_extra_unlabeled,
                data_seed=data_seed,
                test_phase=test_phase)

    model['ema_consistency'] = hyperparams['ema_consistency']
    model['max_consistency_cost'] = hyperparams['max_consistency_cost']
    model['apply_consistency_to_labeled'] = hyperparams['apply_consistency_to_labeled']
    model['training_length'] = hyperparams['training_length']
    model['num_logits'] = num_logits
    model['logit_distance_cost'] = logit_distance_cost

    training_batches = minibatching.training_batches(svhn.training,
                                                     minibatch_size,
                                                     hyperparams['n_labeled_per_batch'])
    evaluation_batches_fn = minibatching.evaluation_epoch_generator(svhn.evaluation,
                                                                    minibatch_size)

    tensorboard_dir = model.save_tensorboard_graph()
    LOG.info("Saved tensorboard graph to %r", tensorboard_dir)

    model.train(training_batches, evaluation_batches_fn)
Esempio n. 3
0
def run(data_seed, dropout, input_noise, augmentation,
        test_phase=False, n_labeled=250, n_extra_unlabeled=0, model_type='mean_teacher'):
    minibatch_size = 100
    hyperparams = model_hyperparameters(model_type, n_labeled, n_extra_unlabeled)

    tf.reset_default_graph()
    model = Model(RunContext(__file__, data_seed))

    svhn = SVHN(n_labeled=n_labeled,
                n_extra_unlabeled=n_extra_unlabeled,
                data_seed=data_seed,
                test_phase=test_phase)

    model['ema_consistency'] = hyperparams['ema_consistency']
    model['max_consistency_cost'] = hyperparams['max_consistency_cost']
    model['apply_consistency_to_labeled'] = hyperparams['apply_consistency_to_labeled']
    model['training_length'] = hyperparams['training_length']
    model['student_dropout_probability'] = dropout
    model['teacher_dropout_probability'] = dropout
    model['input_noise'] = input_noise
    model['translate'] = augmentation

    training_batches = minibatching.training_batches(svhn.training,
                                                     minibatch_size,
                                                     hyperparams['n_labeled_per_batch'])
    evaluation_batches_fn = minibatching.evaluation_epoch_generator(svhn.evaluation,
                                                                    minibatch_size)

    tensorboard_dir = model.save_tensorboard_graph()
    LOG.info("Saved tensorboard graph to %r", tensorboard_dir)

    model.train(training_batches, evaluation_batches_fn)
Esempio n. 4
0
def run():
    data_seed = 0
    date = datetime.now()
    n_labeled = 500
    n_extra_unlabeled = 0

    result_dir = "{root}/{dataset}/{model}/{date:%Y-%m-%d_%H:%M:%S}/{seed}".format(
        root='results/final_eval',
        dataset='svhn_{}_{}'.format(n_labeled, n_extra_unlabeled),
        model='mean_teacher',
        date=date,
        seed=data_seed)

    model = Model(result_dir=result_dir)
    model['rampdown_length'] = 0
    model['training_length'] = 180000

    tensorboard_dir = model.save_tensorboard_graph()
    LOG.info("Saved tensorboard graph to %r", tensorboard_dir)

    svhn = SVHN(data_seed, n_labeled, n_extra_unlabeled)
    training_batches = minibatching.training_batches(svhn.training,
                                                     n_labeled_per_batch=1)
    evaluation_batches_fn = minibatching.evaluation_epoch_generator(
        svhn.evaluation)

    model.train(training_batches, evaluation_batches_fn)
def run(result_dir, test_phase, n_labeled, n_extra_unlabeled, data_seed,
        model_type):
    minibatch_size = 100
    hyperparams = model_hyperparameters(model_type, n_labeled,
                                        n_extra_unlabeled)

    tf.reset_default_graph()
    model = Model(result_dir=result_dir)

    svhn = SVHN(n_labeled=n_labeled,
                n_extra_unlabeled=n_extra_unlabeled,
                data_seed=data_seed,
                test_phase=test_phase)

    model['rampdown_length'] = 0
    model['ema_consistency'] = hyperparams['ema_consistency']
    model['max_consistency_coefficient'] = hyperparams[
        'max_consistency_coefficient']
    model['apply_consistency_to_labeled'] = hyperparams[
        'apply_consistency_to_labeled']
    model['training_length'] = hyperparams['training_length']

    training_batches = minibatching.training_batches(
        svhn.training, minibatch_size, hyperparams['n_labeled_per_batch'])
    evaluation_batches_fn = minibatching.evaluation_epoch_generator(
        svhn.evaluation, minibatch_size)

    tensorboard_dir = model.save_tensorboard_graph()
    LOG.info("Saved tensorboard graph to %r", tensorboard_dir)

    model.train(training_batches, evaluation_batches_fn)
Esempio n. 6
0
def run(test_phase, n_labeled, n_extra_unlabeled, data_seed, model_type):
    minibatch_size = 100
    hyperparams = model_hyperparameters(model_type, n_labeled, n_extra_unlabeled)

    tf.reset_default_graph()
    model = Model(RunContext(__file__, data_seed))

    svhn = SVHN(n_labeled=n_labeled,
                n_extra_unlabeled=n_extra_unlabeled,
                data_seed=data_seed,
                test_phase=test_phase)

    model['ema_consistency'] = hyperparams['ema_consistency']
    model['max_consistency_cost'] = hyperparams['max_consistency_cost']
    model['apply_consistency_to_labeled'] = hyperparams['apply_consistency_to_labeled']
    model['training_length'] = hyperparams['training_length']

    # Turn off augmentation
    model['translate'] = False
    model['flip_horizontally'] = False

    training_batches = minibatching.training_batches(svhn.training,
                                                     minibatch_size,
                                                     hyperparams['n_labeled_per_batch'])
    evaluation_batches_fn = minibatching.evaluation_epoch_generator(svhn.evaluation,
                                                                    minibatch_size)

    tensorboard_dir = model.save_tensorboard_graph()
    LOG.info("Saved tensorboard graph to %r", tensorboard_dir)

    model.train(training_batches, evaluation_batches_fn)
def main(argv):
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    if argv[2] == "mnistm":
        dataset = MNISTM(image_path=argv[1],
                         label_path=None,
                         test_mode=True,
                         transform=T.ToTensor())
        source = "usps"
        target = "mnistm"
    elif argv[2] == "svhn":
        dataset = SVHN(image_path=argv[1],
                       label_path=None,
                       test_mode=True,
                       transform=T.ToTensor())
        source = "mnistm"
        target = "svhn"
    elif argv[2] == "usps":
        dataset = USPS(image_path=argv[1],
                       label_path=None,
                       test_mode=True,
                       transform=T.ToTensor())
        source = "svhn"
        target = "usps"

    batch_size = 64
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    model = DANN(filter_num=64)
    model.to(device)

    with torch.no_grad():

        model.load_state_dict(
            torch.load("./model/{}_{}.pkl".format(source, target)))
        model.eval()

        with open(argv[3], "w", newline="") as csvfile:

            writer = csv.writer(csvfile)
            writer.writerow(["image_name", "label"])

            for _, data in enumerate(loader):
                inputs, file_name = data
                inputs = inputs.to(device)

                outputs, _ = model(inputs, 0)

                predict = torch.max(outputs, 1)[1].cpu().numpy()

                for i in range(len(file_name)):
                    writer.writerow([file_name[i], predict[i]])
Esempio n. 8
0
def run(data_seed=0):
    n_labeled = 1000
    n_extra_unlabeled = 0

    model = Model(RunContext(__file__, 0))

    tensorboard_dir = model.save_tensorboard_graph()
    LOG.info("Saved tensorboard graph to %r", tensorboard_dir)

    svhn = SVHN(data_seed, n_labeled, n_extra_unlabeled)
    training_batches = minibatching.training_batches(svhn.training,
                                                     n_labeled_per_batch=1)
    evaluation_batches_fn = minibatching.evaluation_epoch_generator(
        svhn.evaluation)

    model.train(training_batches, evaluation_batches_fn)
Esempio n. 9
0
def run(data_seed=0):
    n_labeled = 500
    n_extra_unlabeled = 0

    model = Model(RunContext(__file__, 0))
    model['rampdown_length'] = 0
    model['rampup_length'] = 5000
    model['training_length'] = 40000
    model['max_consistency_cost'] = 50.0

    tensorboard_dir = model.save_tensorboard_graph()
    LOG.info("Saved tensorboard graph to %r", tensorboard_dir)

    svhn = SVHN(data_seed, n_labeled, n_extra_unlabeled)
    training_batches = minibatching.training_batches(svhn.training, n_labeled_per_batch=50)
    evaluation_batches_fn = minibatching.evaluation_epoch_generator(svhn.evaluation)

    model.train(training_batches, evaluation_batches_fn)
def main(argv):
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    dataset = SVHN(image_path=argv[1],
                   label_path=None,
                   test_mode=True,
                   transform=T.ToTensor())
    source = "mnistm"
    target = "svhn"

    batch_size = 64
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    encoder = ADDA_Encoder()
    encoder.load_state_dict(torch.load("./model/adda/tgt_encoder.pkl"))
    encoder.to(device)
    encoder.eval()
    classifier = ADDA_Classifier()
    classifier.load_state_dict(torch.load("./model/adda/classifier.pkl"))
    classifier.to(device)
    classifier.eval()

    with torch.no_grad():
        with open(argv[3], "w", newline="") as csvfile:

            writer = csv.writer(csvfile)
            writer.writerow(["image_name", "label"])

            for _, data in enumerate(loader):
                inputs, file_name = data
                inputs = inputs.to(device)

                outputs = classifier(encoder(inputs))

                predict = torch.max(outputs, 1)[1].cpu().numpy()

                for i in range(len(file_name)):
                    writer.writerow([file_name[i], predict[i]])
Esempio n. 11
0
                                     transform=transform_train)
     train_loader = torch.utils.data.DataLoader(training_set,
                                                batch_size=args.batch_size,
                                                shuffle=True)
     testset = datasets.CIFAR10(root='./cifar10_data',
                                train=False,
                                download=True,
                                transform=transform_test)
     test_loader = torch.utils.data.DataLoader(
         testset, batch_size=args.test_batch_size, shuffle=False)
 elif args.dataset == 'SVHN':
     training_set = SVHN('./svhn_data',
                         split='train',
                         transform=transforms.Compose([
                             transforms.RandomCrop(32, padding=4),
                             transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                  (0.2023, 0.1994, 0.2010)),
                         ]))
     train_loader = torch.utils.data.DataLoader(training_set,
                                                batch_size=128,
                                                shuffle=True)
     transform_test = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465),
                              (0.2023, 0.1994, 0.2010)),
     ])
     testset = SVHN(root='./svhn_data',
                    split='test',
                    download=True,
Esempio n. 12
0
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with digits datasets.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='MNIST',
        type=str,
        metavar='',
        help=
        'target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')')
    parser.add_argument('-o',
                        '--output',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--icfg',
                        default=None,
                        type=str,
                        metavar='',
                        help='config file (overrides args)')
    parser.add_argument('--n_src_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from each source domain')
    parser.add_argument('--n_tgt_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from the target domain')
    parser.add_argument(
        '--mu_d',
        type=float,
        default=1e-2,
        help=
        "hyperparameter of the coefficient for the domain discriminator loss")
    parser.add_argument(
        '--mu_s',
        type=float,
        default=0.2,
        help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--mu_c',
                        type=float,
                        default=1e-1,
                        help="hyperparameter of the FixMatch loss")
    parser.add_argument('--n_rand_aug',
                        type=int,
                        default=2,
                        help="N parameter of RandAugment")
    parser.add_argument('--m_min_rand_aug',
                        type=int,
                        default=3,
                        help="minimum M parameter of RandAugment")
    parser.add_argument('--m_max_rand_aug',
                        type=int,
                        default=10,
                        help="maximum M parameter of RandAugment")
    parser.add_argument('--weight_decay',
                        default=0.,
                        type=float,
                        metavar='',
                        help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr',
                        default=1e-1,
                        type=float,
                        metavar='',
                        help='learning rate')
    parser.add_argument('--epochs',
                        default=30,
                        type=int,
                        metavar='',
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size (per domain)')
    parser.add_argument(
        '--checkpoint',
        default=0,
        type=int,
        metavar='',
        help=
        'number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--eval_target',
                        default=False,
                        type=int,
                        metavar='',
                        help='evaluate target during training')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--use_visdom',
                        default=False,
                        type=int,
                        metavar='',
                        help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env',
                        default='digits_train',
                        type=str,
                        metavar='',
                        help='Visdom environment name')
    parser.add_argument('--visdom_port',
                        default=8888,
                        type=int,
                        metavar='',
                        help='Visdom port')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())

    # override args with icfg (if provided)
    cfg = args.copy()
    if cfg['icfg'] is not None:
        cv_parser = ConfigParser()
        cv_parser.read(cfg['icfg'])
        cv_param_names = []
        for key, val in cv_parser.items('main'):
            cfg[key] = ast.literal_eval(val)
            cv_param_names.append(key)

    # dump cfg to a txt file for your records
    with open(cfg['output'] + '.txt', 'w') as f:
        f.write(str(cfg) + '\n')

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    if ('FS' in cfg['model']) or ('FM' in cfg['model']):
        # weak data augmentation (small rotation + small translation)
        data_aug = T.Compose([
            T.RandomAffine(5, translate=(0.125, 0.125)),
            T.ToTensor(),
        ])
    else:
        data_aug = T.ToTensor()

    # define all datasets
    datasets = {}
    datasets['MNIST'] = MNIST(train=True,
                              path=os.path.join(cfg['data_path'], 'MNIST'),
                              transform=data_aug)
    datasets['MNIST_M'] = MNIST_M(train=True,
                                  path=os.path.join(cfg['data_path'],
                                                    'MNIST_M'),
                                  transform=data_aug)
    datasets['SVHN'] = SVHN(train=True,
                            path=os.path.join(cfg['data_path'], 'SVHN'),
                            transform=data_aug)
    datasets['SynthDigits'] = SynthDigits(train=True,
                                          path=os.path.join(
                                              cfg['data_path'], 'SynthDigits'),
                                          transform=data_aug)
    if ('FS' in cfg['model']) or ('FM' in cfg['model']):
        test_set = deepcopy(datasets[cfg['target']])
        test_set.transform = T.ToTensor()  # no data augmentation in test
    else:
        test_set = datasets[cfg['target']]

    # get a subset of cfg['n_images'] from each dataset
    # define public and private test sets: the private is not used at training time to learn invariant representations
    for ds_name in datasets:
        if ds_name == cfg['target']:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_tgt_images'] + cfg['n_src_images'])
            test_pub_set = Subset(test_set, indices[0:cfg['n_tgt_images']])
            test_priv_set = Subset(test_set, indices[cfg['n_tgt_images']::])
            datasets[cfg['target']] = Subset(datasets[cfg['target']],
                                             indices[0:cfg['n_tgt_images']])
        else:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_src_images'])
            datasets[ds_name] = Subset(datasets[ds_name],
                                       indices[0:cfg['n_src_images']])

    # build the dataloader
    train_loader = MSDA_Loader(datasets,
                               cfg['target'],
                               batch_size=cfg['batch_size'],
                               shuffle=True,
                               device=device)
    test_pub_loader = DataLoader(test_pub_set,
                                 batch_size=4 * cfg['batch_size'])
    test_priv_loader = DataLoader(test_priv_set,
                                  batch_size=4 * cfg['batch_size'])
    valid_loaders = ({
        'target pub': test_pub_loader,
        'target priv': test_priv_loader
    } if cfg['eval_target'] else None)
    log.print('target domain:',
              cfg['target'],
              '| source domains:',
              train_loader.sources,
              level=1)

    if cfg['model'] == 'FS':
        model = SimpleCNN().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        if valid_loaders is not None:
            del valid_loaders['target pub']
        fs_train_routine(model, optimizer, test_pub_loader, valid_loaders, cfg)
    elif cfg['model'] == 'FM':
        model = SimpleCNN().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        cfg['excl_transf'] = [Flip]
        fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'DANNS':
        for src in train_loader.sources:
            model = MODANet().to(device)
            optimizer = optim.Adadelta(model.parameters(),
                                       lr=cfg['lr'],
                                       weight_decay=cfg['weight_decay'])
            dataset_ss = {
                src: datasets[src],
                cfg['target']: datasets[cfg['target']]
            }
            train_loader = MSDA_Loader(dataset_ss,
                                       cfg['target'],
                                       batch_size=cfg['batch_size'],
                                       shuffle=True,
                                       device=device)
            dann_train_routine(model, optimizer, train_loader, valid_loaders,
                               cfg)
            torch.save(model.state_dict(), cfg['output'] + '_' + src)
    elif cfg['model'] == 'DANNM':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MDAN':
        model = MDANet(len(train_loader.sources)).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MDANU':
        model = MDANet(len(train_loader.sources)).to(device)
        model.grad_reverse = nn.ModuleList([
            nn.Identity() for _ in range(len(model.domain_class))
        ])  # remove grad reverse
        task_optim = optim.Adadelta(list(model.feat_ext.parameters()) +
                                    list(model.task_class.parameters()),
                                    lr=cfg['lr'],
                                    weight_decay=cfg['weight_decay'])
        adv_optim = optim.Adadelta(model.domain_class.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        optimizers = (task_optim, adv_optim)
        mdan_unif_train_routine(model, optimizers, train_loader, valid_loaders,
                                cfg)
    elif cfg['model'] == 'MDANFM':
        model = MDANet(len(train_loader.sources)).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mdan_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)
    elif cfg['model'] == 'MDANUFM':
        model = MDANet(len(train_loader.sources)).to(device)
        task_optim = optim.Adadelta(list(model.feat_ext.parameters()) +
                                    list(model.task_class.parameters()),
                                    lr=cfg['lr'],
                                    weight_decay=cfg['weight_decay'])
        adv_optim = optim.Adadelta(model.domain_class.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        optimizers = (task_optim, adv_optim)
        cfg['excl_transf'] = [Flip]
        mdan_unif_fm_train_routine(model, optimizer, train_loader,
                                   valid_loaders, cfg)
    elif cfg['model'] == 'MODA':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MODAFM':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        cfg['excl_transf'] = [Flip]
        moda_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    if cfg['model'] != 'DANNS':
        torch.save(model.state_dict(), cfg['output'])
Esempio n. 13
0
def main():
    # N.B.: parameters defined in cv_cfg.ini override args!
    parser = argparse.ArgumentParser(description='Cross-validation over source domains for the digits datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-m', '--model', default='MODAFM', type=str, metavar='', help='model type (\'MDAN\' / \'MODA\' / \'MODAFM\'')
    parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/', type=str, metavar='', help='data directory path')
    parser.add_argument('-t', '--target', default='MNIST', type=str, metavar='', help='target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')')
    parser.add_argument('-o', '--output', default='msda_hyperparams.ini', type=str, metavar='', help='model file (output of train)')
    parser.add_argument('-n', '--n_iter', default=20, type=int, metavar='', help='number of CV iterations')
    parser.add_argument('--n_images', default=20000, type=int, metavar='', help='number of images from each domain')
    parser.add_argument('--mu', type=float, default=1e-2, help="hyperparameter of the coefficient for the domain adversarial loss")
    parser.add_argument('--beta', type=float, default=0.2, help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--lambda', type=float, default=1e-1, help="hyperparameter of the FixMatch loss")
    parser.add_argument('--n_rand_aug', type=int, default=2, help="N parameter of RandAugment")
    parser.add_argument('--m_min_rand_aug', type=int, default=3, help="minimum M parameter of RandAugment")
    parser.add_argument('--m_max_rand_aug', type=int, default=10, help="maximum M parameter of RandAugment")
    parser.add_argument('--weight_decay', default=0., type=float, metavar='', help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr', default=1e-1, type=float, metavar='', help='learning rate')
    parser.add_argument('--epochs', default=30, type=int, metavar='', help='number of training epochs')
    parser.add_argument('--batch_size', default=8, type=int, metavar='', help='batch size (per domain)')
    parser.add_argument('--checkpoint', default=0, type=int, metavar='', help='number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU')
    parser.add_argument('--use_visdom', default=False, type=int, metavar='', help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env', default='digits_train', type=str, metavar='', help='Visdom environment name')
    parser.add_argument('--visdom_port', default=8888, type=int, metavar='', help='Visdom port')
    parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed')
    args = vars(parser.parse_args())

    # override args with cv_cfg.ini
    cfg = args.copy()
    cv_parser = ConfigParser()
    cv_parser.read('cv_cfg.ini')
    cv_param_names = []
    for key, val in cv_parser.items('main'):
        cfg[key] = ast.literal_eval(val)
        cv_param_names.append(key)

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    if 'FM' in cfg['model']:
        # weak data augmentation (small rotation + small translation)
        data_aug = T.Compose([
            T.RandomAffine(5, translate=(0.125, 0.125)),
            T.ToTensor(),
        ])
    else:
        data_aug = T.ToTensor()
    cfg['test_transform'] = T.ToTensor()

    # define all datasets
    datasets = {}
    datasets['MNIST'] = MNIST(train=True, path=os.path.join(cfg['data_path'], 'MNIST'), transform=data_aug)
    datasets['MNIST_M'] = MNIST_M(train=True, path=os.path.join(cfg['data_path'], 'MNIST_M'), transform=data_aug)
    datasets['SVHN'] = SVHN(train=True, path=os.path.join(cfg['data_path'], 'SVHN'), transform=data_aug)
    datasets['SynthDigits'] = SynthDigits(train=True, path=os.path.join(cfg['data_path'], 'SynthDigits'), transform=data_aug)
    del datasets[cfg['target']]

    # get a subset of cfg['n_images'] from each dataset
    for ds_name in datasets:
        if ds_name == cfg['target']:
            continue
        indices = random.sample(range(len(datasets[ds_name])), cfg['n_images'])
        datasets[ds_name] = Subset(datasets[ds_name], indices[0:cfg['n_images']])

    if cfg['model'] == 'MDAN':
        cfg['model'] = MDANet(len(datasets)-1).to(device)
        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: mdan_train_routine(model, optimizer, train_loader, dict(), cfg)
    elif cfg['model'] == 'MODA':
        cfg['model'] = MODANet().to(device)
        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_train_routine(model, optimizer, train_loader, dict(), cfg)
    elif cfg['model'] == 'MODAFM':
        cfg['model'] = MODANet().to(device)
        cfg['excl_transf'] = [Flip]
        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_fm_train_routine(model, optimizer, train_loader, dict(), cfg)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    best_params, _ = cross_validation(datasets, cfg, cv_param_names)
    log.print('best_params:', best_params, level=1)

    results = ConfigParser()
    results.add_section('main')
    for key, value in  best_params.items():
        results.set('main', key, str(value))
    with open(cfg['output'], 'w') as f:
        results.write(f)
Esempio n. 14
0
    gray_transforms = transforms.Compose(gray_transforms)

    # mnistm transforms
    mnistm_transforms = list()
    mnistm_transforms.append(transforms.Resize((32, 32)))
    mnistm_transforms.append(transforms.ToTensor())
    mnistm_transforms.append(
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
    mnistm_transforms = transforms.Compose(mnistm_transforms)

    # create the dataloaders
    dataloader = {}
    if args.source == 'svhn':
        dataloader['source_train'] = DataLoader(SVHN(os.path.join(
            args.data_root, args.source),
                                                     split='train',
                                                     transform=dset_transforms,
                                                     domain_label=0,
                                                     download=True),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                drop_last=True)
    elif args.source == 'mnist':
        dataloader['source_train'] = DataLoader(MNIST(
            os.path.join(args.data_root, args.source),
            train=True,
            transform=gray_transforms,
            domain_label=0,
            download=True),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                drop_last=True)
Esempio n. 15
0
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with digits datasets.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='MNIST',
        type=str,
        metavar='',
        help=
        'target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')')
    parser.add_argument('-i',
                        '--input',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--n_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from each domain')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())
    cfg = args.copy()

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    # define all datasets
    datasets = {}
    datasets['MNIST'] = MNIST(train=True,
                              path=os.path.join(cfg['data_path'], 'MNIST'),
                              transform=T.ToTensor())
    datasets['MNIST_M'] = MNIST_M(train=True,
                                  path=os.path.join(cfg['data_path'],
                                                    'MNIST_M'),
                                  transform=T.ToTensor())
    datasets['SVHN'] = SVHN(train=True,
                            path=os.path.join(cfg['data_path'], 'SVHN'),
                            transform=T.ToTensor())
    datasets['SynthDigits'] = SynthDigits(train=True,
                                          path=os.path.join(
                                              cfg['data_path'], 'SynthDigits'),
                                          transform=T.ToTensor())
    test_set = datasets[cfg['target']]

    # get a subset of cfg['n_images'] from each dataset
    # define public and private test sets: the private is not used at training time to learn invariant representations
    for ds_name in datasets:
        if ds_name == cfg['target']:
            indices = random.sample(range(len(datasets[ds_name])),
                                    2 * cfg['n_images'])
            test_pub_set = Subset(test_set, indices[0:cfg['n_images']])
            test_priv_set = Subset(test_set, indices[cfg['n_images']::])
        else:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_images'])
        datasets[ds_name] = Subset(datasets[ds_name],
                                   indices[0:cfg['n_images']])

    # build the dataloader
    test_pub_loader = DataLoader(test_pub_set,
                                 batch_size=4 * cfg['batch_size'])
    test_priv_loader = DataLoader(test_priv_set,
                                  batch_size=4 * cfg['batch_size'])
    test_loaders = {
        'target pub': test_pub_loader,
        'target priv': test_priv_loader
    }
    log.print('target domain:', cfg['target'], level=0)

    if cfg['model'] in ['FS', 'FM']:
        model = SimpleCNN().to(device)
    elif args['model'] == 'MDAN':
        model = MDANet(len(datasets) - 1).to(device)
    elif cfg['model'] in ['DANNS', 'DANNM', 'MODA', 'MODAFM']:
        model = MODANet().to(device)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    if cfg['model'] != 'DANNS':
        model.load_state_dict(torch.load(cfg['input']))
        accuracies, losses = test_routine(model, test_loaders, cfg)
        print('target pub: acc = {:.3f},'.format(accuracies['target pub']),
              'loss = {:.3f}'.format(losses['target pub']))
        print('target priv: acc = {:.3f},'.format(accuracies['target priv']),
              'loss = {:.3f}'.format(losses['target priv']))

    else:  # for DANNS, report results for the best source domain
        src_domains = ['MNIST', 'MNIST_M', 'SVHN', 'SynthDigits']
        src_domains.remove(cfg['target'])
        for i, src in enumerate(src_domains):
            model.load_state_dict(torch.load(cfg['input'] + '_' + src))
            acc, loss = test_routine(model, test_loaders, cfg)
            if i == 0:
                accuracies = acc
                losses = loss
            else:
                for key in accuracies.keys():
                    accuracies[key] = acc[key] if (
                        acc[key] > accuracies[key]) else accuracies[key]
                    losses[key] = loss[key] if (
                        acc[key] > accuracies[key]) else losses[key]
        log.print('target pub: acc = {:.3f},'.format(accuracies['target pub']),
                  'loss = {:.3f}'.format(losses['target pub']),
                  level=1)
        log.print('target priv: acc = {:.3f},'.format(
            accuracies['target priv']),
                  'loss = {:.3f}'.format(losses['target priv']),
                  level=1)