예제 #1
0
파일: train.py 프로젝트: dpernes/modafm
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with digits datasets.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='MNIST',
        type=str,
        metavar='',
        help=
        'target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')')
    parser.add_argument('-o',
                        '--output',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--icfg',
                        default=None,
                        type=str,
                        metavar='',
                        help='config file (overrides args)')
    parser.add_argument('--n_src_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from each source domain')
    parser.add_argument('--n_tgt_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from the target domain')
    parser.add_argument(
        '--mu_d',
        type=float,
        default=1e-2,
        help=
        "hyperparameter of the coefficient for the domain discriminator loss")
    parser.add_argument(
        '--mu_s',
        type=float,
        default=0.2,
        help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--mu_c',
                        type=float,
                        default=1e-1,
                        help="hyperparameter of the FixMatch loss")
    parser.add_argument('--n_rand_aug',
                        type=int,
                        default=2,
                        help="N parameter of RandAugment")
    parser.add_argument('--m_min_rand_aug',
                        type=int,
                        default=3,
                        help="minimum M parameter of RandAugment")
    parser.add_argument('--m_max_rand_aug',
                        type=int,
                        default=10,
                        help="maximum M parameter of RandAugment")
    parser.add_argument('--weight_decay',
                        default=0.,
                        type=float,
                        metavar='',
                        help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr',
                        default=1e-1,
                        type=float,
                        metavar='',
                        help='learning rate')
    parser.add_argument('--epochs',
                        default=30,
                        type=int,
                        metavar='',
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size (per domain)')
    parser.add_argument(
        '--checkpoint',
        default=0,
        type=int,
        metavar='',
        help=
        'number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--eval_target',
                        default=False,
                        type=int,
                        metavar='',
                        help='evaluate target during training')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--use_visdom',
                        default=False,
                        type=int,
                        metavar='',
                        help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env',
                        default='digits_train',
                        type=str,
                        metavar='',
                        help='Visdom environment name')
    parser.add_argument('--visdom_port',
                        default=8888,
                        type=int,
                        metavar='',
                        help='Visdom port')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())

    # override args with icfg (if provided)
    cfg = args.copy()
    if cfg['icfg'] is not None:
        cv_parser = ConfigParser()
        cv_parser.read(cfg['icfg'])
        cv_param_names = []
        for key, val in cv_parser.items('main'):
            cfg[key] = ast.literal_eval(val)
            cv_param_names.append(key)

    # dump cfg to a txt file for your records
    with open(cfg['output'] + '.txt', 'w') as f:
        f.write(str(cfg) + '\n')

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    if ('FS' in cfg['model']) or ('FM' in cfg['model']):
        # weak data augmentation (small rotation + small translation)
        data_aug = T.Compose([
            T.RandomAffine(5, translate=(0.125, 0.125)),
            T.ToTensor(),
        ])
    else:
        data_aug = T.ToTensor()

    # define all datasets
    datasets = {}
    datasets['MNIST'] = MNIST(train=True,
                              path=os.path.join(cfg['data_path'], 'MNIST'),
                              transform=data_aug)
    datasets['MNIST_M'] = MNIST_M(train=True,
                                  path=os.path.join(cfg['data_path'],
                                                    'MNIST_M'),
                                  transform=data_aug)
    datasets['SVHN'] = SVHN(train=True,
                            path=os.path.join(cfg['data_path'], 'SVHN'),
                            transform=data_aug)
    datasets['SynthDigits'] = SynthDigits(train=True,
                                          path=os.path.join(
                                              cfg['data_path'], 'SynthDigits'),
                                          transform=data_aug)
    if ('FS' in cfg['model']) or ('FM' in cfg['model']):
        test_set = deepcopy(datasets[cfg['target']])
        test_set.transform = T.ToTensor()  # no data augmentation in test
    else:
        test_set = datasets[cfg['target']]

    # get a subset of cfg['n_images'] from each dataset
    # define public and private test sets: the private is not used at training time to learn invariant representations
    for ds_name in datasets:
        if ds_name == cfg['target']:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_tgt_images'] + cfg['n_src_images'])
            test_pub_set = Subset(test_set, indices[0:cfg['n_tgt_images']])
            test_priv_set = Subset(test_set, indices[cfg['n_tgt_images']::])
            datasets[cfg['target']] = Subset(datasets[cfg['target']],
                                             indices[0:cfg['n_tgt_images']])
        else:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_src_images'])
            datasets[ds_name] = Subset(datasets[ds_name],
                                       indices[0:cfg['n_src_images']])

    # build the dataloader
    train_loader = MSDA_Loader(datasets,
                               cfg['target'],
                               batch_size=cfg['batch_size'],
                               shuffle=True,
                               device=device)
    test_pub_loader = DataLoader(test_pub_set,
                                 batch_size=4 * cfg['batch_size'])
    test_priv_loader = DataLoader(test_priv_set,
                                  batch_size=4 * cfg['batch_size'])
    valid_loaders = ({
        'target pub': test_pub_loader,
        'target priv': test_priv_loader
    } if cfg['eval_target'] else None)
    log.print('target domain:',
              cfg['target'],
              '| source domains:',
              train_loader.sources,
              level=1)

    if cfg['model'] == 'FS':
        model = SimpleCNN().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        if valid_loaders is not None:
            del valid_loaders['target pub']
        fs_train_routine(model, optimizer, test_pub_loader, valid_loaders, cfg)
    elif cfg['model'] == 'FM':
        model = SimpleCNN().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        cfg['excl_transf'] = [Flip]
        fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'DANNS':
        for src in train_loader.sources:
            model = MODANet().to(device)
            optimizer = optim.Adadelta(model.parameters(),
                                       lr=cfg['lr'],
                                       weight_decay=cfg['weight_decay'])
            dataset_ss = {
                src: datasets[src],
                cfg['target']: datasets[cfg['target']]
            }
            train_loader = MSDA_Loader(dataset_ss,
                                       cfg['target'],
                                       batch_size=cfg['batch_size'],
                                       shuffle=True,
                                       device=device)
            dann_train_routine(model, optimizer, train_loader, valid_loaders,
                               cfg)
            torch.save(model.state_dict(), cfg['output'] + '_' + src)
    elif cfg['model'] == 'DANNM':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MDAN':
        model = MDANet(len(train_loader.sources)).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MDANU':
        model = MDANet(len(train_loader.sources)).to(device)
        model.grad_reverse = nn.ModuleList([
            nn.Identity() for _ in range(len(model.domain_class))
        ])  # remove grad reverse
        task_optim = optim.Adadelta(list(model.feat_ext.parameters()) +
                                    list(model.task_class.parameters()),
                                    lr=cfg['lr'],
                                    weight_decay=cfg['weight_decay'])
        adv_optim = optim.Adadelta(model.domain_class.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        optimizers = (task_optim, adv_optim)
        mdan_unif_train_routine(model, optimizers, train_loader, valid_loaders,
                                cfg)
    elif cfg['model'] == 'MDANFM':
        model = MDANet(len(train_loader.sources)).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mdan_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)
    elif cfg['model'] == 'MDANUFM':
        model = MDANet(len(train_loader.sources)).to(device)
        task_optim = optim.Adadelta(list(model.feat_ext.parameters()) +
                                    list(model.task_class.parameters()),
                                    lr=cfg['lr'],
                                    weight_decay=cfg['weight_decay'])
        adv_optim = optim.Adadelta(model.domain_class.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        optimizers = (task_optim, adv_optim)
        cfg['excl_transf'] = [Flip]
        mdan_unif_fm_train_routine(model, optimizer, train_loader,
                                   valid_loaders, cfg)
    elif cfg['model'] == 'MODA':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MODAFM':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        cfg['excl_transf'] = [Flip]
        moda_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    if cfg['model'] != 'DANNS':
        torch.save(model.state_dict(), cfg['output'])
예제 #2
0
def main():
    # N.B.: parameters defined in cv_cfg.ini override args!
    parser = argparse.ArgumentParser(description='Cross-validation over source domains for the digits datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-m', '--model', default='MODAFM', type=str, metavar='', help='model type (\'MDAN\' / \'MODA\' / \'MODAFM\'')
    parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/', type=str, metavar='', help='data directory path')
    parser.add_argument('-t', '--target', default='MNIST', type=str, metavar='', help='target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')')
    parser.add_argument('-o', '--output', default='msda_hyperparams.ini', type=str, metavar='', help='model file (output of train)')
    parser.add_argument('-n', '--n_iter', default=20, type=int, metavar='', help='number of CV iterations')
    parser.add_argument('--n_images', default=20000, type=int, metavar='', help='number of images from each domain')
    parser.add_argument('--mu', type=float, default=1e-2, help="hyperparameter of the coefficient for the domain adversarial loss")
    parser.add_argument('--beta', type=float, default=0.2, help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--lambda', type=float, default=1e-1, help="hyperparameter of the FixMatch loss")
    parser.add_argument('--n_rand_aug', type=int, default=2, help="N parameter of RandAugment")
    parser.add_argument('--m_min_rand_aug', type=int, default=3, help="minimum M parameter of RandAugment")
    parser.add_argument('--m_max_rand_aug', type=int, default=10, help="maximum M parameter of RandAugment")
    parser.add_argument('--weight_decay', default=0., type=float, metavar='', help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr', default=1e-1, type=float, metavar='', help='learning rate')
    parser.add_argument('--epochs', default=30, type=int, metavar='', help='number of training epochs')
    parser.add_argument('--batch_size', default=8, type=int, metavar='', help='batch size (per domain)')
    parser.add_argument('--checkpoint', default=0, type=int, metavar='', help='number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU')
    parser.add_argument('--use_visdom', default=False, type=int, metavar='', help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env', default='digits_train', type=str, metavar='', help='Visdom environment name')
    parser.add_argument('--visdom_port', default=8888, type=int, metavar='', help='Visdom port')
    parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed')
    args = vars(parser.parse_args())

    # override args with cv_cfg.ini
    cfg = args.copy()
    cv_parser = ConfigParser()
    cv_parser.read('cv_cfg.ini')
    cv_param_names = []
    for key, val in cv_parser.items('main'):
        cfg[key] = ast.literal_eval(val)
        cv_param_names.append(key)

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    if 'FM' in cfg['model']:
        # weak data augmentation (small rotation + small translation)
        data_aug = T.Compose([
            T.RandomAffine(5, translate=(0.125, 0.125)),
            T.ToTensor(),
        ])
    else:
        data_aug = T.ToTensor()
    cfg['test_transform'] = T.ToTensor()

    # define all datasets
    datasets = {}
    datasets['MNIST'] = MNIST(train=True, path=os.path.join(cfg['data_path'], 'MNIST'), transform=data_aug)
    datasets['MNIST_M'] = MNIST_M(train=True, path=os.path.join(cfg['data_path'], 'MNIST_M'), transform=data_aug)
    datasets['SVHN'] = SVHN(train=True, path=os.path.join(cfg['data_path'], 'SVHN'), transform=data_aug)
    datasets['SynthDigits'] = SynthDigits(train=True, path=os.path.join(cfg['data_path'], 'SynthDigits'), transform=data_aug)
    del datasets[cfg['target']]

    # get a subset of cfg['n_images'] from each dataset
    for ds_name in datasets:
        if ds_name == cfg['target']:
            continue
        indices = random.sample(range(len(datasets[ds_name])), cfg['n_images'])
        datasets[ds_name] = Subset(datasets[ds_name], indices[0:cfg['n_images']])

    if cfg['model'] == 'MDAN':
        cfg['model'] = MDANet(len(datasets)-1).to(device)
        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: mdan_train_routine(model, optimizer, train_loader, dict(), cfg)
    elif cfg['model'] == 'MODA':
        cfg['model'] = MODANet().to(device)
        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_train_routine(model, optimizer, train_loader, dict(), cfg)
    elif cfg['model'] == 'MODAFM':
        cfg['model'] = MODANet().to(device)
        cfg['excl_transf'] = [Flip]
        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_fm_train_routine(model, optimizer, train_loader, dict(), cfg)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    best_params, _ = cross_validation(datasets, cfg, cv_param_names)
    log.print('best_params:', best_params, level=1)

    results = ConfigParser()
    results.add_section('main')
    for key, value in  best_params.items():
        results.set('main', key, str(value))
    with open(cfg['output'], 'w') as f:
        results.write(f)
예제 #3
0
파일: train.py 프로젝트: dpernes/modafm
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with Amazon dataset.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/Amazon',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='books',
        type=str,
        metavar='',
        help=
        'target domain (\'books\' / \'dvd\' / \'electronics\' / \'kitchen\')')
    parser.add_argument('-o',
                        '--output',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--icfg',
                        default=None,
                        type=str,
                        metavar='',
                        help='config file (overrides args)')
    parser.add_argument('--n_samples',
                        default=2000,
                        type=int,
                        metavar='',
                        help='number of samples from each domain')
    parser.add_argument('--n_features',
                        default=5000,
                        type=int,
                        metavar='',
                        help='number of features to use')
    parser.add_argument(
        '--mu',
        type=float,
        default=1e-2,
        help="hyperparameter of the coefficient for the domain adversarial loss"
    )
    parser.add_argument(
        '--beta',
        type=float,
        default=2e-1,
        help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--lambda',
                        type=float,
                        default=1e-1,
                        help="hyperparameter of the FixMatch loss")
    parser.add_argument('--min_dropout',
                        type=int,
                        default=2e-1,
                        help="minimum dropout rate")
    parser.add_argument('--max_dropout',
                        type=int,
                        default=8e-1,
                        help="maximum dropout rate")
    parser.add_argument('--weight_decay',
                        default=0.,
                        type=float,
                        metavar='',
                        help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr',
                        default=1e0,
                        type=float,
                        metavar='',
                        help='learning rate')
    parser.add_argument('--epochs',
                        default=15,
                        type=int,
                        metavar='',
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        default=20,
                        type=int,
                        metavar='',
                        help='batch size (per domain)')
    parser.add_argument(
        '--checkpoint',
        default=0,
        type=int,
        metavar='',
        help=
        'number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--eval_target',
                        default=False,
                        type=int,
                        metavar='',
                        help='evaluate target during training')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--use_visdom',
                        default=False,
                        type=int,
                        metavar='',
                        help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env',
                        default='amazon_train',
                        type=str,
                        metavar='',
                        help='Visdom environment name')
    parser.add_argument('--visdom_port',
                        default=8888,
                        type=int,
                        metavar='',
                        help='Visdom port')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())

    # override args with icfg (if provided)
    cfg = args.copy()
    if cfg['icfg'] is not None:
        cv_parser = ConfigParser()
        cv_parser.read(cfg['icfg'])
        cv_param_names = []
        for key, val in cv_parser.items('main'):
            cfg[key] = ast.literal_eval(val)
            cv_param_names.append(key)

    # dump cfg to a txt file for your records
    with open(cfg['output'] + '.txt', 'w') as f:
        f.write(str(cfg) + '\n')

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(args['seed'])
        np.random.seed(seed=args['seed'])
        torch.manual_seed(args['seed'])
        torch.cuda.manual_seed(args['seed'])

    domains = ['books', 'dvd', 'electronics', 'kitchen']
    datasets = {}
    for domain in domains:
        datasets[domain] = Amazon('./amazon.npz',
                                  domain,
                                  dimension=cfg['n_features'],
                                  transform=torch.from_numpy)
        indices = random.sample(range(len(datasets[domain])), cfg['n_samples'])
        if domain == cfg['target']:
            priv_indices = list(
                set(range(len(datasets[cfg['target']]))) - set(indices))
            test_priv_set = Subset(datasets[cfg['target']], priv_indices)
        datasets[domain] = Subset(datasets[domain], indices)
    test_pub_set = datasets[cfg['target']]

    train_loader = MSDA_Loader(datasets,
                               cfg['target'],
                               batch_size=cfg['batch_size'],
                               shuffle=True,
                               device=device)
    test_pub_loader = DataLoader(test_pub_set,
                                 batch_size=4 * cfg['batch_size'])
    test_priv_loader = DataLoader(test_priv_set,
                                  batch_size=4 * cfg['batch_size'])
    valid_loaders = {
        'target pub': test_pub_loader,
        'target priv': test_priv_loader
    } if cfg['eval_target'] else None
    log.print('target domain:',
              cfg['target'],
              'source domains:',
              train_loader.sources,
              level=1)

    if cfg['model'] == 'FS':
        model = SimpleMLP(input_dim=cfg['n_features'], n_classes=2).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        if cfg['eval_target']:
            del valid_loaders['target pub']
        fs_train_routine(model, optimizer, test_pub_loader, valid_loaders, cfg)
    elif cfg['model'] == 'FM':
        model = SimpleMLP(input_dim=cfg['n_features'], n_classes=2).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mlp_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                             cfg)
    elif cfg['model'] == 'DANNS':
        for src in train_loader.sources:
            model = MODANet(input_dim=cfg['n_features'],
                            n_classes=2).to(device)
            optimizer = optim.Adadelta(model.parameters(),
                                       lr=cfg['lr'],
                                       weight_decay=cfg['weight_decay'])
            dataset_ss = {
                src: datasets[src],
                cfg['target']: datasets[cfg['target']]
            }
            train_loader = MSDA_Loader(dataset_ss,
                                       cfg['target'],
                                       batch_size=cfg['batch_size'],
                                       shuffle=True,
                                       device=device)
            dann_train_routine(model, optimizer, train_loader, valid_loaders,
                               cfg)
            torch.save(model.state_dict(), cfg['output'] + '_' + src)
    elif cfg['model'] == 'DANNM':
        model = MODANet(input_dim=cfg['n_features'], n_classes=2).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MDAN':
        model = MDANet(input_dim=cfg['n_features'],
                       n_classes=2,
                       n_domains=len(train_loader.sources)).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MODA':
        model = MODANet(input_dim=cfg['n_features'], n_classes=2).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=args['lr'],
                                   weight_decay=cfg['weight_decay'])
        moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MODAFM':
        model = MODANet(input_dim=cfg['n_features'], n_classes=2).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        moda_mlp_fm_train_routine(model, optimizer, train_loader,
                                  valid_loaders, cfg)

    torch.save(model.state_dict(), cfg['output'])
