test_loss, acc_value * 100))


if __name__ == "__main__":
    writer = SummaryWriter(tmp_dir.joinpath(args.run_id, 'log'))

    model = SimpleCNN(num_targets=len(train_dataset.classes),
                      num_channels=num_channels)
    model = model.to(device)

    # we choose binary cross entropy loss with logits (i.e. sigmoid applied before calculating loss)
    # because we want to detect multiple events concurrently (at least later on, when we have more labels)
    criterion = nn.BCEWithLogitsLoss()

    # for most cases adam is a good choice
    optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

    start_epoch = 0

    # optionally, resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
Example #2
0
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with digits datasets.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='MNIST',
        type=str,
        metavar='',
        help=
        'target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')')
    parser.add_argument('-o',
                        '--output',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--icfg',
                        default=None,
                        type=str,
                        metavar='',
                        help='config file (overrides args)')
    parser.add_argument('--n_src_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from each source domain')
    parser.add_argument('--n_tgt_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from the target domain')
    parser.add_argument(
        '--mu_d',
        type=float,
        default=1e-2,
        help=
        "hyperparameter of the coefficient for the domain discriminator loss")
    parser.add_argument(
        '--mu_s',
        type=float,
        default=0.2,
        help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--mu_c',
                        type=float,
                        default=1e-1,
                        help="hyperparameter of the FixMatch loss")
    parser.add_argument('--n_rand_aug',
                        type=int,
                        default=2,
                        help="N parameter of RandAugment")
    parser.add_argument('--m_min_rand_aug',
                        type=int,
                        default=3,
                        help="minimum M parameter of RandAugment")
    parser.add_argument('--m_max_rand_aug',
                        type=int,
                        default=10,
                        help="maximum M parameter of RandAugment")
    parser.add_argument('--weight_decay',
                        default=0.,
                        type=float,
                        metavar='',
                        help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr',
                        default=1e-1,
                        type=float,
                        metavar='',
                        help='learning rate')
    parser.add_argument('--epochs',
                        default=30,
                        type=int,
                        metavar='',
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size (per domain)')
    parser.add_argument(
        '--checkpoint',
        default=0,
        type=int,
        metavar='',
        help=
        'number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--eval_target',
                        default=False,
                        type=int,
                        metavar='',
                        help='evaluate target during training')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--use_visdom',
                        default=False,
                        type=int,
                        metavar='',
                        help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env',
                        default='digits_train',
                        type=str,
                        metavar='',
                        help='Visdom environment name')
    parser.add_argument('--visdom_port',
                        default=8888,
                        type=int,
                        metavar='',
                        help='Visdom port')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())

    # override args with icfg (if provided)
    cfg = args.copy()
    if cfg['icfg'] is not None:
        cv_parser = ConfigParser()
        cv_parser.read(cfg['icfg'])
        cv_param_names = []
        for key, val in cv_parser.items('main'):
            cfg[key] = ast.literal_eval(val)
            cv_param_names.append(key)

    # dump cfg to a txt file for your records
    with open(cfg['output'] + '.txt', 'w') as f:
        f.write(str(cfg) + '\n')

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    if ('FS' in cfg['model']) or ('FM' in cfg['model']):
        # weak data augmentation (small rotation + small translation)
        data_aug = T.Compose([
            T.RandomAffine(5, translate=(0.125, 0.125)),
            T.ToTensor(),
        ])
    else:
        data_aug = T.ToTensor()

    # define all datasets
    datasets = {}
    datasets['MNIST'] = MNIST(train=True,
                              path=os.path.join(cfg['data_path'], 'MNIST'),
                              transform=data_aug)
    datasets['MNIST_M'] = MNIST_M(train=True,
                                  path=os.path.join(cfg['data_path'],
                                                    'MNIST_M'),
                                  transform=data_aug)
    datasets['SVHN'] = SVHN(train=True,
                            path=os.path.join(cfg['data_path'], 'SVHN'),
                            transform=data_aug)
    datasets['SynthDigits'] = SynthDigits(train=True,
                                          path=os.path.join(
                                              cfg['data_path'], 'SynthDigits'),
                                          transform=data_aug)
    if ('FS' in cfg['model']) or ('FM' in cfg['model']):
        test_set = deepcopy(datasets[cfg['target']])
        test_set.transform = T.ToTensor()  # no data augmentation in test
    else:
        test_set = datasets[cfg['target']]

    # get a subset of cfg['n_images'] from each dataset
    # define public and private test sets: the private is not used at training time to learn invariant representations
    for ds_name in datasets:
        if ds_name == cfg['target']:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_tgt_images'] + cfg['n_src_images'])
            test_pub_set = Subset(test_set, indices[0:cfg['n_tgt_images']])
            test_priv_set = Subset(test_set, indices[cfg['n_tgt_images']::])
            datasets[cfg['target']] = Subset(datasets[cfg['target']],
                                             indices[0:cfg['n_tgt_images']])
        else:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_src_images'])
            datasets[ds_name] = Subset(datasets[ds_name],
                                       indices[0:cfg['n_src_images']])

    # build the dataloader
    train_loader = MSDA_Loader(datasets,
                               cfg['target'],
                               batch_size=cfg['batch_size'],
                               shuffle=True,
                               device=device)
    test_pub_loader = DataLoader(test_pub_set,
                                 batch_size=4 * cfg['batch_size'])
    test_priv_loader = DataLoader(test_priv_set,
                                  batch_size=4 * cfg['batch_size'])
    valid_loaders = ({
        'target pub': test_pub_loader,
        'target priv': test_priv_loader
    } if cfg['eval_target'] else None)
    log.print('target domain:',
              cfg['target'],
              '| source domains:',
              train_loader.sources,
              level=1)

    if cfg['model'] == 'FS':
        model = SimpleCNN().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        if valid_loaders is not None:
            del valid_loaders['target pub']
        fs_train_routine(model, optimizer, test_pub_loader, valid_loaders, cfg)
    elif cfg['model'] == 'FM':
        model = SimpleCNN().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        cfg['excl_transf'] = [Flip]
        fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'DANNS':
        for src in train_loader.sources:
            model = MODANet().to(device)
            optimizer = optim.Adadelta(model.parameters(),
                                       lr=cfg['lr'],
                                       weight_decay=cfg['weight_decay'])
            dataset_ss = {
                src: datasets[src],
                cfg['target']: datasets[cfg['target']]
            }
            train_loader = MSDA_Loader(dataset_ss,
                                       cfg['target'],
                                       batch_size=cfg['batch_size'],
                                       shuffle=True,
                                       device=device)
            dann_train_routine(model, optimizer, train_loader, valid_loaders,
                               cfg)
            torch.save(model.state_dict(), cfg['output'] + '_' + src)
    elif cfg['model'] == 'DANNM':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MDAN':
        model = MDANet(len(train_loader.sources)).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MDANU':
        model = MDANet(len(train_loader.sources)).to(device)
        model.grad_reverse = nn.ModuleList([
            nn.Identity() for _ in range(len(model.domain_class))
        ])  # remove grad reverse
        task_optim = optim.Adadelta(list(model.feat_ext.parameters()) +
                                    list(model.task_class.parameters()),
                                    lr=cfg['lr'],
                                    weight_decay=cfg['weight_decay'])
        adv_optim = optim.Adadelta(model.domain_class.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        optimizers = (task_optim, adv_optim)
        mdan_unif_train_routine(model, optimizers, train_loader, valid_loaders,
                                cfg)
    elif cfg['model'] == 'MDANFM':
        model = MDANet(len(train_loader.sources)).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mdan_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)
    elif cfg['model'] == 'MDANUFM':
        model = MDANet(len(train_loader.sources)).to(device)
        task_optim = optim.Adadelta(list(model.feat_ext.parameters()) +
                                    list(model.task_class.parameters()),
                                    lr=cfg['lr'],
                                    weight_decay=cfg['weight_decay'])
        adv_optim = optim.Adadelta(model.domain_class.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        optimizers = (task_optim, adv_optim)
        cfg['excl_transf'] = [Flip]
        mdan_unif_fm_train_routine(model, optimizer, train_loader,
                                   valid_loaders, cfg)
    elif cfg['model'] == 'MODA':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MODAFM':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        cfg['excl_transf'] = [Flip]
        moda_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    if cfg['model'] != 'DANNS':
        torch.save(model.state_dict(), cfg['output'])