def full_training(args): if not os.path.isdir(args.expdir): os.makedirs(args.expdir) elif os.path.exists(args.expdir + '/results.npy'): return if 'ae' in args.task: os.mkdir(args.expdir + '/figs/') train_batch_size = args.train_batch_size // 4 if args.task == 'rot' else args.train_batch_size test_batch_size = args.test_batch_size // 4 if args.task == 'rot' else args.test_batch_size yield_indices = (args.task == 'inst_disc') datadir = args.datadir + args.dataset trainloader, valloader, num_classes = general_dataset_loader.prepare_data_loaders( datadir, image_dim=args.image_dim, yield_indices=yield_indices, train_batch_size=train_batch_size, test_batch_size=test_batch_size, train_on_10_percent=args.train_on_10, train_on_half_classes=args.train_on_half) _, testloader, _ = general_dataset_loader.prepare_data_loaders( datadir, image_dim=args.image_dim, yield_indices=yield_indices, train_batch_size=train_batch_size, test_batch_size=test_batch_size, ) args.num_classes = num_classes if args.task == 'rot': num_classes = 4 elif args.task == 'inst_disc': num_classes = args.low_dim if args.task == 'ae': net = models.AE([args.code_dim], image_dim=args.image_dim) elif args.task == 'jigsaw': net = JigsawModel(num_perms=args.num_perms, code_dim=args.code_dim, gray_prob=args.gray_prob, image_dim=args.image_dim) else: net = models.resnet26(num_classes, mlp_depth=args.mlp_depth, normalize=(args.task == 'inst_disc')) if args.task == 'inst_disc': train_lemniscate = LinearAverage(args.low_dim, trainloader.dataset.__len__(), args.nce_t, args.nce_m) train_lemniscate.cuda() args.train_lemniscate = train_lemniscate test_lemniscate = LinearAverage(args.low_dim, valloader.dataset.__len__(), args.nce_t, args.nce_m) test_lemniscate.cuda() args.test_lemniscate = test_lemniscate if args.source: try: old_net = torch.load(args.source) except: print("Falling back encoding") from functools import partial import pickle pickle.load = partial(pickle.load, encoding="latin1") pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") old_net = torch.load(args.source, map_location=lambda storage, loc: storage, pickle_module=pickle) # net.load_state_dict(old_net['net'].state_dict()) old_net = old_net['net'] if hasattr(old_net, "module"): old_net = old_net.module old_state_dict = old_net.state_dict() new_state_dict = net.state_dict() for key, weight in old_state_dict.items(): if 'linear' not in key: new_state_dict[key] = weight elif key == 'linears.0.weight' and weight.shape[0] == num_classes: new_state_dict['linears.0.0.weight'] = weight elif key == 'linears.0.bias' and weight.shape[0] == num_classes: new_state_dict['linears.0.0.bias'] = weight net.load_state_dict(new_state_dict) del old_net net = torch.nn.DataParallel(net).cuda() start_epoch = 0 if args.task in ['ae', 'inst_disc']: best_acc = np.inf else: best_acc = -1 results = np.zeros((4, start_epoch + args.nb_epochs)) net.cuda() cudnn.benchmark = True if args.task in ['ae']: args.criterion = nn.MSELoss() else: args.criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.wd) print("Start training") train_func = eval('utils_pytorch.train_' + args.task) test_func = eval('utils_pytorch.test_' + args.task) if args.test_first: with torch.no_grad(): test_func(0, valloader, net, best_acc, args, optimizer) for epoch in range(start_epoch, start_epoch + args.nb_epochs): utils_pytorch.adjust_learning_rate(optimizer, epoch, args) st_time = time.time() # Training and validation train_acc, train_loss = train_func(epoch, trainloader, net, args, optimizer) test_acc, test_loss, best_acc = test_func(epoch, valloader, net, best_acc, args, optimizer) # Record statistics results[0:2, epoch] = [train_loss, train_acc] results[2:4, epoch] = [test_loss, test_acc] np.save(args.expdir + '/results.npy', results) print('Epoch lasted {0}'.format(time.time() - st_time)) sys.stdout.flush() if (args.task == 'rot') and (train_acc >= 98) and args.early_stopping: break if args.task == 'inst_disc': args.train_lemniscate = None args.test_lemniscate = None else: best_net = torch.load(args.expdir + 'checkpoint.t7')['net'] if args.task in ['ae', 'inst_disc']: best_acc = np.inf else: best_acc = -1 final_acc, final_loss, _ = test_func(0, testloader, best_net, best_acc, args, None)
lemniscate = checkpoint['lemniscate'] best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] else: print('==> Building model..') net = models.__dict__['ResNet50'](low_dim=args.low_dim) # define leminiscate lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum) # define loss function criterion = NCACrossEntropy(torch.LongTensor(trainloader.dataset.targets)) if use_cuda: net.cuda() net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) lemniscate.cuda() criterion.cuda() cudnn.benchmark = True if args.test_only: acc = kNN(0, net, lemniscate, trainloader, testloader, 30, args.temperature) sys.exit(0) optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True) def adjust_learning_rate(optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = args.lr * (0.1 ** (epoch // 50)) print(lr) for param_group in optimizer.param_groups: param_group['lr'] = lr