예제 #4
0
파일: train.py 프로젝트: dpernes/modafm
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with the DomainNet dataset.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/DomainNet192',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='clipart',
        type=str,
        metavar='',
        help=
        'target domain (\'clipart\' / \'infograph\' / \'painting\' / \'quickdraw\' / \'real\' / \'sketch\')'
    )
    parser.add_argument('-o',
                        '--output',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--icfg',
                        default=None,
                        type=str,
                        metavar='',
                        help='config file (overrides args)')
    parser.add_argument(
        '--arch',
        default='resnet152',
        type=str,
        metavar='',
        help='network architecture (\'resnet101\' / \'resnet152\'')
    parser.add_argument(
        '--mu_d',
        type=float,
        default=1e-2,
        help=
        "hyperparameter of the coefficient for the domain discriminator loss")
    parser.add_argument(
        '--mu_s',
        type=float,
        default=0.2,
        help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--mu_c',
                        type=float,
                        default=1e-1,
                        help="hyperparameter of the FixMatch loss")
    parser.add_argument('--n_rand_aug',
                        type=int,
                        default=2,
                        help="N parameter of RandAugment")
    parser.add_argument('--m_min_rand_aug',
                        type=int,
                        default=3,
                        help="minimum M parameter of RandAugment")
    parser.add_argument('--m_max_rand_aug',
                        type=int,
                        default=10,
                        help="maximum M parameter of RandAugment")
    parser.add_argument('--weight_decay',
                        default=0.,
                        type=float,
                        metavar='',
                        help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr',
                        default=1e-3,
                        type=float,
                        metavar='',
                        help='learning rate')
    parser.add_argument('--epochs',
                        default=50,
                        type=int,
                        metavar='',
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size (per domain)')
    parser.add_argument(
        '--checkpoint',
        default=0,
        type=int,
        metavar='',
        help=
        'number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--eval_target',
                        default=False,
                        type=int,
                        metavar='',
                        help='evaluate target during training')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--use_visdom',
                        default=False,
                        type=int,
                        metavar='',
                        help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env',
                        default='domainnet_train',
                        type=str,
                        metavar='',
                        help='Visdom environment name')
    parser.add_argument('--visdom_port',
                        default=8888,
                        type=int,
                        metavar='',
                        help='Visdom port')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())

    # override args with icfg (if provided)
    cfg = args.copy()
    if cfg['icfg'] is not None:
        cv_parser = ConfigParser()
        cv_parser.read(cfg['icfg'])
        cv_param_names = []
        for key, val in cv_parser.items('main'):
            cfg[key] = ast.literal_eval(val)
            cv_param_names.append(key)

    # dump args to a txt file for your records
    with open(cfg['output'] + '.txt', 'w') as f:
        f.write(str(cfg) + '\n')

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    # normalization transformation (required for pretrained networks)
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    if 'FM' in cfg['model']:
        # weak data augmentation (small rotation + small translation)
        data_aug = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.RandomHorizontalFlip(),
            T.RandomAffine(5, translate=(0.125, 0.125)),
            T.ToTensor(),
            # normalize,  # normalization disrupts FixMatch
        ])

        eval_transf = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.ToTensor(),
        ])

    else:
        data_aug = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            normalize,
        ])

        eval_transf = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.ToTensor(),
            normalize,
        ])

    domains = [
        'clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'
    ]
    datasets = {
        domain: DomainNet(cfg['data_path'],
                          domain=domain,
                          train=True,
                          transform=data_aug)
        for domain in domains
    }
    n_classes = len(datasets[cfg['target']].class_names)

    test_set = DomainNet(cfg['data_path'],
                         domain=cfg['target'],
                         train=False,
                         transform=eval_transf)
    if 'FM' in cfg['model']:
        target_pub = deepcopy(datasets[cfg['target']])
        target_pub.transform = eval_transf  # no data augmentation in test
    else:
        target_pub = datasets[cfg['target']]

    if cfg['model'] != 'FS':
        train_loader = MSDA_Loader(datasets,
                                   cfg['target'],
                                   batch_size=cfg['batch_size'],
                                   shuffle=True,
                                   num_workers=0,
                                   device=device)
        if cfg['eval_target']:
            valid_loaders = {
                'target pub':
                DataLoader(target_pub, batch_size=6 * cfg['batch_size']),
                'target priv':
                DataLoader(test_set, batch_size=6 * cfg['batch_size'])
            }
        else:
            valid_loaders = None
        log.print('target domain:',
                  cfg['target'],
                  '| source domains:',
                  train_loader.sources,
                  level=1)
    else:
        train_loader = DataLoader(datasets[cfg['target']],
                                  batch_size=cfg['batch_size'],
                                  shuffle=True)
        test_loader = DataLoader(test_set, batch_size=cfg['batch_size'])
        log.print('target domain:', cfg['target'], level=1)

    if cfg['model'] == 'FS':
        model = SimpleCNN(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        valid_loaders = {
            'target pub': test_loader
        } if cfg['eval_target'] else None
        fs_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'FM':
        model = SimpleCNN(n_classes=n_classes, arch=cfg['arch']).to(device)
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        cfg['excl_transf'] = None
        fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'DANNS':
        for src in train_loader.sources:
            model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
            conv_params, fc_params = [], []
            for name, param in model.named_parameters():
                if 'fc' in name.lower():
                    fc_params.append(param)
                else:
                    conv_params.append(param)
            optimizer = optim.Adadelta([{
                'params': conv_params,
                'lr': 0.1 * cfg['lr'],
                'weight_decay': cfg['weight_decay']
            }, {
                'params': fc_params,
                'lr': cfg['lr'],
                'weight_decay': cfg['weight_decay']
            }])
            dataset_ss = {
                src: datasets[src],
                cfg['target']: datasets[cfg['target']]
            }
            train_loader = MSDA_Loader(dataset_ss,
                                       cfg['target'],
                                       batch_size=cfg['batch_size'],
                                       shuffle=True,
                                       device=device)
            dann_train_routine(model, optimizer, train_loader, valid_loaders,
                               cfg)
            torch.save(model.state_dict(), cfg['output'] + '_' + src)

    elif cfg['model'] == 'DANNM':
        model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif args['model'] == 'MDAN':
        model = MDANet(n_classes=n_classes,
                       n_domains=len(train_loader.sources),
                       arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'MODA':
        model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'MODAFM':
        model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        cfg['excl_transf'] = None
        moda_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)

    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    torch.save(model.state_dict(), cfg['output'])
예제 #5
0
파일: cross_val.py 프로젝트: dpernes/modafm
def main():
    parser = argparse.ArgumentParser(
        description=
        'Cross-validation over source domains for the Office dataset.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-m',
                        '--model',
                        default='MODAFM',
                        type=str,
                        metavar='',
                        help='model type (\'MDAN\' / \'MODA\' / \'MODAFM\'')
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/OfficeRsz',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='amazon',
        type=str,
        metavar='',
        help='target domain (\'amazon\' / \'dslr\' / \'webcam\')')
    parser.add_argument(
        '-o',
        '--output',
        default='cv_out.ini',
        type=str,
        metavar='',
        help='best hyperparameters (output of cross validation')
    parser.add_argument('-n',
                        '--n_iter',
                        default=20,
                        type=int,
                        metavar='',
                        help='number of CV iterations')
    parser.add_argument(
        '--mu_d',
        type=float,
        default=1e-2,
        help=
        "hyperparameter of the coefficient for the domain discriminator loss")
    parser.add_argument(
        '--mu_s',
        type=float,
        default=0.2,
        help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--mu_c',
                        type=float,
                        default=1e-1,
                        help="hyperparameter of the FixMatch loss")
    parser.add_argument('--n_rand_aug',
                        type=int,
                        default=2,
                        help="N parameter of RandAugment")
    parser.add_argument('--m_min_rand_aug',
                        type=int,
                        default=3,
                        help="minimum M parameter of RandAugment")
    parser.add_argument('--m_max_rand_aug',
                        type=int,
                        default=10,
                        help="maximum M parameter of RandAugment")
    parser.add_argument('--weight_decay',
                        default=0.,
                        type=float,
                        metavar='',
                        help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr',
                        default=1e-1,
                        type=float,
                        metavar='',
                        help='learning rate')
    parser.add_argument('--epochs',
                        default=15,
                        type=int,
                        metavar='',
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size (per domain)')
    parser.add_argument(
        '--checkpoint',
        default=0,
        type=int,
        metavar='',
        help=
        'number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--use_visdom',
                        default=False,
                        type=int,
                        metavar='',
                        help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env',
                        default='office_train',
                        type=str,
                        metavar='',
                        help='Visdom environment name')
    parser.add_argument('--visdom_port',
                        default=8888,
                        type=int,
                        metavar='',
                        help='Visdom port')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())

    # override args with cv_cfg.ini
    cfg = args.copy()
    cv_parser = ConfigParser()
    cv_parser.read('cv_cfg.ini')
    cv_param_names = []
    for key, val in cv_parser.items('main'):
        cfg[key] = ast.literal_eval(val)
        cv_param_names.append(key)

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    # dump cfg to a txt file for your records
    with open(cfg['output'] + '.txt', 'w') as f:
        f.write(str(cfg) + '\n')

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    # normalization transformation (required for pretrained networks)
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    if 'FM' in cfg['model']:
        # weak data augmentation (small rotation + small translation)
        data_aug = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomAffine(5, translate=(0.125, 0.125)),
            T.ToTensor(),
            # normalize,  # normalization disrupts FixMatch
        ])
        cfg['test_transform'] = T.ToTensor()
    else:
        data_aug = T.Compose([
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            normalize,
        ])
        cfg['test_transform'] = T.Compose([
            T.ToTensor(),
            normalize,
        ])

    domains = ['amazon', 'dslr', 'webcam']
    datasets = {
        domain: Office(cfg['data_path'], domain=domain, transform=data_aug)
        for domain in domains
    }
    n_classes = len(datasets[cfg['target']].class_names)
    del datasets[args['target']]

    if cfg['model'] == 'MDAN':
        model = MDANet(n_classes=n_classes,
                       n_domains=len(datasets) - 1).to(device)
        cfg['model'] = model

        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'FC' in name.upper():
                fc_params.append(param)
            else:
                conv_params.append(param)
        cfg['param_groups'] = []
        cfg['param_groups'].append({
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        })
        cfg['param_groups'].append({
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        })

        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: mdan_train_routine(
            model, optimizer, train_loader, dict(), cfg)
    elif cfg['model'] == 'MODA':
        model = MixMDANet(n_classes=n_classes).to(device)
        cfg['model'] = model

        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'FC' in name.upper():
                fc_params.append(param)
            else:
                conv_params.append(param)
        cfg['param_groups'] = []
        cfg['param_groups'].append({
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        })
        cfg['param_groups'].append({
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        })

        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_train_routine(
            model, optimizer, train_loader, dict(), cfg)
    elif cfg['model'] == 'MODAFM':
        model = MixMDANet(n_classes=n_classes).to(device)
        cfg['model'] = model

        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'FC' in name.upper():
                fc_params.append(param)
            else:
                conv_params.append(param)
        cfg['param_groups'] = []
        cfg['param_groups'].append({
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        })
        cfg['param_groups'].append({
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        })

        cfg['excl_transf'] = None
        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_fm_train_routine(
            model, optimizer, train_loader, dict(), cfg)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    best_params, _ = cross_validation(datasets, cfg, cv_param_names)
    log.print('best_params:', best_params, level=1)

    results = ConfigParser()
    results.add_section('main')
    for key, value in best_params.items():
        results.set('main', key, str(value))
    with open(cfg['output'], 'w') as f:
        results.write(f)
