def main(): parser = argparse.ArgumentParser( description='Domain adaptation experiments with digits datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '-m', '--model', default='MODAFM', type=str, metavar='', help= 'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\'' ) parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/', type=str, metavar='', help='data directory path') parser.add_argument( '-t', '--target', default='MNIST', type=str, metavar='', help= 'target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')') parser.add_argument('-o', '--output', default='msda.pth', type=str, metavar='', help='model file (output of train)') parser.add_argument('--icfg', default=None, type=str, metavar='', help='config file (overrides args)') parser.add_argument('--n_src_images', default=20000, type=int, metavar='', help='number of images from each source domain') parser.add_argument('--n_tgt_images', default=20000, type=int, metavar='', help='number of images from the target domain') parser.add_argument( '--mu_d', type=float, default=1e-2, help= "hyperparameter of the coefficient for the domain discriminator loss") parser.add_argument( '--mu_s', type=float, default=0.2, help="hyperparameter of the non-sparsity regularization") parser.add_argument('--mu_c', type=float, default=1e-1, help="hyperparameter of the FixMatch loss") parser.add_argument('--n_rand_aug', type=int, default=2, help="N parameter of RandAugment") parser.add_argument('--m_min_rand_aug', type=int, default=3, help="minimum M parameter of RandAugment") parser.add_argument('--m_max_rand_aug', type=int, default=10, help="maximum M parameter of RandAugment") parser.add_argument('--weight_decay', default=0., type=float, metavar='', help='hyperparameter of weight decay regularization') parser.add_argument('--lr', default=1e-1, type=float, metavar='', help='learning rate') parser.add_argument('--epochs', default=30, type=int, metavar='', help='number of training epochs') parser.add_argument('--batch_size', default=8, type=int, metavar='', help='batch size (per domain)') parser.add_argument( '--checkpoint', default=0, type=int, metavar='', help= 'number of epochs between saving checkpoints (0 disables checkpoints)') parser.add_argument('--eval_target', default=False, type=int, metavar='', help='evaluate target during training') parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU') parser.add_argument('--use_visdom', default=False, type=int, metavar='', help='use Visdom to visualize plots') parser.add_argument('--visdom_env', default='digits_train', type=str, metavar='', help='Visdom environment name') parser.add_argument('--visdom_port', default=8888, type=int, metavar='', help='Visdom port') parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level (0, 1, 2)') parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed') args = vars(parser.parse_args()) # override args with icfg (if provided) cfg = args.copy() if cfg['icfg'] is not None: cv_parser = ConfigParser() cv_parser.read(cfg['icfg']) cv_param_names = [] for key, val in cv_parser.items('main'): cfg[key] = ast.literal_eval(val) cv_param_names.append(key) # dump cfg to a txt file for your records with open(cfg['output'] + '.txt', 'w') as f: f.write(str(cfg) + '\n') # use a fixed random seed for reproducibility purposes if cfg['seed'] > 0: random.seed(cfg['seed']) np.random.seed(seed=cfg['seed']) torch.manual_seed(cfg['seed']) torch.cuda.manual_seed(cfg['seed']) device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu' log = Logger(cfg['verbosity']) log.print('device:', device, level=0) if ('FS' in cfg['model']) or ('FM' in cfg['model']): # weak data augmentation (small rotation + small translation) data_aug = T.Compose([ T.RandomAffine(5, translate=(0.125, 0.125)), T.ToTensor(), ]) else: data_aug = T.ToTensor() # define all datasets datasets = {} datasets['MNIST'] = MNIST(train=True, path=os.path.join(cfg['data_path'], 'MNIST'), transform=data_aug) datasets['MNIST_M'] = MNIST_M(train=True, path=os.path.join(cfg['data_path'], 'MNIST_M'), transform=data_aug) datasets['SVHN'] = SVHN(train=True, path=os.path.join(cfg['data_path'], 'SVHN'), transform=data_aug) datasets['SynthDigits'] = SynthDigits(train=True, path=os.path.join( cfg['data_path'], 'SynthDigits'), transform=data_aug) if ('FS' in cfg['model']) or ('FM' in cfg['model']): test_set = deepcopy(datasets[cfg['target']]) test_set.transform = T.ToTensor() # no data augmentation in test else: test_set = datasets[cfg['target']] # get a subset of cfg['n_images'] from each dataset # define public and private test sets: the private is not used at training time to learn invariant representations for ds_name in datasets: if ds_name == cfg['target']: indices = random.sample(range(len(datasets[ds_name])), cfg['n_tgt_images'] + cfg['n_src_images']) test_pub_set = Subset(test_set, indices[0:cfg['n_tgt_images']]) test_priv_set = Subset(test_set, indices[cfg['n_tgt_images']::]) datasets[cfg['target']] = Subset(datasets[cfg['target']], indices[0:cfg['n_tgt_images']]) else: indices = random.sample(range(len(datasets[ds_name])), cfg['n_src_images']) datasets[ds_name] = Subset(datasets[ds_name], indices[0:cfg['n_src_images']]) # build the dataloader train_loader = MSDA_Loader(datasets, cfg['target'], batch_size=cfg['batch_size'], shuffle=True, device=device) test_pub_loader = DataLoader(test_pub_set, batch_size=4 * cfg['batch_size']) test_priv_loader = DataLoader(test_priv_set, batch_size=4 * cfg['batch_size']) valid_loaders = ({ 'target pub': test_pub_loader, 'target priv': test_priv_loader } if cfg['eval_target'] else None) log.print('target domain:', cfg['target'], '| source domains:', train_loader.sources, level=1) if cfg['model'] == 'FS': model = SimpleCNN().to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) if valid_loaders is not None: del valid_loaders['target pub'] fs_train_routine(model, optimizer, test_pub_loader, valid_loaders, cfg) elif cfg['model'] == 'FM': model = SimpleCNN().to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) cfg['excl_transf'] = [Flip] fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'DANNS': for src in train_loader.sources: model = MODANet().to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) dataset_ss = { src: datasets[src], cfg['target']: datasets[cfg['target']] } train_loader = MSDA_Loader(dataset_ss, cfg['target'], batch_size=cfg['batch_size'], shuffle=True, device=device) dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg) torch.save(model.state_dict(), cfg['output'] + '_' + src) elif cfg['model'] == 'DANNM': model = MODANet().to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'MDAN': model = MDANet(len(train_loader.sources)).to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'MDANU': model = MDANet(len(train_loader.sources)).to(device) model.grad_reverse = nn.ModuleList([ nn.Identity() for _ in range(len(model.domain_class)) ]) # remove grad reverse task_optim = optim.Adadelta(list(model.feat_ext.parameters()) + list(model.task_class.parameters()), lr=cfg['lr'], weight_decay=cfg['weight_decay']) adv_optim = optim.Adadelta(model.domain_class.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) optimizers = (task_optim, adv_optim) mdan_unif_train_routine(model, optimizers, train_loader, valid_loaders, cfg) elif cfg['model'] == 'MDANFM': model = MDANet(len(train_loader.sources)).to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) mdan_fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'MDANUFM': model = MDANet(len(train_loader.sources)).to(device) task_optim = optim.Adadelta(list(model.feat_ext.parameters()) + list(model.task_class.parameters()), lr=cfg['lr'], weight_decay=cfg['weight_decay']) adv_optim = optim.Adadelta(model.domain_class.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) optimizers = (task_optim, adv_optim) cfg['excl_transf'] = [Flip] mdan_unif_fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'MODA': model = MODANet().to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'MODAFM': model = MODANet().to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) cfg['excl_transf'] = [Flip] moda_fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg) else: raise ValueError('Unknown model {}'.format(cfg['model'])) if cfg['model'] != 'DANNS': torch.save(model.state_dict(), cfg['output'])
def main(): # N.B.: parameters defined in cv_cfg.ini override args! parser = argparse.ArgumentParser(description='Cross-validation over source domains for the digits datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-m', '--model', default='MODAFM', type=str, metavar='', help='model type (\'MDAN\' / \'MODA\' / \'MODAFM\'') parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/', type=str, metavar='', help='data directory path') parser.add_argument('-t', '--target', default='MNIST', type=str, metavar='', help='target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')') parser.add_argument('-o', '--output', default='msda_hyperparams.ini', type=str, metavar='', help='model file (output of train)') parser.add_argument('-n', '--n_iter', default=20, type=int, metavar='', help='number of CV iterations') parser.add_argument('--n_images', default=20000, type=int, metavar='', help='number of images from each domain') parser.add_argument('--mu', type=float, default=1e-2, help="hyperparameter of the coefficient for the domain adversarial loss") parser.add_argument('--beta', type=float, default=0.2, help="hyperparameter of the non-sparsity regularization") parser.add_argument('--lambda', type=float, default=1e-1, help="hyperparameter of the FixMatch loss") parser.add_argument('--n_rand_aug', type=int, default=2, help="N parameter of RandAugment") parser.add_argument('--m_min_rand_aug', type=int, default=3, help="minimum M parameter of RandAugment") parser.add_argument('--m_max_rand_aug', type=int, default=10, help="maximum M parameter of RandAugment") parser.add_argument('--weight_decay', default=0., type=float, metavar='', help='hyperparameter of weight decay regularization') parser.add_argument('--lr', default=1e-1, type=float, metavar='', help='learning rate') parser.add_argument('--epochs', default=30, type=int, metavar='', help='number of training epochs') parser.add_argument('--batch_size', default=8, type=int, metavar='', help='batch size (per domain)') parser.add_argument('--checkpoint', default=0, type=int, metavar='', help='number of epochs between saving checkpoints (0 disables checkpoints)') parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU') parser.add_argument('--use_visdom', default=False, type=int, metavar='', help='use Visdom to visualize plots') parser.add_argument('--visdom_env', default='digits_train', type=str, metavar='', help='Visdom environment name') parser.add_argument('--visdom_port', default=8888, type=int, metavar='', help='Visdom port') parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level (0, 1, 2)') parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed') args = vars(parser.parse_args()) # override args with cv_cfg.ini cfg = args.copy() cv_parser = ConfigParser() cv_parser.read('cv_cfg.ini') cv_param_names = [] for key, val in cv_parser.items('main'): cfg[key] = ast.literal_eval(val) cv_param_names.append(key) # use a fixed random seed for reproducibility purposes if cfg['seed'] > 0: random.seed(cfg['seed']) np.random.seed(seed=cfg['seed']) torch.manual_seed(cfg['seed']) torch.cuda.manual_seed(cfg['seed']) device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu' log = Logger(cfg['verbosity']) log.print('device:', device, level=0) if 'FM' in cfg['model']: # weak data augmentation (small rotation + small translation) data_aug = T.Compose([ T.RandomAffine(5, translate=(0.125, 0.125)), T.ToTensor(), ]) else: data_aug = T.ToTensor() cfg['test_transform'] = T.ToTensor() # define all datasets datasets = {} datasets['MNIST'] = MNIST(train=True, path=os.path.join(cfg['data_path'], 'MNIST'), transform=data_aug) datasets['MNIST_M'] = MNIST_M(train=True, path=os.path.join(cfg['data_path'], 'MNIST_M'), transform=data_aug) datasets['SVHN'] = SVHN(train=True, path=os.path.join(cfg['data_path'], 'SVHN'), transform=data_aug) datasets['SynthDigits'] = SynthDigits(train=True, path=os.path.join(cfg['data_path'], 'SynthDigits'), transform=data_aug) del datasets[cfg['target']] # get a subset of cfg['n_images'] from each dataset for ds_name in datasets: if ds_name == cfg['target']: continue indices = random.sample(range(len(datasets[ds_name])), cfg['n_images']) datasets[ds_name] = Subset(datasets[ds_name], indices[0:cfg['n_images']]) if cfg['model'] == 'MDAN': cfg['model'] = MDANet(len(datasets)-1).to(device) cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: mdan_train_routine(model, optimizer, train_loader, dict(), cfg) elif cfg['model'] == 'MODA': cfg['model'] = MODANet().to(device) cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_train_routine(model, optimizer, train_loader, dict(), cfg) elif cfg['model'] == 'MODAFM': cfg['model'] = MODANet().to(device) cfg['excl_transf'] = [Flip] cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_fm_train_routine(model, optimizer, train_loader, dict(), cfg) else: raise ValueError('Unknown model {}'.format(cfg['model'])) best_params, _ = cross_validation(datasets, cfg, cv_param_names) log.print('best_params:', best_params, level=1) results = ConfigParser() results.add_section('main') for key, value in best_params.items(): results.set('main', key, str(value)) with open(cfg['output'], 'w') as f: results.write(f)
def main(): parser = argparse.ArgumentParser(description='Domain adaptation experiments with Office dataset.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-m', '--model', default='MODAFM', type=str, metavar='', help='model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\'') parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/OfficeRsz', type=str, metavar='', help='data directory path') parser.add_argument('-t', '--target', default='amazon', type=str, metavar='', help='target domain (\'amazon\' / \'dslr\' / \'webcam\')') parser.add_argument('-i', '--input', default='msda.pth', type=str, metavar='', help='model file (output of train)') parser.add_argument('--arch', default='resnet50', type=str, metavar='', help='network architecture (\'resnet50\' / \'alexnet\'') parser.add_argument('--batch_size', default=20, type=int, metavar='', help='batch size (per domain)') parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU') parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level (0, 1, 2)') parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed') args = vars(parser.parse_args()) cfg = args.copy() # use a fixed random seed for reproducibility purposes if cfg['seed'] > 0: random.seed(cfg['seed']) np.random.seed(seed=cfg['seed']) torch.manual_seed(cfg['seed']) torch.cuda.manual_seed(cfg['seed']) device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu' log = Logger(cfg['verbosity']) log.print('device:', device, level=0) # normalization transformation (required for pretrained networks) normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if 'FM' in cfg['model']: transform = T.ToTensor() else: transform = T.Compose([ T.ToTensor(), normalize, ]) domains = ['amazon', 'dslr', 'webcam'] datasets = {domain: Office(cfg['data_path'], domain=domain, transform=transform) for domain in domains} n_classes = len(datasets[cfg['target']].class_names) if 'FM' in cfg['model']: test_set = deepcopy(datasets[cfg['target']]) test_set.transform = T.ToTensor() # no data augmentation in test else: test_set = datasets[cfg['target']] if cfg['model'] != 'FS': test_loader = {'target pub': DataLoader(test_set, batch_size=3*cfg['batch_size'])} else: train_indices = random.sample(range(len(datasets[cfg['target']])), int(0.8*len(datasets[cfg['target']]))) test_indices = list(set(range(len(datasets[cfg['target']]))) - set(train_indices)) test_loader = {'target pub': DataLoader( datasets[cfg['target']], batch_size=cfg['batch_size'], sampler=SubsetRandomSampler(test_indices))} log.print('target domain:', cfg['target'], level=1) if cfg['model'] in ['FS', 'FM']: model = SimpleCNN(n_classes=n_classes, arch=cfg['arch']).to(device) elif args['model'] == 'MDAN': model = MDANet(n_classes=n_classes, n_domains=len(domains)-1, arch=cfg['arch']).to(device) elif cfg['model'] in ['DANNS', 'DANNM', 'MODA', 'MODAFM']: model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device) else: raise ValueError('Unknown model {}'.format(cfg['model'])) if cfg['model'] != 'DANNS': model.load_state_dict(torch.load(cfg['input'])) accuracies, losses = test_routine(model, test_loader, cfg) print('target pub: acc = {:.3f},'.format(accuracies['target pub']), 'loss = {:.3f}'.format(losses['target pub'])) else: # for DANNS, report results for the best source domain src_domains = ['amazon', 'dslr', 'webcam'] src_domains.remove(cfg['target']) for i, src in enumerate(src_domains): model.load_state_dict(torch.load(cfg['input']+'_'+src)) acc, loss = test_routine(model, test_loader, cfg) if i == 0: accuracies = acc losses = loss else: for key in accuracies.keys(): accuracies[key] = acc[key] if (acc[key] > accuracies[key]) else accuracies[key] losses[key] = loss[key] if (acc[key] > accuracies[key]) else losses[key] log.print('target pub: acc = {:.3f},'.format(accuracies['target pub']), 'loss = {:.3f}'.format(losses['target pub']), level=1)
def main(): parser = argparse.ArgumentParser( description='Domain adaptation experiments with Amazon dataset.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '-m', '--model', default='MODAFM', type=str, metavar='', help= 'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\'' ) parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/Amazon', type=str, metavar='', help='data directory path') parser.add_argument( '-t', '--target', default='books', type=str, metavar='', help= 'target domain (\'books\' / \'dvd\' / \'electronics\' / \'kitchen\')') parser.add_argument('-o', '--output', default='msda.pth', type=str, metavar='', help='model file (output of train)') parser.add_argument('--icfg', default=None, type=str, metavar='', help='config file (overrides args)') parser.add_argument('--n_samples', default=2000, type=int, metavar='', help='number of samples from each domain') parser.add_argument('--n_features', default=5000, type=int, metavar='', help='number of features to use') parser.add_argument( '--mu', type=float, default=1e-2, help="hyperparameter of the coefficient for the domain adversarial loss" ) parser.add_argument( '--beta', type=float, default=2e-1, help="hyperparameter of the non-sparsity regularization") parser.add_argument('--lambda', type=float, default=1e-1, help="hyperparameter of the FixMatch loss") parser.add_argument('--min_dropout', type=int, default=2e-1, help="minimum dropout rate") parser.add_argument('--max_dropout', type=int, default=8e-1, help="maximum dropout rate") parser.add_argument('--weight_decay', default=0., type=float, metavar='', help='hyperparameter of weight decay regularization') parser.add_argument('--lr', default=1e0, type=float, metavar='', help='learning rate') parser.add_argument('--epochs', default=15, type=int, metavar='', help='number of training epochs') parser.add_argument('--batch_size', default=20, type=int, metavar='', help='batch size (per domain)') parser.add_argument( '--checkpoint', default=0, type=int, metavar='', help= 'number of epochs between saving checkpoints (0 disables checkpoints)') parser.add_argument('--eval_target', default=False, type=int, metavar='', help='evaluate target during training') parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU') parser.add_argument('--use_visdom', default=False, type=int, metavar='', help='use Visdom to visualize plots') parser.add_argument('--visdom_env', default='amazon_train', type=str, metavar='', help='Visdom environment name') parser.add_argument('--visdom_port', default=8888, type=int, metavar='', help='Visdom port') parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level') parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed') args = vars(parser.parse_args()) # override args with icfg (if provided) cfg = args.copy() if cfg['icfg'] is not None: cv_parser = ConfigParser() cv_parser.read(cfg['icfg']) cv_param_names = [] for key, val in cv_parser.items('main'): cfg[key] = ast.literal_eval(val) cv_param_names.append(key) # dump cfg to a txt file for your records with open(cfg['output'] + '.txt', 'w') as f: f.write(str(cfg) + '\n') device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu' log = Logger(cfg['verbosity']) log.print('device:', device, level=0) # use a fixed random seed for reproducibility purposes if cfg['seed'] > 0: random.seed(args['seed']) np.random.seed(seed=args['seed']) torch.manual_seed(args['seed']) torch.cuda.manual_seed(args['seed']) domains = ['books', 'dvd', 'electronics', 'kitchen'] datasets = {} for domain in domains: datasets[domain] = Amazon('./amazon.npz', domain, dimension=cfg['n_features'], transform=torch.from_numpy) indices = random.sample(range(len(datasets[domain])), cfg['n_samples']) if domain == cfg['target']: priv_indices = list( set(range(len(datasets[cfg['target']]))) - set(indices)) test_priv_set = Subset(datasets[cfg['target']], priv_indices) datasets[domain] = Subset(datasets[domain], indices) test_pub_set = datasets[cfg['target']] train_loader = MSDA_Loader(datasets, cfg['target'], batch_size=cfg['batch_size'], shuffle=True, device=device) test_pub_loader = DataLoader(test_pub_set, batch_size=4 * cfg['batch_size']) test_priv_loader = DataLoader(test_priv_set, batch_size=4 * cfg['batch_size']) valid_loaders = { 'target pub': test_pub_loader, 'target priv': test_priv_loader } if cfg['eval_target'] else None log.print('target domain:', cfg['target'], 'source domains:', train_loader.sources, level=1) if cfg['model'] == 'FS': model = SimpleMLP(input_dim=cfg['n_features'], n_classes=2).to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) if cfg['eval_target']: del valid_loaders['target pub'] fs_train_routine(model, optimizer, test_pub_loader, valid_loaders, cfg) elif cfg['model'] == 'FM': model = SimpleMLP(input_dim=cfg['n_features'], n_classes=2).to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) mlp_fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'DANNS': for src in train_loader.sources: model = MODANet(input_dim=cfg['n_features'], n_classes=2).to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) dataset_ss = { src: datasets[src], cfg['target']: datasets[cfg['target']] } train_loader = MSDA_Loader(dataset_ss, cfg['target'], batch_size=cfg['batch_size'], shuffle=True, device=device) dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg) torch.save(model.state_dict(), cfg['output'] + '_' + src) elif cfg['model'] == 'DANNM': model = MODANet(input_dim=cfg['n_features'], n_classes=2).to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'MDAN': model = MDANet(input_dim=cfg['n_features'], n_classes=2, n_domains=len(train_loader.sources)).to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'MODA': model = MODANet(input_dim=cfg['n_features'], n_classes=2).to(device) optimizer = optim.Adadelta(model.parameters(), lr=args['lr'], weight_decay=cfg['weight_decay']) moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'MODAFM': model = MODANet(input_dim=cfg['n_features'], n_classes=2).to(device) optimizer = optim.Adadelta(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) moda_mlp_fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg) torch.save(model.state_dict(), cfg['output'])
def main(): parser = argparse.ArgumentParser( description='Domain adaptation experiments with the DomainNet dataset.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '-m', '--model', default='MODAFM', type=str, metavar='', help= 'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\'' ) parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/DomainNet192', type=str, metavar='', help='data directory path') parser.add_argument( '-t', '--target', default='clipart', type=str, metavar='', help= 'target domain (\'clipart\' / \'infograph\' / \'painting\' / \'quickdraw\' / \'real\' / \'sketch\')' ) parser.add_argument('-o', '--output', default='msda.pth', type=str, metavar='', help='model file (output of train)') parser.add_argument('--icfg', default=None, type=str, metavar='', help='config file (overrides args)') parser.add_argument( '--arch', default='resnet152', type=str, metavar='', help='network architecture (\'resnet101\' / \'resnet152\'') parser.add_argument( '--mu_d', type=float, default=1e-2, help= "hyperparameter of the coefficient for the domain discriminator loss") parser.add_argument( '--mu_s', type=float, default=0.2, help="hyperparameter of the non-sparsity regularization") parser.add_argument('--mu_c', type=float, default=1e-1, help="hyperparameter of the FixMatch loss") parser.add_argument('--n_rand_aug', type=int, default=2, help="N parameter of RandAugment") parser.add_argument('--m_min_rand_aug', type=int, default=3, help="minimum M parameter of RandAugment") parser.add_argument('--m_max_rand_aug', type=int, default=10, help="maximum M parameter of RandAugment") parser.add_argument('--weight_decay', default=0., type=float, metavar='', help='hyperparameter of weight decay regularization') parser.add_argument('--lr', default=1e-3, type=float, metavar='', help='learning rate') parser.add_argument('--epochs', default=50, type=int, metavar='', help='number of training epochs') parser.add_argument('--batch_size', default=8, type=int, metavar='', help='batch size (per domain)') parser.add_argument( '--checkpoint', default=0, type=int, metavar='', help= 'number of epochs between saving checkpoints (0 disables checkpoints)') parser.add_argument('--eval_target', default=False, type=int, metavar='', help='evaluate target during training') parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU') parser.add_argument('--use_visdom', default=False, type=int, metavar='', help='use Visdom to visualize plots') parser.add_argument('--visdom_env', default='domainnet_train', type=str, metavar='', help='Visdom environment name') parser.add_argument('--visdom_port', default=8888, type=int, metavar='', help='Visdom port') parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level (0, 1, 2)') parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed') args = vars(parser.parse_args()) # override args with icfg (if provided) cfg = args.copy() if cfg['icfg'] is not None: cv_parser = ConfigParser() cv_parser.read(cfg['icfg']) cv_param_names = [] for key, val in cv_parser.items('main'): cfg[key] = ast.literal_eval(val) cv_param_names.append(key) # dump args to a txt file for your records with open(cfg['output'] + '.txt', 'w') as f: f.write(str(cfg) + '\n') # use a fixed random seed for reproducibility purposes if cfg['seed'] > 0: random.seed(cfg['seed']) np.random.seed(seed=cfg['seed']) torch.manual_seed(cfg['seed']) torch.cuda.manual_seed(cfg['seed']) device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu' log = Logger(cfg['verbosity']) log.print('device:', device, level=0) # normalization transformation (required for pretrained networks) normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if 'FM' in cfg['model']: # weak data augmentation (small rotation + small translation) data_aug = T.Compose([ # T.RandomCrop(224), # T.Resize(128), T.RandomHorizontalFlip(), T.RandomAffine(5, translate=(0.125, 0.125)), T.ToTensor(), # normalize, # normalization disrupts FixMatch ]) eval_transf = T.Compose([ # T.RandomCrop(224), # T.Resize(128), T.ToTensor(), ]) else: data_aug = T.Compose([ # T.RandomCrop(224), # T.Resize(128), T.RandomHorizontalFlip(), T.ToTensor(), normalize, ]) eval_transf = T.Compose([ # T.RandomCrop(224), # T.Resize(128), T.ToTensor(), normalize, ]) domains = [ 'clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch' ] datasets = { domain: DomainNet(cfg['data_path'], domain=domain, train=True, transform=data_aug) for domain in domains } n_classes = len(datasets[cfg['target']].class_names) test_set = DomainNet(cfg['data_path'], domain=cfg['target'], train=False, transform=eval_transf) if 'FM' in cfg['model']: target_pub = deepcopy(datasets[cfg['target']]) target_pub.transform = eval_transf # no data augmentation in test else: target_pub = datasets[cfg['target']] if cfg['model'] != 'FS': train_loader = MSDA_Loader(datasets, cfg['target'], batch_size=cfg['batch_size'], shuffle=True, num_workers=0, device=device) if cfg['eval_target']: valid_loaders = { 'target pub': DataLoader(target_pub, batch_size=6 * cfg['batch_size']), 'target priv': DataLoader(test_set, batch_size=6 * cfg['batch_size']) } else: valid_loaders = None log.print('target domain:', cfg['target'], '| source domains:', train_loader.sources, level=1) else: train_loader = DataLoader(datasets[cfg['target']], batch_size=cfg['batch_size'], shuffle=True) test_loader = DataLoader(test_set, batch_size=cfg['batch_size']) log.print('target domain:', cfg['target'], level=1) if cfg['model'] == 'FS': model = SimpleCNN(n_classes=n_classes, arch=cfg['arch']).to(device) conv_params, fc_params = [], [] for name, param in model.named_parameters(): if 'fc' in name.lower(): fc_params.append(param) else: conv_params.append(param) optimizer = optim.Adadelta([{ 'params': conv_params, 'lr': 0.1 * cfg['lr'], 'weight_decay': cfg['weight_decay'] }, { 'params': fc_params, 'lr': cfg['lr'], 'weight_decay': cfg['weight_decay'] }]) valid_loaders = { 'target pub': test_loader } if cfg['eval_target'] else None fs_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'FM': model = SimpleCNN(n_classes=n_classes, arch=cfg['arch']).to(device) for name, param in model.named_parameters(): if 'fc' in name.lower(): fc_params.append(param) else: conv_params.append(param) optimizer = optim.Adadelta([{ 'params': conv_params, 'lr': 0.1 * cfg['lr'], 'weight_decay': cfg['weight_decay'] }, { 'params': fc_params, 'lr': cfg['lr'], 'weight_decay': cfg['weight_decay'] }]) cfg['excl_transf'] = None fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'DANNS': for src in train_loader.sources: model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device) conv_params, fc_params = [], [] for name, param in model.named_parameters(): if 'fc' in name.lower(): fc_params.append(param) else: conv_params.append(param) optimizer = optim.Adadelta([{ 'params': conv_params, 'lr': 0.1 * cfg['lr'], 'weight_decay': cfg['weight_decay'] }, { 'params': fc_params, 'lr': cfg['lr'], 'weight_decay': cfg['weight_decay'] }]) dataset_ss = { src: datasets[src], cfg['target']: datasets[cfg['target']] } train_loader = MSDA_Loader(dataset_ss, cfg['target'], batch_size=cfg['batch_size'], shuffle=True, device=device) dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg) torch.save(model.state_dict(), cfg['output'] + '_' + src) elif cfg['model'] == 'DANNM': model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device) conv_params, fc_params = [], [] for name, param in model.named_parameters(): if 'fc' in name.lower(): fc_params.append(param) else: conv_params.append(param) optimizer = optim.Adadelta([{ 'params': conv_params, 'lr': 0.1 * cfg['lr'], 'weight_decay': cfg['weight_decay'] }, { 'params': fc_params, 'lr': cfg['lr'], 'weight_decay': cfg['weight_decay'] }]) dann_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif args['model'] == 'MDAN': model = MDANet(n_classes=n_classes, n_domains=len(train_loader.sources), arch=cfg['arch']).to(device) conv_params, fc_params = [], [] for name, param in model.named_parameters(): if 'fc' in name.lower(): fc_params.append(param) else: conv_params.append(param) optimizer = optim.Adadelta([{ 'params': conv_params, 'lr': 0.1 * cfg['lr'], 'weight_decay': cfg['weight_decay'] }, { 'params': fc_params, 'lr': cfg['lr'], 'weight_decay': cfg['weight_decay'] }]) mdan_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'MODA': model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device) conv_params, fc_params = [], [] for name, param in model.named_parameters(): if 'fc' in name.lower(): fc_params.append(param) else: conv_params.append(param) optimizer = optim.Adadelta([{ 'params': conv_params, 'lr': 0.1 * cfg['lr'], 'weight_decay': cfg['weight_decay'] }, { 'params': fc_params, 'lr': cfg['lr'], 'weight_decay': cfg['weight_decay'] }]) moda_train_routine(model, optimizer, train_loader, valid_loaders, cfg) elif cfg['model'] == 'MODAFM': model = MODANet(n_classes=n_classes, arch=cfg['arch']).to(device) conv_params, fc_params = [], [] for name, param in model.named_parameters(): if 'fc' in name.lower(): fc_params.append(param) else: conv_params.append(param) optimizer = optim.Adadelta([{ 'params': conv_params, 'lr': 0.1 * cfg['lr'], 'weight_decay': cfg['weight_decay'] }, { 'params': fc_params, 'lr': cfg['lr'], 'weight_decay': cfg['weight_decay'] }]) cfg['excl_transf'] = None moda_fm_train_routine(model, optimizer, train_loader, valid_loaders, cfg) else: raise ValueError('Unknown model {}'.format(cfg['model'])) torch.save(model.state_dict(), cfg['output'])
def main(): parser = argparse.ArgumentParser( description= 'Cross-validation over source domains for the Office dataset.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-m', '--model', default='MODAFM', type=str, metavar='', help='model type (\'MDAN\' / \'MODA\' / \'MODAFM\'') parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/OfficeRsz', type=str, metavar='', help='data directory path') parser.add_argument( '-t', '--target', default='amazon', type=str, metavar='', help='target domain (\'amazon\' / \'dslr\' / \'webcam\')') parser.add_argument( '-o', '--output', default='cv_out.ini', type=str, metavar='', help='best hyperparameters (output of cross validation') parser.add_argument('-n', '--n_iter', default=20, type=int, metavar='', help='number of CV iterations') parser.add_argument( '--mu_d', type=float, default=1e-2, help= "hyperparameter of the coefficient for the domain discriminator loss") parser.add_argument( '--mu_s', type=float, default=0.2, help="hyperparameter of the non-sparsity regularization") parser.add_argument('--mu_c', type=float, default=1e-1, help="hyperparameter of the FixMatch loss") parser.add_argument('--n_rand_aug', type=int, default=2, help="N parameter of RandAugment") parser.add_argument('--m_min_rand_aug', type=int, default=3, help="minimum M parameter of RandAugment") parser.add_argument('--m_max_rand_aug', type=int, default=10, help="maximum M parameter of RandAugment") parser.add_argument('--weight_decay', default=0., type=float, metavar='', help='hyperparameter of weight decay regularization') parser.add_argument('--lr', default=1e-1, type=float, metavar='', help='learning rate') parser.add_argument('--epochs', default=15, type=int, metavar='', help='number of training epochs') parser.add_argument('--batch_size', default=8, type=int, metavar='', help='batch size (per domain)') parser.add_argument( '--checkpoint', default=0, type=int, metavar='', help= 'number of epochs between saving checkpoints (0 disables checkpoints)') parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU') parser.add_argument('--use_visdom', default=False, type=int, metavar='', help='use Visdom to visualize plots') parser.add_argument('--visdom_env', default='office_train', type=str, metavar='', help='Visdom environment name') parser.add_argument('--visdom_port', default=8888, type=int, metavar='', help='Visdom port') parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level (0, 1, 2)') parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed') args = vars(parser.parse_args()) # override args with cv_cfg.ini cfg = args.copy() cv_parser = ConfigParser() cv_parser.read('cv_cfg.ini') cv_param_names = [] for key, val in cv_parser.items('main'): cfg[key] = ast.literal_eval(val) cv_param_names.append(key) device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu' log = Logger(cfg['verbosity']) log.print('device:', device, level=0) # dump cfg to a txt file for your records with open(cfg['output'] + '.txt', 'w') as f: f.write(str(cfg) + '\n') # use a fixed random seed for reproducibility purposes if cfg['seed'] > 0: random.seed(cfg['seed']) np.random.seed(seed=cfg['seed']) torch.manual_seed(cfg['seed']) torch.cuda.manual_seed(cfg['seed']) # normalization transformation (required for pretrained networks) normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if 'FM' in cfg['model']: # weak data augmentation (small rotation + small translation) data_aug = T.Compose([ T.RandomHorizontalFlip(), T.RandomAffine(5, translate=(0.125, 0.125)), T.ToTensor(), # normalize, # normalization disrupts FixMatch ]) cfg['test_transform'] = T.ToTensor() else: data_aug = T.Compose([ T.RandomHorizontalFlip(), T.ToTensor(), normalize, ]) cfg['test_transform'] = T.Compose([ T.ToTensor(), normalize, ]) domains = ['amazon', 'dslr', 'webcam'] datasets = { domain: Office(cfg['data_path'], domain=domain, transform=data_aug) for domain in domains } n_classes = len(datasets[cfg['target']].class_names) del datasets[args['target']] if cfg['model'] == 'MDAN': model = MDANet(n_classes=n_classes, n_domains=len(datasets) - 1).to(device) cfg['model'] = model conv_params, fc_params = [], [] for name, param in model.named_parameters(): if 'FC' in name.upper(): fc_params.append(param) else: conv_params.append(param) cfg['param_groups'] = [] cfg['param_groups'].append({ 'params': conv_params, 'lr': 0.1 * cfg['lr'], 'weight_decay': cfg['weight_decay'] }) cfg['param_groups'].append({ 'params': fc_params, 'lr': cfg['lr'], 'weight_decay': cfg['weight_decay'] }) cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: mdan_train_routine( model, optimizer, train_loader, dict(), cfg) elif cfg['model'] == 'MODA': model = MixMDANet(n_classes=n_classes).to(device) cfg['model'] = model conv_params, fc_params = [], [] for name, param in model.named_parameters(): if 'FC' in name.upper(): fc_params.append(param) else: conv_params.append(param) cfg['param_groups'] = [] cfg['param_groups'].append({ 'params': conv_params, 'lr': 0.1 * cfg['lr'], 'weight_decay': cfg['weight_decay'] }) cfg['param_groups'].append({ 'params': fc_params, 'lr': cfg['lr'], 'weight_decay': cfg['weight_decay'] }) cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_train_routine( model, optimizer, train_loader, dict(), cfg) elif cfg['model'] == 'MODAFM': model = MixMDANet(n_classes=n_classes).to(device) cfg['model'] = model conv_params, fc_params = [], [] for name, param in model.named_parameters(): if 'FC' in name.upper(): fc_params.append(param) else: conv_params.append(param) cfg['param_groups'] = [] cfg['param_groups'].append({ 'params': conv_params, 'lr': 0.1 * cfg['lr'], 'weight_decay': cfg['weight_decay'] }) cfg['param_groups'].append({ 'params': fc_params, 'lr': cfg['lr'], 'weight_decay': cfg['weight_decay'] }) cfg['excl_transf'] = None cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_fm_train_routine( model, optimizer, train_loader, dict(), cfg) else: raise ValueError('Unknown model {}'.format(cfg['model'])) best_params, _ = cross_validation(datasets, cfg, cv_param_names) log.print('best_params:', best_params, level=1) results = ConfigParser() results.add_section('main') for key, value in best_params.items(): results.set('main', key, str(value)) with open(cfg['output'], 'w') as f: results.write(f)
def main(): parser = argparse.ArgumentParser( description='Domain adaptation experiments with digits datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '-m', '--model', default='MODAFM', type=str, metavar='', help= 'model type (\'FS\' / \'DANNS\' / \'DANNM\' / \'MDAN\' / \'MODA\' / \'FM\' / \'MODAFM\'' ) parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/', type=str, metavar='', help='data directory path') parser.add_argument( '-t', '--target', default='MNIST', type=str, metavar='', help= 'target domain (\'MNIST\' / \'MNIST_M\' / \'SVHN\' / \'SynthDigits\')') parser.add_argument('-i', '--input', default='msda.pth', type=str, metavar='', help='model file (output of train)') parser.add_argument('--n_images', default=20000, type=int, metavar='', help='number of images from each domain') parser.add_argument('--batch_size', default=8, type=int, metavar='', help='batch size') parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU') parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level (0, 1, 2)') parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed') args = vars(parser.parse_args()) cfg = args.copy() # use a fixed random seed for reproducibility purposes if cfg['seed'] > 0: random.seed(cfg['seed']) np.random.seed(seed=cfg['seed']) torch.manual_seed(cfg['seed']) torch.cuda.manual_seed(cfg['seed']) device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu' log = Logger(cfg['verbosity']) log.print('device:', device, level=0) # define all datasets datasets = {} datasets['MNIST'] = MNIST(train=True, path=os.path.join(cfg['data_path'], 'MNIST'), transform=T.ToTensor()) datasets['MNIST_M'] = MNIST_M(train=True, path=os.path.join(cfg['data_path'], 'MNIST_M'), transform=T.ToTensor()) datasets['SVHN'] = SVHN(train=True, path=os.path.join(cfg['data_path'], 'SVHN'), transform=T.ToTensor()) datasets['SynthDigits'] = SynthDigits(train=True, path=os.path.join( cfg['data_path'], 'SynthDigits'), transform=T.ToTensor()) test_set = datasets[cfg['target']] # get a subset of cfg['n_images'] from each dataset # define public and private test sets: the private is not used at training time to learn invariant representations for ds_name in datasets: if ds_name == cfg['target']: indices = random.sample(range(len(datasets[ds_name])), 2 * cfg['n_images']) test_pub_set = Subset(test_set, indices[0:cfg['n_images']]) test_priv_set = Subset(test_set, indices[cfg['n_images']::]) else: indices = random.sample(range(len(datasets[ds_name])), cfg['n_images']) datasets[ds_name] = Subset(datasets[ds_name], indices[0:cfg['n_images']]) # build the dataloader test_pub_loader = DataLoader(test_pub_set, batch_size=4 * cfg['batch_size']) test_priv_loader = DataLoader(test_priv_set, batch_size=4 * cfg['batch_size']) test_loaders = { 'target pub': test_pub_loader, 'target priv': test_priv_loader } log.print('target domain:', cfg['target'], level=0) if cfg['model'] in ['FS', 'FM']: model = SimpleCNN().to(device) elif args['model'] == 'MDAN': model = MDANet(len(datasets) - 1).to(device) elif cfg['model'] in ['DANNS', 'DANNM', 'MODA', 'MODAFM']: model = MODANet().to(device) else: raise ValueError('Unknown model {}'.format(cfg['model'])) if cfg['model'] != 'DANNS': model.load_state_dict(torch.load(cfg['input'])) accuracies, losses = test_routine(model, test_loaders, cfg) print('target pub: acc = {:.3f},'.format(accuracies['target pub']), 'loss = {:.3f}'.format(losses['target pub'])) print('target priv: acc = {:.3f},'.format(accuracies['target priv']), 'loss = {:.3f}'.format(losses['target priv'])) else: # for DANNS, report results for the best source domain src_domains = ['MNIST', 'MNIST_M', 'SVHN', 'SynthDigits'] src_domains.remove(cfg['target']) for i, src in enumerate(src_domains): model.load_state_dict(torch.load(cfg['input'] + '_' + src)) acc, loss = test_routine(model, test_loaders, cfg) if i == 0: accuracies = acc losses = loss else: for key in accuracies.keys(): accuracies[key] = acc[key] if ( acc[key] > accuracies[key]) else accuracies[key] losses[key] = loss[key] if ( acc[key] > accuracies[key]) else losses[key] log.print('target pub: acc = {:.3f},'.format(accuracies['target pub']), 'loss = {:.3f}'.format(losses['target pub']), level=1) log.print('target priv: acc = {:.3f},'.format( accuracies['target priv']), 'loss = {:.3f}'.format(losses['target priv']), level=1)
def main(): # N.B.: parameters defined in cv_cfg.ini override args! parser = argparse.ArgumentParser( description= 'Cross-validation over source domains for the Amazon dataset.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-m', '--model', default='MODAFM', type=str, metavar='', help='model type (\'MDAN\' / \'MODA\' / \'MODAFM\'') parser.add_argument('-d', '--data_path', default='/ctm-hdd-pool01/DB/Amazon', type=str, metavar='', help='data directory path') parser.add_argument( '-t', '--target', default='books', type=str, metavar='', help= 'target domain (\'books\' / \'dvd\' / \'electronics\' / \'kitchen\')') parser.add_argument('-o', '--output', default='msda_hyperparams.ini', type=str, metavar='', help='output file') parser.add_argument('-n', '--n_iter', default=20, type=int, metavar='', help='number of CV iterations') parser.add_argument('--n_samples', default=2000, type=int, metavar='', help='number of samples from each domain') parser.add_argument('--n_features', default=5000, type=int, metavar='', help='number of features to use') parser.add_argument( '--mu', type=float, default=1e-2, help="hyperparameter of the coefficient for the domain adversarial loss" ) parser.add_argument( '--beta', type=float, default=2e-1, help="hyperparameter of the non-sparsity regularization") parser.add_argument('--lambda', type=float, default=1e-1, help="hyperparameter of the FixMatch loss") parser.add_argument('--min_dropout', type=int, default=2e-1, help="minimum dropout rate") parser.add_argument('--max_dropout', type=int, default=8e-1, help="maximum dropout rate") parser.add_argument('--weight_decay', default=0., type=float, metavar='', help='hyperparameter of weight decay regularization') parser.add_argument('--lr', default=1e0, type=float, metavar='', help='learning rate') parser.add_argument('--epochs', default=15, type=int, metavar='', help='number of training epochs') parser.add_argument('--batch_size', default=20, type=int, metavar='', help='batch size (per domain)') parser.add_argument( '--checkpoint', default=0, type=int, metavar='', help= 'number of epochs between saving checkpoints (0 disables checkpoints)') parser.add_argument('--use_cuda', default=True, type=int, metavar='', help='use CUDA capable GPU') parser.add_argument('--use_visdom', default=False, type=int, metavar='', help='use Visdom to visualize plots') parser.add_argument('--visdom_env', default='amazon_train', type=str, metavar='', help='Visdom environment name') parser.add_argument('--visdom_port', default=8888, type=int, metavar='', help='Visdom port') parser.add_argument('--verbosity', default=2, type=int, metavar='', help='log verbosity level') parser.add_argument('--seed', default=42, type=int, metavar='', help='random seed') args = vars(parser.parse_args()) # override args with cv_cfg.ini cfg = args.copy() cv_parser = ConfigParser() cv_parser.read('cv_cfg.ini') cv_param_names = [] for key, val in cv_parser.items('main'): cfg[key] = ast.literal_eval(val) cv_param_names.append(key) # use a fixed random seed for reproducibility purposes if cfg['seed'] > 0: random.seed(cfg['seed']) np.random.seed(seed=cfg['seed']) torch.manual_seed(cfg['seed']) torch.cuda.manual_seed(cfg['seed']) device = 'cuda' if (cfg['use_cuda'] and torch.cuda.is_available()) else 'cpu' log = Logger(cfg['verbosity']) log.print('device:', device, level=0) domains = ['books', 'dvd', 'electronics', 'kitchen'] datasets = {} for domain in domains: if domain == cfg['target']: continue datasets[domain] = Amazon('./amazon.npz', domain, dimension=cfg['n_features'], transform=torch.from_numpy) indices = random.sample(range(len(datasets[domain])), cfg['n_samples']) datasets[domain] = Subset(datasets[domain], indices) cfg['test_transform'] = torch.from_numpy if cfg['model'] == 'MDAN': model = MDANet(input_dim=cfg['n_features'], n_classes=2, n_domains=len(domains) - 2).to(device) cfg['model'] = model cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: mdan_train_routine( model, optimizer, train_loader, dict(), cfg) elif cfg['model'] == 'MODA': model = MODANet(input_dim=cfg['n_features'], n_classes=2).to(device) cfg['model'] = model cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_train_routine( model, optimizer, train_loader, dict(), cfg) elif cfg['model'] == 'MODAFM': model = MODANet(input_dim=cfg['n_features'], n_classes=2).to(device) cfg['model'] = model cfg['train_routine'] = lambda model, optimizer, train_loader, cfg: moda_mlp_fm_train_routine( model, optimizer, train_loader, dict(), cfg) best_params, _ = cross_validation(datasets, cfg, cv_param_names) log.print('best_params:', best_params, level=1) results = ConfigParser() results.add_section('main') for key, value in best_params.items(): results.set('main', key, str(value)) with open(cfg['output'], 'w') as f: results.write(f)