def main(): opt = TrainParser().parse() # set up logger utils.set_logger(opt=opt, filename='train.log', filemode='w') if opt.seed: utils.make_deterministic(opt.seed) model = models.get_model(opt) model = model.to(device) if opt.multi_gpu and device == 'cuda': model = torch.nn.DataParallel(model) criterion = torch.nn.CrossEntropyLoss() optimizer = utils.get_optimizer(opt, params=model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, momentum=opt.momentum) trainer = AdversarialTrainer if opt.adversarial else Trainer trainer = trainer(opt=opt, model=model, optimizer=optimizer, val_metric_name='acc (%)', val_metric_obj='max') if hasattr(model, 'update_centers'): trainer.update_centers_eval = lambda: utils.update_centers_eval(model) loader = datasets.get_dataloaders(opt) # TODO a hacky way to load some dummy validation data opt.is_train = False val_loader = datasets.get_dataloaders(opt) opt.is_train = True if opt.load_model: trainer.load() # save init model trainer.save( epoch=trainer.start_epoch - 1, val_metric_value=trainer.best_val_metric, force_save=True ) utils.update_centers_eval(model) train(opt, n_epochs=opt.n_epochs, trainer=trainer, loader=loader, val_loader=val_loader, criterion=criterion, device=device)
fig.savefig(save_name, bbox_inches='tight') if __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' opt = TestParser().parse() # set up logger utils.set_logger(opt=opt, filename='test.log', filemode='a') logger = logging.getLogger() if opt.adversarial: from cleverhans.future.torch.attacks import \ fast_gradient_method, projected_gradient_descent net = models.get_model(opt) net = net.to(device) if opt.multi_gpu and device == 'cuda': net = torch.nn.DataParallel(net) net.load_state_dict( torch.load(os.path.join(opt.checkpoint_dir, 'net.pth'))['state_dict']) if hasattr(net, 'update_centers'): utils.update_centers_eval(net) net_head = torch.nn.Sequential(*list(net.children())[:-1]) loader = datasets.get_dataloaders(opt) raw_data, labels, activations = test(opt, net, loader) idx = np.random.permutation(10000)
def main(): opt = TrainParser().parse() # set up logger utils.set_logger(opt=opt, filename='train.log', filemode='w') if opt.seed: utils.make_deterministic(opt.seed) loader = datasets.get_dataloaders(opt) # TODO a hacky way to load some dummy validation data opt.is_train = False val_loader = datasets.get_dataloaders(opt) opt.is_train = True model = models.get_model(opt) model = model.to(device) modules, params = model.split(n_parts=opt.n_parts, mode=opt.split_mode) trainer_cls = AdversarialTrainer if opt.adversarial else Trainer output_layer = list(model.children())[-1] hidden_criterion = getattr(losses, opt.hidden_objective)(output_layer.phi, opt.n_classes) output_criterion = torch.nn.CrossEntropyLoss() optimizers, trainers = [], [] for i in range(1, opt.n_parts + 1): optimizers.append( utils.get_optimizer(opt, params=params[i - 1], lr=getattr(opt, 'lr{}'.format(i)), weight_decay=getattr( opt, 'weight_decay{}'.format(i)), momentum=getattr(opt, 'momentum{}'.format(i)))) trainer = trainer_cls( opt=opt, model=modules[i - 1], set_eval=modules[i - 2] if i > 1 else None, optimizer=optimizers[i - 1], val_metric_name=opt.hidden_objective if i < opt.n_parts else 'acc', val_metric_obj='max') trainers.append(trainer) if opt.load_model: if i < opt.n_parts: # load hidden layer trainers[i - 1].load('net_part{}.pth'.format(i)) else: # load output layer trainers[i - 1].load('net.pth') # save init model trainers[0].save(epoch=trainers[0].start_epoch - 1, val_metric_value=trainers[0].best_val_metric, model_name='net_part{}.pth'.format(1), force_save=True) # train the first hidden module train_hidden(opt, n_epochs=opt.n_epochs1, trainer=trainers[0], loader=loader, val_loader=val_loader, criterion=hidden_criterion, part_id=1, device=device) # train other hidden modules for i in range(2, opt.n_parts): # save init model trainers[i - 1].save(epoch=trainers[i - 1].start_epoch - 1, val_metric_value=trainers[i - 1].best_val_metric, model_name='net_part{}.pth'.format(i), force_save=True) # prepare centers utils.update_centers_eval(model) # exclude certain network part(s) from the graph to make things faster utils.exclude_during_backward(modules[i - 2]) train_hidden(opt, n_epochs=getattr(opt, 'n_epochs{}'.format(i)), trainer=trainers[i - 1], loader=loader, val_loader=val_loader, criterion=hidden_criterion, part_id=i, device=device) # save init model trainers[-1].save(epoch=trainers[-1].start_epoch - 1, val_metric_value=trainers[-1].best_val_metric, model_name='net.pth', force_save=True) # train output layer utils.update_centers_eval(model) utils.exclude_during_backward(modules[-2]) train_output(opt, n_epochs=getattr(opt, 'n_epochs{}'.format(opt.n_parts)), trainer=trainers[-1], loader=loader, val_loader=val_loader, criterion=output_criterion, part_id=opt.n_parts, device=device) utils.include_during_backward(modules[-2])