예제 #6
0
def main():
    # N.B.: parameters defined in cv_cfg.ini override args!
    parser = argparse.ArgumentParser(
        description=
        'Cross-validation over source domains for the Amazon dataset.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-m',
                        '--model',
                        default='MODAFM',
                        type=str,
                        metavar='',
                        help='model type (\'MDAN\' / \'MODA\' / \'MODAFM\'')
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/Amazon',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='books',
        type=str,
        metavar='',
        help=
        'target domain (\'books\' / \'dvd\' / \'electronics\' / \'kitchen\')')
    parser.add_argument('-o',
                        '--output',
                        default='msda_hyperparams.ini',
                        type=str,
                        metavar='',
                        help='output file')
    parser.add_argument('-n',
                        '--n_iter',
                        default=20,
                        type=int,
                        metavar='',
                        help='number of CV iterations')
    parser.add_argument('--n_samples',
                        default=2000,
                        type=int,
                        metavar='',
                        help='number of samples from each domain')
    parser.add_argument('--n_features',
                        default=5000,
                        type=int,
                        metavar='',
                        help='number of features to use')
    parser.add_argument(
        '--mu',
        type=float,
        default=1e-2,
        help="hyperparameter of the coefficient for the domain adversarial loss"
    )
    parser.add_argument(
        '--beta',
        type=float,
        default=2e-1,
        help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--lambda',
                        type=float,
                        default=1e-1,
                        help="hyperparameter of the FixMatch loss")
    parser.add_argument('--min_dropout',
                        type=int,
                        default=2e-1,
                        help="minimum dropout rate")
    parser.add_argument('--max_dropout',
                        type=int,
                        default=8e-1,
                        help="maximum dropout rate")
    parser.add_argument('--weight_decay',
                        default=0.,
                        type=float,
                        metavar='',
                        help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr',
                        default=1e0,
                        type=float,
                        metavar='',
                        help='learning rate')
    parser.add_argument('--epochs',
                        default=15,
                        type=int,
                        metavar='',
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        default=20,
                        type=int,
                        metavar='',
                        help='batch size (per domain)')
    parser.add_argument(
        '--checkpoint',
        default=0,
        type=int,
        metavar='',
        help=
        'number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--use_visdom',
                        default=False,
                        type=int,
                        metavar='',
                        help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env',
                        default='amazon_train',
                        type=str,
                        metavar='',
                        help='Visdom environment name')
    parser.add_argument('--visdom_port',
                        default=8888,
                        type=int,
                        metavar='',
                        help='Visdom port')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())

    # override args with cv_cfg.ini
    cfg = args.copy()
    cv_parser = ConfigParser()
    cv_parser.read('cv_cfg.ini')
    cv_param_names = []
    for key, val in cv_parser.items('main'):
        cfg[key] = ast.literal_eval(val)
        cv_param_names.append(key)

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    domains = ['books', 'dvd', 'electronics', 'kitchen']
    datasets = {}
    for domain in domains:
        if domain == cfg['target']:
            continue
        datasets[domain] = Amazon('./amazon.npz',
                                  domain,
                                  dimension=cfg['n_features'],
                                  transform=torch.from_numpy)
        indices = random.sample(range(len(datasets[domain])), cfg['n_samples'])
        datasets[domain] = Subset(datasets[domain], indices)
    cfg['test_transform'] = torch.from_numpy

    if cfg['model'] == 'MDAN':
        model = MDANet(input_dim=cfg['n_features'],
                       n_classes=2,
                       n_domains=len(domains) - 2).to(device)
        cfg['model'] = model
        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: mdan_train_routine(
            model, optimizer, train_loader, dict(), cfg)
    elif cfg['model'] == 'MODA':
        model = MODANet(input_dim=cfg['n_features'], n_classes=2).to(device)
        cfg['model'] = model
        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_train_routine(
            model, optimizer, train_loader, dict(), cfg)
    elif cfg['model'] == 'MODAFM':
        model = MODANet(input_dim=cfg['n_features'], n_classes=2).to(device)
        cfg['model'] = model
        cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_mlp_fm_train_routine(
            model, optimizer, train_loader, dict(), cfg)

    best_params, _ = cross_validation(datasets, cfg, cv_param_names)
    log.print('best_params:', best_params, level=1)

    results = ConfigParser()
    results.add_section('main')
    for key, value in best_params.items():
        results.set('main', key, str(value))
    with open(cfg['output'], 'w') as f:
        results.write(f)