def train_s2m2(base_loader, base_loader_test, val_loader, model, start_epoch, stop_epoch, params, tmp): if params.dct_status: channels = params.channels else: channels = 3 val_acc_best = 0.0 if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) if path.exists(params.checkpoint_dir + '/val_' + params.dataset + '.pt'): loader = torch.load(params.checkpoint_dir + '/val_' + params.dataset + '.pt') else: loader = [] for _, (x, _) in enumerate(val_loader): loader.append(x) torch.save(loader, params.checkpoint_dir + '/val_' + params.dataset + '.pt') def mixup_criterion(criterion, pred, y_a, y_b, lam): return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) criterion = nn.CrossEntropyLoss() if params.model == 'WideResNet28_10': rotate_classifier = nn.Sequential(nn.Linear(640, 4)) elif params.model == 'ResNet18': rotate_classifier = nn.Sequential(nn.Linear(512, 4)) rotate_classifier.cuda() if 'rotate' in tmp: print("loading rotate model") rotate_classifier.load_state_dict(tmp['rotate']) optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': rotate_classifier.parameters() }]) print("stop_epoch", start_epoch, stop_epoch) for epoch in range(start_epoch, stop_epoch): print('\nEpoch: %d' % epoch) model.train() train_loss = 0 rotate_loss = 0 correct = 0 total = 0 torch.cuda.empty_cache() print("inside base_loader: ", len(base_loader)) for batch_idx, (inputs, targets) in enumerate(base_loader): if use_gpu: inputs, targets = inputs.cuda(), targets.cuda() #print("shape of input: ", inputs.shape) lam = np.random.beta(params.alpha, params.alpha) f, outputs, target_a, target_b = model(inputs, targets, mixup_hidden=True, mixup_alpha=params.alpha, lam=lam) loss = mixup_criterion(criterion, outputs, target_a, target_b, lam) train_loss += loss.data.item() optimizer.zero_grad() loss.backward() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += ( lam * predicted.eq(target_a.data).cpu().sum().float() + (1 - lam) * predicted.eq(target_b.data).cpu().sum().float()) bs = inputs.size(0) inputs_ = [] targets_ = [] a_ = [] indices = np.arange(bs) np.random.shuffle(indices) split_size = int(bs / 4) for j in indices[0:split_size]: x90 = inputs[j].transpose(2, 1).flip(1) x180 = x90.transpose(2, 1).flip(1) x270 = x180.transpose(2, 1).flip(1) inputs_ += [inputs[j], x90, x180, x270] targets_ += [targets[j] for _ in range(4)] a_ += [ torch.tensor(0), torch.tensor(1), torch.tensor(2), torch.tensor(3) ] inputs = Variable(torch.stack(inputs_, 0)) targets = Variable(torch.stack(targets_, 0)) a_ = Variable(torch.stack(a_, 0)) if use_gpu: inputs = inputs.cuda() targets = targets.cuda() a_ = a_.cuda() rf, outputs = model(inputs) rotate_outputs = rotate_classifier(rf) rloss = criterion(rotate_outputs, a_) closs = criterion(outputs, targets) loss = (rloss + closs) / 2.0 rotate_loss += rloss.data.item() loss.backward() optimizer.step() if batch_idx % 50 == 0: print( '{0}/{1}'.format(batch_idx, len(base_loader)), 'Loss: %.3f | Acc: %.3f%% | RotLoss: %.3f ' % (train_loss / (batch_idx + 1), 100. * correct / total, rotate_loss / (batch_idx + 1))) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1): outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch)) torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile) model.eval() with torch.no_grad(): test_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(base_loader_test): if use_gpu: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs), Variable(targets) f, outputs = model.forward(inputs) loss = criterion(outputs, targets) test_loss += loss.data.item() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() print('Loss: %.3f | Acc: %.3f%%' % (test_loss / (batch_idx + 1), 100. * correct / total)) if params.dct_status: valmodel = BaselineFinetune(model_dict[params.model + '_dct'], params.train_n_way, params.n_shot, loss_type='dist') else: valmodel = BaselineFinetune(model_dict[params.model], params.train_n_way, params.n_shot, loss_type='dist') valmodel.n_query = 15 acc_all1, acc_all2, acc_all3 = [], [], [] for i, x in enumerate(loader): if params.dct_status: x = x.view(-1, channels, image_size_dct, image_size_dct) else: x = x.view(-1, channels, image_size, image_size) if use_gpu: x = x.cuda() with torch.no_grad(): f, scores = model(x) f = f.view(params.train_n_way, params.n_shot + valmodel.n_query, -1) scores = valmodel.set_forward_adaptation(f.cpu()) acc = [] for each_score in scores: pred = each_score.data.cpu().numpy().argmax(axis=1) y = np.repeat(range(5), 15) acc.append(np.mean(pred == y) * 100) acc_all1.append(acc[0]) acc_all2.append(acc[1]) acc_all3.append(acc[2]) print('Test Acc at 100= %4.2f%%' % (np.mean(acc_all1))) print('Test Acc at 200= %4.2f%%' % (np.mean(acc_all2))) print('Test Acc at 300= %4.2f%%' % (np.mean(acc_all3))) if np.mean(acc_all3) > val_acc_best: val_acc_best = np.mean(acc_all3) bestfile = os.path.join(params.checkpoint_dir, 'best.tar') torch.save( { 'epoch': epoch, 'state': model.state_dict(), 'rotate': rotate_classifier.state_dict() }, bestfile) return model
def train_baseline(base_loader, base_loader_test, val_loader, model, start_epoch, stop_epoch, params, tmp): if params.dct_status: channels = params.channels else: channels = 3 val_acc_best = 0.0 if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) if path.exists(params.checkpoint_dir + '/val_' + params.dataset + '.pt'): loader = torch.load(params.checkpoint_dir + '/val_' + params.dataset + '.pt') else: loader = [] for ii, (x, _) in enumerate(val_loader): loader.append(x) #print("head of train_dct: ", x.shape) torch.save(loader, params.checkpoint_dir + '/val_' + params.dataset + '.pt') criterion = nn.CrossEntropyLoss().cuda() optimizer = optim.Adam(model.parameters()) print("stop_epoch", start_epoch, stop_epoch) for epoch in range(start_epoch, stop_epoch): print('\nEpoch: %d' % epoch) model.train() train_loss = 0 reg_loss = 0 correct = 0 correct1 = 0.0 total = 0 for batch_idx, (input_var, target_var) in enumerate(base_loader): if use_gpu: input_var, target_var = input_var.cuda(), target_var.cuda() input_dct_var, target_var = Variable(input_var), Variable( target_var) f, outputs = model.forward(input_dct_var) loss = criterion(outputs, target_var) train_loss += loss.data.item() _, predicted = torch.max(outputs.data, 1) total += target_var.size(0) correct += predicted.eq(target_var.data).cpu().sum() optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % 50 == 0: print( '{0}/{1}'.format(batch_idx, len(base_loader)), 'Loss: %.3f | Acc: %.3f%% ' % (train_loss / (batch_idx + 1), 100. * correct / total)) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1): outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch)) torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile) model.eval() with torch.no_grad(): test_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(base_loader_test): if use_gpu: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs), Variable(targets) f, outputs = model.forward(inputs) loss = criterion(outputs, targets) test_loss += loss.data.item() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() print('Loss: %.3f | Acc: %.3f%%' % (test_loss / (batch_idx + 1), 100. * correct / total)) torch.cuda.empty_cache() valmodel = BaselineFinetune(model_dict[params.model], params.train_n_way, params.n_shot, loss_type='dist') valmodel.n_query = 15 acc_all1, acc_all2, acc_all3 = [], [], [] for i, x in enumerate(loader): # print("len of loader: ",len(loader)) # print("shape of x: ",x.shape) if params.dct_status: x = x.view(-1, channels, image_size_dct, image_size_dct) else: x = x.view(-1, channels, image_size, image_size) if use_gpu: x = x.cuda() with torch.no_grad(): f, scores = model(x) f = f.view(params.train_n_way, params.n_shot + valmodel.n_query, -1) scores = valmodel.set_forward_adaptation(f.cpu()) acc = [] for each_score in scores: pred = each_score.data.cpu().numpy().argmax(axis=1) y = np.repeat(range(5), 15) acc.append(np.mean(pred == y) * 100) acc_all1.append(acc[0]) acc_all2.append(acc[1]) acc_all3.append(acc[2]) print('Test Acc at 100= %4.2f%%' % (np.mean(acc_all1))) print('Test Acc at 200= %4.2f%%' % (np.mean(acc_all2))) print('Test Acc at 300= %4.2f%%' % (np.mean(acc_all3))) if np.mean(acc_all3) > val_acc_best: val_acc_best = np.mean(acc_all3) bestfile = os.path.join(params.checkpoint_dir, 'best.tar') torch.save({'epoch': epoch, 'state': model.state_dict()}, bestfile) return model