def update(self, minibatches): objective = 0 for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches): lam = np.random.beta(self.hparams["mixup_alpha"], self.hparams["mixup_alpha"]) x = lam * xi + (1 - lam) * xj predictions = self.predict(x) objective += lam * F.cross_entropy(predictions, yi) objective += (1 - lam) * F.cross_entropy(predictions, yj) objective /= len(minibatches) self.optimizer.zero_grad() objective.backward() self.optimizer.step() return {'loss': objective.item()}
def update(self, minibatches): """ Terms being computed: * Li = Loss(xi, yi, params) * Gi = Grad(Li, params) * Lj = Loss(xj, yj, Optimizer(params, grad(Li, params))) * Gj = Grad(Lj, params) * params = Optimizer(params, Grad(Li + beta * Lj, params)) * = Optimizer(params, Gi + beta * Gj) That is, when calling .step(), we want grads to be Gi + beta * Gj For computational efficiency, we do not compute second derivatives. """ num_mb = len(minibatches) objective = 0 self.optimizer.zero_grad() for p in self.network.parameters(): if p.grad is None: p.grad = torch.zeros_like(p) for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches): # fine tune clone-network on task "i" inner_net = copy.deepcopy(self.network) inner_opt = torch.optim.Adam( inner_net.parameters(), lr=self.hparams["lr"], weight_decay=self.hparams['weight_decay']) inner_obj = F.cross_entropy(inner_net(xi), yi) inner_opt.zero_grad() inner_obj.backward() inner_opt.step() # The network has now accumulated gradients Gi # The clone-network has now parameters P - lr * Gi for p_tgt, p_src in zip(self.network.parameters(), inner_net.parameters()): if p_src.grad is not None: p_tgt.grad.data.add_(p_src.grad.data / num_mb) # `objective` is populated for reporting purposes objective += inner_obj.item() # this computes Gj on the clone-network loss_inner_j = F.cross_entropy(inner_net(xj), yj) grad_inner_j = autograd.grad(loss_inner_j, inner_net.parameters(), allow_unused=True) # `objective` is populated for reporting purposes objective += (self.hparams['mldg_beta'] * loss_inner_j).item() for p, g_j in zip(self.network.parameters(), grad_inner_j): if g_j is not None: p.grad.data.add_(self.hparams['mldg_beta'] * g_j.data / num_mb) # The network has now accumulated gradients Gi + beta * Gj # Repeat for all train-test splits, do .step() objective /= len(minibatches) self.optimizer.step() return {'loss': objective}