def outer_loop(self, batch, is_train): train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch( batch) loss_log = 0 acc_log = 0 grad_list = [] loss_list = [] for (train_input, train_target, test_input, test_target) in zip(train_inputs, train_targets, test_inputs, test_targets): with higher.innerloop_ctx(self.network, self.inner_optimizer, track_higher_grads=False) as (fmodel, diffopt): for step in range(self.args.n_inner): self.inner_loop(fmodel, diffopt, train_input, train_target) train_logit = fmodel(train_input) in_loss = F.cross_entropy(train_logit, train_target) test_logit = fmodel(test_input) outer_loss = F.cross_entropy(test_logit, test_target) loss_log += outer_loss.item() / self.batch_size with torch.no_grad(): acc_log += get_accuracy( test_logit, test_target).item() / self.batch_size if is_train: params = list(fmodel.parameters(time=-1)) in_grad = torch.nn.utils.parameters_to_vector( torch.autograd.grad(in_loss, params, create_graph=True)) outer_grad = torch.nn.utils.parameters_to_vector( torch.autograd.grad(outer_loss, params)) implicit_grad = self.neumann_approx( in_grad, outer_grad, params) grad_list.append(implicit_grad) loss_list.append(outer_loss.item()) if is_train: self.outer_optimizer.zero_grad() weight = torch.ones(len(grad_list)) weight = weight / torch.sum(weight) grad = mix_grad(grad_list, weight) grad_log = apply_grad(self.network, grad) self.outer_optimizer.step() return loss_log, acc_log, grad_log else: return loss_log, acc_log
def outer_loop(self, batch, is_train): self.network.zero_grad() train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch( batch) loss_log = 0 acc_log = 0 grad_list = [] loss_list = [] for (train_input, train_target, test_input, test_target) in zip(train_inputs, train_targets, test_inputs, test_targets): with higher.innerloop_ctx( self.network, self.inner_optimizer, track_higher_grads=is_train) as (fmodel, diffopt): for step in range(self.args.n_inner): self.inner_loop(fmodel, diffopt, train_input, train_target) test_logit = fmodel(test_input) outer_loss = F.cross_entropy(test_logit, test_target) loss_log += outer_loss.item() / self.batch_size with torch.no_grad(): acc_log += get_accuracy( test_logit, test_target).item() / self.batch_size if is_train: outer_grad = torch.autograd.grad(outer_loss, fmodel.parameters(time=0)) grad_list.append(outer_grad) loss_list.append(outer_loss.item()) if is_train: weight = torch.ones(len(grad_list)) weight = weight / torch.sum(weight) grad = mix_grad(grad_list, weight) grad_log = apply_grad(self.network, grad) self.outer_optimizer.step() return loss_log, acc_log, grad_log else: return loss_log, acc_log
def outer_loop(self, batch, is_train): self.network.zero_grad() train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch(batch) loss_log = 0 acc_log = 0 grad_list = [] loss_list = [] for (train_input, train_target, test_input, test_target) in zip(train_inputs, train_targets, test_inputs, test_targets): override = self.inner_optimizer if is_train else None with higher.innerloop_ctx(self.network, self.inner_optimizer, track_higher_grads=False) as (fmodel, diffopt): for step in range(self.args.n_inner): if is_train: index = np.random.permutation(np.arange(len(test_input)))[:10] train_input = test_input[index] train_target = test_target[index] self.inner_loop(fmodel, diffopt, train_input, train_target) with torch.no_grad(): test_logit = fmodel(test_input) outer_loss = F.cross_entropy(test_logit, test_target) loss_log += outer_loss.item()/self.batch_size loss_list.append(outer_loss.item()) acc_log += get_accuracy(test_logit, test_target).item()/self.batch_size if is_train: outer_grad = [] for p_0, p_T in zip(fmodel.parameters(time=0), fmodel.parameters(time=step)): outer_grad.append(-(p_T - p_0).detach()) grad_list.append(outer_grad) if is_train: weight = torch.ones(len(grad_list))/len(grad_list) grad = mix_grad(grad_list, weight) grad_log = apply_grad(self.network, grad) self.outer_optimizer.step() return loss_log, acc_log, grad_log else: return loss_log, acc_log
def outer_loop(self, batch, is_train): self.network.zero_grad() train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch(batch) loss_log = 0 acc_log = 0 grad_list = [] loss_list = [] for (train_input, train_target, test_input, test_target) in zip(train_inputs, train_targets, test_inputs, test_targets): context = torch.zeros(self.context_dim).cuda().requires_grad_() for step in range(self.args.n_inner): context = self.inner_loop(context, train_input, train_target, is_train) test_logit = self.cavia_forward(test_input, context) outer_loss = F.cross_entropy(test_logit, test_target) loss_log += outer_loss.item()/self.batch_size with torch.no_grad(): acc_log += get_accuracy(test_logit, test_target).item()/self.batch_size if is_train: outer_grad = torch.autograd.grad(outer_loss, self.network.parameters()) grad_list.append(outer_grad) loss_list.append(outer_loss.item()) if is_train: weight = torch.ones(len(grad_list)) weight = weight / torch.sum(weight) grad = mix_grad(grad_list, weight) grad_log = apply_grad(self.network, grad) self.outer_optimizer.step() return loss_log, acc_log, grad_log else: return loss_log, acc_log
def outer_loop(self, batch, reverse_dict_list, is_train): self.network.zero_grad() train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch( batch) loss_log = 0 acc_log = 0 grad_list = [] loss_list = [] for i, (train_input, train_target, test_input, test_target) in enumerate( zip(train_inputs, train_targets, test_inputs, test_targets)): self.network.init_decoder() inner_optimizer = torch.optim.SGD(self.network.decoder, lr=self.args.inner_lr) with higher.innerloop_ctx( self.network, inner_optimizer, track_higher_grads=is_train) as (fmodel, diffopt): fmodel(torch.zeros(1, 3, 80, 80).type(torch.float32).cuda()) # Convert numerical label to word label target = [[ reverse_dict_list[i][j] for j in range(self.args.num_way) ]] # Get transformed word embeddings and lambda word_proto = fmodel.word_embedding(target, is_train)[0].permute([1, 0]) train_input = fmodel(train_input) for step in range(self.args.n_inner): self.inner_loop(fmodel.decoder, diffopt, train_input, train_target, word_proto) test_logit = fmodel(test_input) test_logit = torch.matmul(test_logit, 2 * word_proto * fmodel.decoder[0]) \ - ((word_proto * fmodel.decoder[0])**2).sum(dim=0, keepdim=True) + fmodel.decoder[1] outer_loss = F.cross_entropy(test_logit, test_target) loss_log += outer_loss.item() / self.batch_size with torch.no_grad(): acc_log += get_accuracy( test_logit, test_target).item() / self.batch_size if is_train: params = fmodel.parameters(time=0) outer_grad = torch.autograd.grad(outer_loss, params) grad_list.append(outer_grad) loss_list.append(outer_loss.item()) # self._lambda = _lambda.detach() if is_train: weight = torch.ones(len(grad_list)) weight = weight / torch.sum(weight) grad = mix_grad(grad_list, weight) grad_log = apply_grad(self.network, grad) self.outer_optimizer.step() return loss_log, acc_log, grad_log else: return loss_log, acc_log
def outer_loop(self, batch, is_train): train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch( batch) loss_log = 0 dice_log = 0 grad_list = [] loss_list = [] for (train_input, train_target, test_input, test_target) in zip(train_inputs, train_targets, test_inputs, test_targets): with higher.innerloop_ctx(self.network, self.inner_optimizer, track_higher_grads=False) as (fmodel, diffopt): for step in range(self.args.n_inner): self.inner_loop(fmodel, diffopt, train_input, train_target) train_logit = fmodel(train_input) in_loss = bce_dice_loss(train_logit, train_target) test_logit = fmodel(test_input) outer_loss = bce_dice_loss(test_logit, test_target) loss_log += outer_loss.item() / self.batch_size out_cut = np.copy(test_logit.data.cpu().numpy()) out_cut[np.nonzero(out_cut < 0.3)] = 0.0 #threshold out_cut[np.nonzero(out_cut >= 0.3)] = 1.0 with torch.no_grad(): dice_log += dice_coef(out_cut, test_target.data.cpu().numpy()).item( ) / self.batch_size if is_train: params = list(fmodel.parameters(time=-1)) in_grad = torch.nn.utils.parameters_to_vector( torch.autograd.grad(in_loss, params, create_graph=True)) outer_grad = torch.nn.utils.parameters_to_vector( torch.autograd.grad(outer_loss, params)) implicit_grad = self.cg(in_grad, outer_grad, params) grad_list.append(implicit_grad) loss_list.append(outer_loss.item()) if is_train: self.outer_optimizer.zero_grad() weight = torch.ones(len(grad_list)) weight = weight / torch.sum(weight) grad = mix_grad(grad_list, weight) grad_log = apply_grad(self.network, grad) self.outer_optimizer.step() return loss_log, dice_log, grad_log else: return loss_log, dice_log