model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("no checkpoint found at '{}'".format(args.resume))

    for epoch in range(0, args.epochs):
        if epoch > args.linear_decrease_start_epoch:
            for g in optimizer.param_groups:
                g['lr'] = args.lr - args.lr * (
                    epoch - args.linear_decrease_start_epoch) / (
                        args.epochs - args.linear_decrease_start_epoch)

        tqdm.write(str(epoch))
        tqdm.write('Training')
        train(epoch, model, train_loader, writer)
        tqdm.write('Testing')
        test(epoch, model, test_loader, writer)

        # is_best = True
        state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        checkpoint_dir = tmp_dir.joinpath(args.run_id, 'checkpoints')
        os.makedirs(checkpoint_dir, exist_ok=True)
        filename = checkpoint_dir.joinpath('checkpoint.pth')
        torch.save(state, filename)
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with the DomainNet dataset.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/DomainNet192',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='clipart',
        type=str,
        metavar='',
        help=
        'target domain (\'clipart\' / \'infograph\' / \'painting\' / \'quickdraw\' / \'real\' / \'sketch\')'
    )
    parser.add_argument('-o',
                        '--output',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--icfg',
                        default=None,
                        type=str,
                        metavar='',
                        help='config file (overrides args)')
    parser.add_argument(
        '--arch',
        default='resnet152',
        type=str,
        metavar='',
        help='network architecture (\'resnet101\' / \'resnet152\'')
    parser.add_argument(
        '--mu_d',
        type=float,
        default=1e-2,
        help=
        "hyperparameter of the coefficient for the domain discriminator loss")
    parser.add_argument(
        '--mu_s',
        type=float,
        default=0.2,
        help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--mu_c',
                        type=float,
                        default=1e-1,
                        help="hyperparameter of the FixMatch loss")
    parser.add_argument('--n_rand_aug',
                        type=int,
                        default=2,
                        help="N parameter of RandAugment")
    parser.add_argument('--m_min_rand_aug',
                        type=int,
                        default=3,
                        help="minimum M parameter of RandAugment")
    parser.add_argument('--m_max_rand_aug',
                        type=int,
                        default=10,
                        help="maximum M parameter of RandAugment")
    parser.add_argument('--weight_decay',
                        default=0.,
                        type=float,
                        metavar='',
                        help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr',
                        default=1e-3,
                        type=float,
                        metavar='',
                        help='learning rate')
    parser.add_argument('--epochs',
                        default=50,
                        type=int,
                        metavar='',
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size (per domain)')
    parser.add_argument(
        '--checkpoint',
        default=0,
        type=int,
        metavar='',
        help=
        'number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--eval_target',
                        default=False,
                        type=int,
                        metavar='',
                        help='evaluate target during training')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--use_visdom',
                        default=False,
                        type=int,
                        metavar='',
                        help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env',
                        default='domainnet_train',
                        type=str,
                        metavar='',
                        help='Visdom environment name')
    parser.add_argument('--visdom_port',
                        default=8888,
                        type=int,
                        metavar='',
                        help='Visdom port')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())

    # override args with icfg (if provided)
    cfg = args.copy()
    if cfg['icfg'] is not None:
        cv_parser = ConfigParser()
        cv_parser.read(cfg['icfg'])
        cv_param_names = []
        for key, val in cv_parser.items('main'):
            cfg[key] = ast.literal_eval(val)
            cv_param_names.append(key)

    # dump args to a txt file for your records
    with open(cfg['output'] + '.txt', 'w') as f:
        f.write(str(cfg) + '\n')

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    # normalization transformation (required for pretrained networks)
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    if 'FM' in cfg['model']:
        # weak data augmentation (small rotation + small translation)
        data_aug = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.RandomHorizontalFlip(),
            T.RandomAffine(5, translate=(0.125, 0.125)),
            T.ToTensor(),
            # normalize,  # normalization disrupts FixMatch
        ])

        eval_transf = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.ToTensor(),
        ])

    else:
        data_aug = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            normalize,
        ])

        eval_transf = T.Compose([
            # T.RandomCrop(224),
            # T.Resize(128),
            T.ToTensor(),
            normalize,
        ])

    domains = [
        'clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'
    ]
    datasets = {
        domain: DomainNet(cfg['data_path'],
                          domain=domain,
                          train=True,
                          transform=data_aug)
        for domain in domains
    }
    n_classes = len(datasets[cfg['target']].class_names)

    test_set = DomainNet(cfg['data_path'],
                         domain=cfg['target'],
                         train=False,
                         transform=eval_transf)
    if 'FM' in cfg['model']:
        target_pub = deepcopy(datasets[cfg['target']])
        target_pub.transform = eval_transf  # no data augmentation in test
    else:
        target_pub = datasets[cfg['target']]

    if cfg['model'] != 'FS':
        train_loader = MSDA_Loader(datasets,
                                   cfg['target'],
                                   batch_size=cfg['batch_size'],
                                   shuffle=True,
                                   num_workers=0,
                                   device=device)
        if cfg['eval_target']:
            valid_loaders = {
                'target pub':
                DataLoader(target_pub, batch_size=6 * cfg['batch_size']),
                'target priv':
                DataLoader(test_set, batch_size=6 * cfg['batch_size'])
            }
        else:
            valid_loaders = None
        log.print('target domain:',
                  cfg['target'],
                  '| source domains:',
                  train_loader.sources,
                  level=1)
    else:
        train_loader = DataLoader(datasets[cfg['target']],
                                  batch_size=cfg['batch_size'],
                                  shuffle=True)
        test_loader = DataLoader(test_set, batch_size=cfg['batch_size'])
        log.print('target domain:', cfg['target'], level=1)

    if cfg['model'] == 'FS':
        model = SimpleCNN(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        valid_loaders = {
            'target pub': test_loader
        } if cfg['eval_target'] else None
        fs_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'FM':
        model = SimpleCNN(n_classes=n_classes, arch=cfg['arch']).to(device)
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        cfg['excl_transf'] = None
        fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'DANNS':
        for src in train_loader.sources:
            model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
            conv_params, fc_params = [], []
            for name, param in model.named_parameters():
                if 'fc' in name.lower():
                    fc_params.append(param)
                else:
                    conv_params.append(param)
            optimizer = optim.Adadelta([{
                'params': conv_params,
                'lr': 0.1 * cfg['lr'],
                'weight_decay': cfg['weight_decay']
            }, {
                'params': fc_params,
                'lr': cfg['lr'],
                'weight_decay': cfg['weight_decay']
            }])
            dataset_ss = {
                src: datasets[src],
                cfg['target']: datasets[cfg['target']]
            }
            train_loader = MSDA_Loader(dataset_ss,
                                       cfg['target'],
                                       batch_size=cfg['batch_size'],
                                       shuffle=True,
                                       device=device)
            dann_train_routine(model, optimizer, train_loader, valid_loaders,
                               cfg)
            torch.save(model.state_dict(), cfg['output'] + '_' + src)

    elif cfg['model'] == 'DANNM':
        model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif args['model'] == 'MDAN':
        model = MDANet(n_classes=n_classes,
                       n_domains=len(train_loader.sources),
                       arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'MODA':
        model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg)

    elif cfg['model'] == 'MODAFM':
        model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device)
        conv_params, fc_params = [], []
        for name, param in model.named_parameters():
            if 'fc' in name.lower():
                fc_params.append(param)
            else:
                conv_params.append(param)
        optimizer = optim.Adadelta([{
            'params': conv_params,
            'lr': 0.1 * cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }, {
            'params': fc_params,
            'lr': cfg['lr'],
            'weight_decay': cfg['weight_decay']
        }])
        cfg['excl_transf'] = None
        moda_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)

    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    torch.save(model.state_dict(), cfg['output'])
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser(
        description='Domain adaptation experiments with digits datasets.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '-m',
        '--model',
        default='MODAFM',
        type=str,
        metavar='',
        help=
        'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\''
    )
    parser.add_argument('-d',
                        '--data_path',
                        default='/ctm-hdd-pool01/DB/',
                        type=str,
                        metavar='',
                        help='data directory path')
    parser.add_argument(
        '-t',
        '--target',
        default='MNIST',
        type=str,
        metavar='',
        help=
        'target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')')
    parser.add_argument('-o',
                        '--output',
                        default='msda.pth',
                        type=str,
                        metavar='',
                        help='model file (output of train)')
    parser.add_argument('--icfg',
                        default=None,
                        type=str,
                        metavar='',
                        help='config file (overrides args)')
    parser.add_argument('--n_src_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from each source domain')
    parser.add_argument('--n_tgt_images',
                        default=20000,
                        type=int,
                        metavar='',
                        help='number of images from the target domain')
    parser.add_argument(
        '--mu_d',
        type=float,
        default=1e-2,
        help=
        "hyperparameter of the coefficient for the domain discriminator loss")
    parser.add_argument(
        '--mu_s',
        type=float,
        default=0.2,
        help="hyperparameter of the non-sparsity regularization")
    parser.add_argument('--mu_c',
                        type=float,
                        default=1e-1,
                        help="hyperparameter of the FixMatch loss")
    parser.add_argument('--n_rand_aug',
                        type=int,
                        default=2,
                        help="N parameter of RandAugment")
    parser.add_argument('--m_min_rand_aug',
                        type=int,
                        default=3,
                        help="minimum M parameter of RandAugment")
    parser.add_argument('--m_max_rand_aug',
                        type=int,
                        default=10,
                        help="maximum M parameter of RandAugment")
    parser.add_argument('--weight_decay',
                        default=0.,
                        type=float,
                        metavar='',
                        help='hyperparameter of weight decay regularization')
    parser.add_argument('--lr',
                        default=1e-1,
                        type=float,
                        metavar='',
                        help='learning rate')
    parser.add_argument('--epochs',
                        default=30,
                        type=int,
                        metavar='',
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        default=8,
                        type=int,
                        metavar='',
                        help='batch size (per domain)')
    parser.add_argument(
        '--checkpoint',
        default=0,
        type=int,
        metavar='',
        help=
        'number of epochs between saving checkpoints (0 disables checkpoints)')
    parser.add_argument('--eval_target',
                        default=False,
                        type=int,
                        metavar='',
                        help='evaluate target during training')
    parser.add_argument('--use_cuda',
                        default=True,
                        type=int,
                        metavar='',
                        help='use CUDA capable GPU')
    parser.add_argument('--use_visdom',
                        default=False,
                        type=int,
                        metavar='',
                        help='use Visdom to visualize plots')
    parser.add_argument('--visdom_env',
                        default='digits_train',
                        type=str,
                        metavar='',
                        help='Visdom environment name')
    parser.add_argument('--visdom_port',
                        default=8888,
                        type=int,
                        metavar='',
                        help='Visdom port')
    parser.add_argument('--verbosity',
                        default=2,
                        type=int,
                        metavar='',
                        help='log verbosity level (0, 1, 2)')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        metavar='',
                        help='random seed')
    args = vars(parser.parse_args())

    # override args with icfg (if provided)
    cfg = args.copy()
    if cfg['icfg'] is not None:
        cv_parser = ConfigParser()
        cv_parser.read(cfg['icfg'])
        cv_param_names = []
        for key, val in cv_parser.items('main'):
            cfg[key] = ast.literal_eval(val)
            cv_param_names.append(key)

    # dump cfg to a txt file for your records
    with open(cfg['output'] + '.txt', 'w') as f:
        f.write(str(cfg) + '\n')

    # use a fixed random seed for reproducibility purposes
    if cfg['seed'] > 0:
        random.seed(cfg['seed'])
        np.random.seed(seed=cfg['seed'])
        torch.manual_seed(cfg['seed'])
        torch.cuda.manual_seed(cfg['seed'])

    device = 'cuda' if (cfg['use_cuda']
                        and torch.cuda.is_available()) else 'cpu'
    log = Logger(cfg['verbosity'])
    log.print('device:', device, level=0)

    if ('FS' in cfg['model']) or ('FM' in cfg['model']):
        # weak data augmentation (small rotation + small translation)
        data_aug = T.Compose([
            T.RandomAffine(5, translate=(0.125, 0.125)),
            T.ToTensor(),
        ])
    else:
        data_aug = T.ToTensor()

    # define all datasets
    datasets = {}
    datasets['MNIST'] = MNIST(train=True,
                              path=os.path.join(cfg['data_path'], 'MNIST'),
                              transform=data_aug)
    datasets['MNIST_M'] = MNIST_M(train=True,
                                  path=os.path.join(cfg['data_path'],
                                                    'MNIST_M'),
                                  transform=data_aug)
    datasets['SVHN'] = SVHN(train=True,
                            path=os.path.join(cfg['data_path'], 'SVHN'),
                            transform=data_aug)
    datasets['SynthDigits'] = SynthDigits(train=True,
                                          path=os.path.join(
                                              cfg['data_path'], 'SynthDigits'),
                                          transform=data_aug)
    if ('FS' in cfg['model']) or ('FM' in cfg['model']):
        test_set = deepcopy(datasets[cfg['target']])
        test_set.transform = T.ToTensor()  # no data augmentation in test
    else:
        test_set = datasets[cfg['target']]

    # get a subset of cfg['n_images'] from each dataset
    # define public and private test sets: the private is not used at training time to learn invariant representations
    for ds_name in datasets:
        if ds_name == cfg['target']:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_tgt_images'] + cfg['n_src_images'])
            test_pub_set = Subset(test_set, indices[0:cfg['n_tgt_images']])
            test_priv_set = Subset(test_set, indices[cfg['n_tgt_images']::])
            datasets[cfg['target']] = Subset(datasets[cfg['target']],
                                             indices[0:cfg['n_tgt_images']])
        else:
            indices = random.sample(range(len(datasets[ds_name])),
                                    cfg['n_src_images'])
            datasets[ds_name] = Subset(datasets[ds_name],
                                       indices[0:cfg['n_src_images']])

    # build the dataloader
    train_loader = MSDA_Loader(datasets,
                               cfg['target'],
                               batch_size=cfg['batch_size'],
                               shuffle=True,
                               device=device)
    test_pub_loader = DataLoader(test_pub_set,
                                 batch_size=4 * cfg['batch_size'])
    test_priv_loader = DataLoader(test_priv_set,
                                  batch_size=4 * cfg['batch_size'])
    valid_loaders = ({
        'target pub': test_pub_loader,
        'target priv': test_priv_loader
    } if cfg['eval_target'] else None)
    log.print('target domain:',
              cfg['target'],
              '| source domains:',
              train_loader.sources,
              level=1)

    if cfg['model'] == 'FS':
        model = SimpleCNN().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        if valid_loaders is not None:
            del valid_loaders['target pub']
        fs_train_routine(model, optimizer, test_pub_loader, valid_loaders, cfg)
    elif cfg['model'] == 'FM':
        model = SimpleCNN().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        cfg['excl_transf'] = [Flip]
        fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'DANNS':
        for src in train_loader.sources:
            model = MODANet().to(device)
            optimizer = optim.Adadelta(model.parameters(),
                                       lr=cfg['lr'],
                                       weight_decay=cfg['weight_decay'])
            dataset_ss = {
                src: datasets[src],
                cfg['target']: datasets[cfg['target']]
            }
            train_loader = MSDA_Loader(dataset_ss,
                                       cfg['target'],
                                       batch_size=cfg['batch_size'],
                                       shuffle=True,
                                       device=device)
            dann_train_routine(model, optimizer, train_loader, valid_loaders,
                               cfg)
            torch.save(model.state_dict(), cfg['output'] + '_' + src)
    elif cfg['model'] == 'DANNM':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MDAN':
        model = MDANet(len(train_loader.sources)).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MDANU':
        model = MDANet(len(train_loader.sources)).to(device)
        model.grad_reverse = nn.ModuleList([
            nn.Identity() for _ in range(len(model.domain_class))
        ])  # remove grad reverse
        task_optim = optim.Adadelta(list(model.feat_ext.parameters()) +
                                    list(model.task_class.parameters()),
                                    lr=cfg['lr'],
                                    weight_decay=cfg['weight_decay'])
        adv_optim = optim.Adadelta(model.domain_class.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        optimizers = (task_optim, adv_optim)
        mdan_unif_train_routine(model, optimizers, train_loader, valid_loaders,
                                cfg)
    elif cfg['model'] == 'MDANFM':
        model = MDANet(len(train_loader.sources)).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        mdan_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)
    elif cfg['model'] == 'MDANUFM':
        model = MDANet(len(train_loader.sources)).to(device)
        task_optim = optim.Adadelta(list(model.feat_ext.parameters()) +
                                    list(model.task_class.parameters()),
                                    lr=cfg['lr'],
                                    weight_decay=cfg['weight_decay'])
        adv_optim = optim.Adadelta(model.domain_class.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        optimizers = (task_optim, adv_optim)
        cfg['excl_transf'] = [Flip]
        mdan_unif_fm_train_routine(model, optimizer, train_loader,
                                   valid_loaders, cfg)
    elif cfg['model'] == 'MODA':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg)
    elif cfg['model'] == 'MODAFM':
        model = MODANet().to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=cfg['lr'],
                                   weight_decay=cfg['weight_decay'])
        cfg['excl_transf'] = [Flip]
        moda_fm_train_routine(model, optimizer, train_loader, valid_loaders,
                              cfg)
    else:
        raise ValueError('Unknown model {}'.format(cfg['model']))

    if cfg['model'] != 'DANNS':
        torch.save(model.state_dict(), cfg['output'])
Ejemplo n.º 4
0
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            print("loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint["epoch"]))
        else:
            print("no checkpoint found at '{}'".format(args.resume))

    for epoch in range(0, args.epochs):
        if epoch > args.linear_decrease_start_epoch:
            for g in optimizer.param_groups:
                g["lr"] = args.lr - args.lr * (
                    epoch - args.linear_decrease_start_epoch) / (
                        args.epochs - args.linear_decrease_start_epoch)

        tqdm.write(str(epoch))
        tqdm.write("Training")
        train(epoch, model, train_loader, writer)
        tqdm.write("Testing")
        test(epoch, model, test_loader, writer)

        # is_best = True
        state = {
            "epoch": epoch + 1,
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        checkpoint_dir = tmp_dir.joinpath(args.run_id, "checkpoints")
        os.makedirs(checkpoint_dir, exist_ok=True)
        filename = checkpoint_dir.joinpath("checkpoint.pth")
        torch.save(state, filename)