def hyper_step(train_grad): # def hyper_step(): """Estimate the hypergradient, and take an update with it. """ zero_hypergrad(get_hyper_train) # xentropy_loss, regularized_loss = train_loss_func() # train_grad = grad(regularized_loss, model.parameters(), create_graph=True) d_train_loss_d_w = gather_flat_grad(train_grad) # Compute gradients of the validation loss w.r.t. the weights/hypers num_weights, num_hypers = sum(p.numel() for p in model.parameters()), sum(p.numel() for p in get_hyper_train()) d_val_loss_d_theta = torch.zeros(num_weights).cuda() # model.eval() model.zero_grad() val_loss = val_loss_func(hyperval_data) # eval() is used in here d_val_loss_d_theta += gather_flat_grad(grad(val_loss, model.parameters())) # Do we need create_graph=True or retain_graph=True ? # compute d / d lambda (partial Lv / partial w * partial Lt / partial w) # = (partial Lv / partial w * partial^2 Lt / (partial w partial lambda)) # indirect_grad = gather_flat_grad(grad(d_train_loss_d_w, get_hyper_train(), grad_outputs=preconditioner.view(-1))) hypergrad = gather_flat_grad(grad(d_train_loss_d_w, get_hyper_train(), grad_outputs=d_val_loss_d_theta.detach().view(-1))) get_hyper_train()[0].grad = -hypergrad # get_hyper_train()[0].grad = hypergrad # store_hypergrad(get_hyper_train, hypergrad) return val_loss, hypergrad.norm()
def hyper_step(get_hyper_train, model, val_loss_func, val_loader, d_train_loss_d_w, elementary_lr, use_reg, args): """Estimate the hypergradient, and take an update with it. :param get_hyper_train: A function which returns the hyperparameters we want to tune. :param model: A function which returns the elementary parameters we want to tune. :param val_loss_func: A function which takes input x and output y, then returns the scalar valued loss. :param val_loader: A generator for input x, output y tuples. :param d_train_loss_d_w: The derivative of the training loss with respect to elementary parameters. :param hyper_optimizer: The optimizer which updates the hyperparameters. :return: The scalar valued validation loss, the hyperparameter norm, and the hypergradient norm. """ zero_hypergrad(get_hyper_train) d_train_loss_d_w = gather_flat_grad(d_train_loss_d_w) # Compute gradients of the validation loss w.r.t. the weights/hypers num_weights, num_hypers = sum(p.numel() for p in model.parameters()), sum(p.numel() for p in get_hyper_train()) d_val_loss_d_theta, direct_grad = torch.zeros(num_weights).cuda(), torch.zeros(num_hypers).cuda() model.train(), model.zero_grad() for batch_idx, (x, y) in enumerate(val_loader): val_loss = val_loss_func(x, y) d_val_loss_d_theta += gather_flat_grad(grad(val_loss, model.parameters(), retain_graph=use_reg)) if use_reg: direct_grad += gather_flat_grad(grad(val_loss, get_hyper_train())) direct_grad[direct_grad != direct_grad] = 0 break # Initialize the preconditioner and counter if not args.use_cg: preconditioner = neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, args.num_neumann_terms) else: def A_vector_multiply_func(vec): p1 = d_val_loss_d_theta.view(1, -1) @ vec.view(-1, 1) p2 = d_val_loss_d_theta.view(-1, 1) @ p1 return p2.view(1, -1, 1) preconditioner, _ = cg_batch(A_vector_multiply_func, d_val_loss_d_theta.view(1, -1, 1)) # conjugate_grad(A_vector_multiply_func, d_val_loss_d_theta) # compute d / d lambda (partial Lv / partial w * partial Lt / partial w) # = (partial Lv / partial w * partial^2 Lt / (partial w partial lambda)) indirect_grad = gather_flat_grad(grad(d_train_loss_d_w, get_hyper_train(), grad_outputs=preconditioner.view(-1))) hypergrad = direct_grad + indirect_grad store_hypergrad(get_hyper_train, hypergrad) return val_loss, hypergrad.norm()
def A_vector_multiply_func(vec): val = gather_flat_grad( grad(d_train_loss_d_w, model.parameters(), grad_outputs=vec.view(-1), retain_graph=True)) # val_2 = gather_flat_grad(grad(d_train_loss_d_w, model.parameters(), grad_outputs=val.view(-1), retain_graph=True)) # return val_2.view(1, -1, 1) return val.view(1, -1, 1)
def hyper_step(elementary_lr, do_true_inverse=False): """Estimate the hypergradient, and take an update with it. """ zero_hypergrad(get_hyper_train) num_weights, num_hypers = sum(p.numel() for p in model.parameters()), sum(p.numel() for p in get_hyper_train()) d_train_loss_d_w = torch.zeros(num_weights).cuda() model.train(), model.zero_grad() # First compute train loss on a batch for batch_idx, (x, y) in enumerate(train_loader): train_loss, _ = train_loss_func(x, y) optimizer.zero_grad() d_train_loss_d_w += gather_flat_grad(grad(train_loss, model.parameters(), create_graph=True)) break optimizer.zero_grad() # Compute gradients of the validation loss w.r.t. the weights d_val_loss_d_theta, direct_grad = torch.zeros(num_weights).cuda(), torch.zeros(num_hypers).cuda() model.train(), model.zero_grad() for batch_idx, (x, y) in enumerate(val_loader): val_loss = val_loss_func(x, y) optimizer.zero_grad() d_val_loss_d_theta += gather_flat_grad(grad(val_loss, model.parameters(), retain_graph=False)) break # Initialize the preconditioner and counter preconditioner = d_val_loss_d_theta # Neumann series to do hessian inversion preconditioner = neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, args.num_neumann_terms, model) # compute d / d lambda (partial Lv / partial w * partial Lt / partial w) # = (partial Lv / partial w * partial^2 Lt / (partial w partial lambda)) indirect_grad = gather_flat_grad( grad(d_train_loss_d_w, get_hyper_train(), grad_outputs=preconditioner.view(-1))) # Direct grad is zero here due to no data augmentation for val data. hypergrad = direct_grad + indirect_grad zero_hypergrad(get_hyper_train) store_hypergrad(get_hyper_train, -hypergrad) return val_loss, hypergrad.norm()
def neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, num_neumann_terms, model): preconditioner = d_val_loss_d_theta.detach() counter = preconditioner # Do the fixed point iteration to approximate the vector-inverseHessian product for i in range(num_neumann_terms): old_counter = counter # This increments counter to counter * (I - hessian) = counter - counter * hessian hessian_term = gather_flat_grad( grad(d_train_loss_d_w, model.parameters(), grad_outputs=counter.contiguous().view(-1), retain_graph=True)) counter = old_counter - elementary_lr * hessian_term preconditioner = preconditioner + counter return elementary_lr * preconditioner
def KFAC_optimize(args, model, train_loader, val_loader, hyper_optimizer, kfac_opt, KFAC_damping, epoch_h): # set up placeholder for the partial derivative in each batch total_d_val_loss_d_lambda = torch.zeros( get_hyper_train(args, model).size(0)) if args.cuda: total_d_val_loss_d_lambda = total_d_val_loss_d_lambda.cuda() ######################## Calculate v1 from the paper, i.e. dL_v / dw num_weights = sum(p.numel() for p in model.parameters()) d_val_loss_d_theta = torch.zeros(num_weights).cuda() model.train() for batch_idx, (x, y) in enumerate(val_loader): model.zero_grad() x, y = prepare_data(args, x, y) val_loss, _ = batch_loss(args, model, x, y, model, val_loss_func) val_loss_grad = grad(val_loss, model.parameters()) d_val_loss_d_theta += gather_flat_grad(val_loss_grad) if batch_idx >= args.val_batch_num: break d_val_loss_d_theta /= ( batch_idx + 1 ) # TODO (@Mo): This is very bad, because it does not account for a potentially uneven batch at the end ######################## Calculate preconditioner, i.e. v1*(inverse Hessian approximation) [orange term in Figure 2] assert args.hessian == 'KFAC', f"Passed {args.hessian}, not a valid choice. Need to choose KFAC" # model.zero_grad() flat_pre_conditioner = torch.zeros(num_weights).cuda() for batch_idx, (x, y) in enumerate(train_loader): model.train() model.zero_grad(), hyper_optimizer.zero_grad() x, y = prepare_data(args, x, y) train_loss, _ = batch_loss(args, model, x, y, model, train_loss_func) # TODO (JON): Probably don't recompute - use create_graph and retain_graph? d_train_loss_d_theta = grad(train_loss, model.parameters(), create_graph=True) flat_d_train_loss_d_theta = gather_flat_grad(d_train_loss_d_theta) current = 0 for m in model.modules(): if m.__class__.__name__ in ['Linear', 'Conv2d']: # kfac_opt.zero_grad() if m.__class__.__name__ == 'Conv2d': size0, size1 = m.weight.size(0), m.weight.view( m.weight.size(0), -1).size(1) else: size0, size1 = m.weight.size(0), m.weight.size(1) mod_size1 = size1 + 1 if m.bias is not None else size1 shape = (size0, (mod_size1)) size = size0 * mod_size1 pre_conditioner = kfac_opt._get_natural_grad( m, d_val_loss_d_theta[current:current + size].view(shape), KFAC_damping) flat_pre_conditioner[current:current + size] = gather_flat_grad(pre_conditioner) current += size model.zero_grad(), hyper_optimizer.zero_grad() flat_d_train_loss_d_theta.backward(flat_pre_conditioner) total_d_val_loss_d_lambda -= get_hyper_train(args, model).grad if batch_idx >= args.train_batch_num: break total_d_val_loss_d_lambda /= ( batch_idx + 1 ) # TODO (@Mo): This is very bad, because it does not account for a potentially uneven batch at the end ##################### Compute direct gradient of val loss w.r.t. lambda. This is usually 0 direct_d_val_loss_d_lambda = torch.zeros( get_hyper_train(args, model).size(0)) if args.cuda: direct_d_val_loss_d_lambda = direct_d_val_loss_d_lambda.cuda() model.train() for batch_idx, (x_val, y_val) in enumerate(val_loader): model.zero_grad(), hyper_optimizer.zero_grad() x_val, y_val = prepare_data(args, x_val, y_val) val_loss, _ = batch_loss(args, model, x_val, y_val, model, val_loss_func) val_loss_grad = grad(val_loss, get_hyper_train(args, model), allow_unused=True) if val_loss_grad is not None and val_loss_grad[0] is not None: direct_d_val_loss_d_lambda += gather_flat_grad(val_loss_grad) else: break if batch_idx >= args.val_batch_num: break direct_d_val_loss_d_lambda /= ( batch_idx + 1 ) # TODO (@Mo): This is very bad, because it does not account for a potentially uneven batch at the end get_hyper_train( args, model).grad = direct_d_val_loss_d_lambda + total_d_val_loss_d_lambda print("weight={}, update={}".format( get_hyper_train(args, model).norm(), get_hyper_train(args, model).grad.norm())) hyper_optimizer.step() model.zero_grad(), hyper_optimizer.zero_grad() return get_hyper_train(args, model), get_hyper_train(args, model).grad
def hyper_step(elementary_lr, args, model, train_loader, val_loader, augment_net, reweighting_net, optimizer, use_reg, reg_anneal_epoch, stop_reg_epoch, graph_iter, device, do_true_inverse=False): # hyper_step(get_hyper_train, model, val_loss_func, val_loader, old_d_train_loss_d_w, elementary_lr, use_reg, args, train_loader, train_loss_func, elementary_optimizer): """ Estimate the hypergradient, and take an update with it. :param get_hyper_train: A function which returns the hyperparameters we want to tune. :param model: A function which returns the elementary parameters we want to tune. :param val_loss_func: A function which takes input x and output y, then returns the scalar valued loss. :param val_loader: A generator for input x, output y tuples. :param d_train_loss_d_w: The derivative of the training loss with respect to elementary parameters. :param hyper_optimizer: The optimizer which updates the hyperparameters. :return: The scalar valued validation loss, the hyperparameter norm, and the hypergradient norm. """ zero_hypergrad(get_hyper_train, args, model, augment_net, reweighting_net) num_weights, num_hypers = sum(p.numel() for p in model.parameters()), sum( p.numel() for p in get_hyper_train(args, model, augment_net, reweighting_net)) print(f"num_weights : {num_weights}, num_hypers : {num_hypers}") # d_train_loss_d_w = gather_flat_grad(d_train_loss_d_w) # TODO: Commented this out! d_train_loss_d_w = torch.zeros(num_weights).to(device) model.train(), model.zero_grad() for batch_idx, (x, y) in enumerate(train_loader): train_loss, _, graph_iter = train_loss_func(x, y, args, model, augment_net, reweighting_net, graph_iter, device) optimizer.zero_grad() d_train_loss_d_w += gather_flat_grad( grad(train_loss, model.parameters(), create_graph=True)) break # TODO (@Mo): Huh? optimizer.zero_grad() # Compute gradients of the validation loss w.r.t. the weights/hypers d_val_loss_d_theta, direct_grad = torch.zeros(num_weights).to( device), torch.zeros(num_hypers).to(device) model.train(), model.zero_grad() for batch_idx, (x, y) in enumerate(val_loader): val_loss = val_loss_func(x, y, args, model, augment_net, use_reg, reg_anneal_epoch, stop_reg_epoch, device) optimizer.zero_grad() d_val_loss_d_theta += gather_flat_grad( grad(val_loss, model.parameters(), retain_graph=use_reg)) if use_reg: direct_grad += gather_flat_grad( grad(val_loss, get_hyper_train(), allow_unused=True)) direct_grad[direct_grad != direct_grad] = 0 break # TODO (@Mo): Huh? # Initialize the preconditioner and counter preconditioner = d_val_loss_d_theta if do_true_inverse: hessian = torch.zeros(num_weights, num_weights).to(device) for i in range(num_weights): hess_row = gather_flat_grad( grad(d_train_loss_d_w[i], model.parameters(), retain_graph=True)) hessian[i] = hess_row # hessian[-i] = hess_row ''' hessian = hessian.t() final_hessian = torch.zeros(num_weights, num_weights).to(device) for i in range(num_weights): final_hessian[-i] = hessian[i] hessian = final_hessian ''' # hessian = hessian #hessian @ hessian # chol = torch.cholesky(hessian.view(1, num_weights, num_weights))[0] + 1e-3*torch.eye(num_weights).to(device) inv_hessian = torch.pinverse(hessian) # inv_hessian = inv_hessian @ inv_hessian preconditioner = d_val_loss_d_theta @ inv_hessian elif not args.use_cg: preconditioner = neumann_hyperstep_preconditioner( d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, args.num_neumann_terms, model) else: def A_vector_multiply_func(vec): val = gather_flat_grad( grad(d_train_loss_d_w, model.parameters(), grad_outputs=vec.view(-1), retain_graph=True)) # val_2 = gather_flat_grad(grad(d_train_loss_d_w, model.parameters(), grad_outputs=val.view(-1), retain_graph=True)) # return val_2.view(1, -1, 1) return val.view(1, -1, 1) if args.num_neumann_terms > 0: preconditioner, _ = cg_batch(A_vector_multiply_func, d_val_loss_d_theta.view(1, -1, 1), maxiter=args.num_neumann_terms) if args.save_hessian and do_true_inverse: ''' if do_true_inverse: name = 'true_inv' elif args.use_cg: name = 'cg' else: name = 'neumann_' + str(args.num_neumann_terms) ''' save_hessian(inv_hessian, name='true_inv') new_hessian = torch.zeros(inv_hessian.shape).to(device) for param_group in optimizer.param_groups: cur_step_size = param_group['step_size'] cur_bias_correction = param_group['bias_correction'] print(f'size: {cur_step_size}') break for i in range(10): hess_term = torch.eye(inv_hessian.shape[0]).to(device) norm_1, norm_2 = torch.norm(torch.eye( inv_hessian.shape[0]).to(device), p=2), torch.norm(hessian, p=2) for j in range(i): # norm_2 = torch.norm(hessian@hessian, p=2) hess_term = hess_term @ ( torch.eye(inv_hessian.shape[0]).to(device) - norm_1 / norm_2 * hessian) new_hessian += hess_term # (torch.eye(inv_hessian.shape[0]).to(device) - elementary_lr*0.1*hessian) # if (i+1) % 10 == 0 or i == 0: save_hessian(new_hessian, name='neumann_' + str(i)) # conjugate_grad(A_vector_multiply_func, d_val_loss_d_theta) # compute d / d lambda (partial Lv / partial w * partial Lt / partial w) # = (partial Lv / partial w * partial^2 Lt / (partial w partial lambda)) indirect_grad = gather_flat_grad( grad(d_train_loss_d_w, get_hyper_train(args, model, augment_net, reweighting_net), grad_outputs=preconditioner.view(-1))) hypergrad = direct_grad + indirect_grad zero_hypergrad(get_hyper_train, args, model, augment_net, reweighting_net) store_hypergrad(get_hyper_train, -hypergrad, args, model, augment_net, reweighting_net) # get_hyper_train()[0].grad = hypergrad return val_loss, hypergrad.norm(), graph_iter
def main(): # Loop over epochs. lr = args.lr start_time = time.time() train_epoch = 0 global_step = 0 # At any point you can hit Ctrl + C to break out of training early. try: # param_optimizer = optim.SGD(model.parameters(), lr=args.lr) # param_optimizer = optim.RMSprop(model.parameters(), lr=args.lr) param_optimizer = optim.Adam(model.parameters(), lr=args.lr) hyper_optimizer = optim.Adam(get_hyper_train(), lr=args.hyper_lr) while train_epoch < args.epochs: epoch_start_time = time.time() model.train() xentropy_loss, regularized_loss = train_loss_func() # ---------Zero grad------------ current_index = 0 for p in model.parameters(): p_num_params = np.prod(p.shape) if p.grad is not None: p.grad = p.grad * 0 # Explicitly zeroing the gradients -- why is this required? Why not model.zero_grad() ? current_index += p_num_params param_optimizer.zero_grad() # -----End of zero grad--------- train_grad = grad(regularized_loss, model.parameters(), create_graph=True) hyper_optimizer.zero_grad() val_loss, grad_norm = hyper_step(train_grad) # val_loss, grad_norm = hyper_step() hyper_optimizer.step() # val_loss = torch.zeros(1) # Replace the original gradient for the elementary optimizer step. current_index = 0 flat_train_grad = gather_flat_grad(train_grad) for p in model.parameters(): p_num_params = np.prod(p.shape) p.grad = flat_train_grad[current_index: current_index + p_num_params].view(p.shape) current_index += p_num_params param_optimizer.step() cur_loss = xentropy_loss elapsed = time.time() - start_time if global_step % 100 == 0: val_loss = evaluate(val_data, test_batch_size) test_loss = evaluate(test_data, test_batch_size) hparam_dict = make_hparam_dict() iteration_dict = { 'iteration': global_step, 'time': time.time() - start_time, 'train_loss': cur_loss.item(), 'val_loss': val_loss, 'train_ppl': math.exp(cur_loss.item()), 'val_ppl': math.exp(val_loss), 'test_loss': test_loss, 'test_ppl': math.exp(test_loss), **hparam_dict } logger.write('iteration', iteration_dict) hparam_string = ' | '.join(['{}: {}'.format(key, value) for (key, value) in hparam_dict.items()]) print('| epoch {:3d} | step {} | lr {:05.5f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f} | val_loss: {:6.2f} | val ppl: {:6.2f} | test_loss: {:6.2f} | test ppl: {:6.2f} | {}'.format( train_epoch, global_step, param_optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss), val_loss, math.exp(val_loss), test_loss, math.exp(test_loss), hparam_string)) # print('| epoch {:3d} | batches | lr {:05.5f} | ms/batch {:5.2f} | ' # 'loss {:5.2f} | ppl {:8.2f} | val_loss: {:6.2f} | val ppl: {:6.2f} | wdecay mean: {:6.4e} | wdecay std: {:6.4e} | wdecay: {}'.format( # train_epoch, param_optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, # cur_loss, math.exp(cur_loss), val_loss, math.exp(val_loss), torch.exp(model.weight_decay).mean().item(), # torch.exp(model.weight_decay).std().item(), model.weight_decay)) start_time = time.time() global_step += 1 train_epoch += 1 except KeyboardInterrupt: print('-' * 89) print('Exiting from training early') sys.stdout.flush()
def experiment(args): # Setup the random seeds torch.manual_seed(args.seed) np.random.seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) # Load the baseline model args.load_baseline_checkpoint = '/h/lorraine/PycharmProjects/CG_IFT_test/baseline_checkpoints/cifar10_resnet18_sgdm_lr0.1_wd0.0005_aug1.pt' args.load_finetune_checkpoint = None # TODO: Make it load the augment net if this is provided model, train_loader, val_loader, test_loader, augment_net, reweighting_net, checkpoint = get_models(args) # Load the logger from train_augment_net_multiple import load_logger, get_id csv_logger, test_id = load_logger(args) args.save_loc = './finetuned_checkpoints/' + get_id(args) # Hyperparameter access functions def get_hyper_train(): # return torch.cat([p.view(-1) for p in augment_net.parameters()]) if args.use_augment_net and args.use_reweighting_net: return list(augment_net.parameters()) + list(reweighting_net.parameters()) elif args.use_augment_net: return augment_net.parameters() elif args.use_reweighting_net: return reweighting_net.parameters() def get_hyper_train_flat(): if args.use_augment_net and args.use_reweighting_net: return torch.cat([torch.cat([p.view(-1) for p in augment_net.parameters()]), torch.cat([p.view(-1) for p in reweighting_net.parameters()])]) elif args.use_reweighting_net: return torch.cat([p.view(-1) for p in reweighting_net.parameters()]) elif args.use_augment_net: return torch.cat([p.view(-1) for p in augment_net.parameters()]) # Setup the optimizers if args.load_baseline_checkpoint is not None: args.lr = args.lr * 0.2 * 0.2 * 0.2 optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=args.wdecay) scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) # [60, 120, 160] # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) hyper_optimizer = optim.Adam(get_hyper_train(), lr=1e-3) # Adam(get_hyper_train()) hyper_scheduler = MultiStepLR(hyper_optimizer, milestones=[40, 100, 140], gamma=0.2) graph_iter = 0 def train_loss_func(x, y): x, y = x.cuda(), y.cuda() reg = 0. if args.use_augment_net: # old_x = x x = augment_net(x, class_label=y) '''num_sample = 10 xs =torch.zeros(num_sample, x.shape[0], x.shape[1], x.shape[2], x.shape[3]).cuda() for i in range(num_sample): xs[i] = augment_net(x, class_label=y) xs_diffs = (torch.mean(xs, dim=0) - old_x) ** 2 diff_loss = torch.mean(xs_diffs) entrop_loss = -torch.mean(torch.std(xs, dim=0) ** 2) reg = 10 * diff_loss + entrop_loss''' pred = model(x) xentropy_loss = F.cross_entropy(pred, y, reduction='none') if args.use_reweighting_net: loss_weights = reweighting_net(x) # TODO: Or reweighting_net(augment_x) ?? loss_weights = loss_weights.squeeze() loss_weights = F.sigmoid(loss_weights / 10.0 ) * 2.0 + 0.1 # loss_weights = (loss_weights - torch.mean(loss_weights)) / torch.std(loss_weights) # loss_weights = F.softmax(loss_weights) # loss_weights = loss_weights * args.batch_size # TODO: Want loss_weight vs x_entropy_loss nonlocal graph_iter if graph_iter % 100 == 0: import matplotlib.pyplot as plt np_loss = xentropy_loss.data.cpu().numpy() np_weight = loss_weights.data.cpu().numpy() for i in range(10): class_indices = (y == i).cpu().numpy() class_indices = [val*ind for val, ind in enumerate(class_indices) if val != 0] plt.scatter(np_loss[class_indices], np_weight[class_indices], alpha=0.5, label=str(i)) # plt.scatter((xentropy_loss*loss_weights).data.cpu().numpy(), loss_weights.data.cpu().numpy(), alpha=0.5, label='weighted') # print(np_loss) plt.ylim([np.min(np_weight) / 2.0, np.max(np_weight) * 2.0]) plt.xlim([np.min(np_loss) / 2.0, np.max(np_loss) * 2.0]) plt.yscale('log') plt.xscale('log') plt.axhline(1.0, c='k') plt.ylabel("loss_weights") plt.xlabel("xentropy_loss") plt.legend() plt.savefig("images/aaaa_lossWeightvsEntropy.pdf") plt.clf() xentropy_loss = xentropy_loss * loss_weights graph_iter += 1 xentropy_loss = xentropy_loss.mean() + reg return xentropy_loss, pred use_reg = args.use_augment_net reg_anneal_epoch = 0 stop_reg_epoch = 200 if args.reg_weight == 0: use_reg = False def val_loss_func(x, y): x, y = x.cuda(), y.cuda() pred = model(x) xentropy_loss = F.cross_entropy(pred, y) reg = 0 if args.use_augment_net: if use_reg: num_sample = 10 xs = torch.zeros(num_sample, x.shape[0], x.shape[1], x.shape[2], x.shape[3]).cuda() for i in range(num_sample): xs[i] = augment_net(x, class_label=y) xs_diffs = (torch.abs(torch.mean(xs, dim=0) - x)) diff_loss = torch.mean(xs_diffs) stds = torch.std(xs, dim=0) entrop_loss = -torch.mean(stds) # TODO : Remember to add direct grad back in to hyper_step reg = args.reg_weight * (diff_loss + entrop_loss) else: reg = 0 # reg *= (args.num_finetune_epochs - reg_anneal_epoch) / (args.num_finetune_epochs + 2) if reg_anneal_epoch >= stop_reg_epoch: reg *= 0 return xentropy_loss + reg def test(loader, do_test_augment=True, num_augment=10): model.eval() # Change model to 'eval' mode (BN uses moving mean/var). correct, total = 0., 0. losses = [] for images, labels in loader: images, labels = images.cuda(), labels.cuda() with torch.no_grad(): pred = model(images) if do_test_augment: if args.use_augment_net and args.num_neumann_terms >= 0: shape_0, shape_1 = pred.shape[0], pred.shape[1] pred = pred.view(1, shape_0, shape_1) # Batch size, num_classes for _ in range(num_augment): pred = torch.cat((pred, model(augment_net(images)).view(1, shape_0, shape_1))) pred = torch.mean(pred, dim=0) xentropy_loss = F.cross_entropy(pred, labels) losses.append(xentropy_loss.item()) pred = torch.max(pred.data, 1)[1] total += labels.size(0) correct += (pred == labels).sum().item() avg_loss = float(np.mean(losses)) acc = correct / total model.train() return avg_loss, acc # print(f"Initial Val Loss: {test(val_loader)}") # print(f"Initial Test Loss: {test(test_loader)}") init_time = time.time() val_loss, val_acc = test(val_loader) test_loss, test_acc = test(test_loader) print(f"Initial Val Loss: {val_loss, val_acc}") print(f"Initial Test Loss: {test_loss, test_acc}") iteration = 0 for epoch in range(0, args.num_finetune_epochs): reg_anneal_epoch = epoch xentropy_loss_avg = 0. total_val_loss, val_loss = 0., 0. correct = 0. total = 0. weight_norm, grad_norm = .0, .0 progress_bar = tqdm(train_loader) num_tune_hyper = 45000 / 5000 # 1/5th the val data as train data hyper_num = 0 for i, (images, labels) in enumerate(progress_bar): progress_bar.set_description('Finetune Epoch ' + str(epoch)) images, labels = images.cuda(), labels.cuda() # pred = model(images) xentropy_loss, pred = train_loss_func(images, labels) # F.cross_entropy(pred, labels) xentropy_loss_avg += xentropy_loss.item() current_index = 0 for p in model.parameters(): p_num_params = np.prod(p.shape) if p.grad is not None: p.grad = p.grad * 0 current_index += p_num_params # optimizer.zero_grad() train_grad = grad(xentropy_loss, model.parameters(), create_graph=True) # if args.num_neumann_terms >= 0: # if this is less than 0, then don't do hyper_steps if i % num_tune_hyper == 0: cur_lr = 1.0 for param_group in optimizer.param_groups: cur_lr = param_group['lr'] break val_loss, grad_norm = hyper_step(get_hyper_train, model, val_loss_func, val_loader, train_grad, cur_lr, use_reg, args) hyper_optimizer.step() weight_norm = get_hyper_train_flat().norm() total_val_loss += val_loss.item() hyper_num += 1 # Replace the original gradient for the elementary optimizer step. current_index = 0 flat_train_grad = gather_flat_grad(train_grad) for p in model.parameters(): p_num_params = np.prod(p.shape) # if p.grad is not None: p.grad = flat_train_grad[current_index: current_index + p_num_params].view(p.shape) current_index += p_num_params optimizer.step() iteration += 1 # Calculate running average of accuracy pred = torch.max(pred.data, 1)[1] total += labels.size(0) correct += (pred == labels.data).sum().item() accuracy = correct / total progress_bar.set_postfix( train='%.4f' % (xentropy_loss_avg / (i + 1)), val='%.4f' % (total_val_loss / max(hyper_num, 1)), acc='%.4f' % accuracy, weight='%.3f' % weight_norm, update='%.3f' % grad_norm ) if i % (num_tune_hyper ** 2) == 0 and args.use_augment_net: from train_augment_net_graph import save_images if args.do_diagnostic: save_images(images, labels, augment_net, args) saver(epoch, model, optimizer, augment_net, reweighting_net, hyper_optimizer, args.save_loc) val_loss, val_acc = test(val_loader) csv_logger.writerow({'epoch': str(epoch), 'train_loss': str(xentropy_loss_avg / (i + 1)), 'train_acc': str(accuracy), 'val_loss': str(val_loss), 'val_acc': str(val_acc), 'test_loss': str(test_loss), 'test_acc': str(test_acc), 'run_time': time.time() - init_time, 'iteration': iteration}) val_loss, val_acc = test(val_loader) test_loss, test_acc = test(test_loader) tqdm.write('val loss: {:6.4f} | val acc: {:6.4f} | test loss: {:6.4f} | test_acc: {:6.4f}'.format( val_loss, val_acc, test_loss, test_acc)) scheduler.step(epoch) # , hyper_scheduler.step(epoch) csv_logger.writerow({'epoch': str(epoch), 'train_loss': str(xentropy_loss_avg / (i + 1)), 'train_acc': str(accuracy), 'val_loss': str(val_loss), 'val_acc': str(val_acc), 'test_loss': str(test_loss), 'test_acc': str(test_acc), 'run_time': time.time() - init_time, 'iteration': iteration})