def train(train_loader: DataLoader, model, optimizer, epoch, args): model.train() total_loss = 0.0 train_score = 0 total_norm = 0 count_norm = 0 grad_clip = .25 for (v, b, q, a) in tqdm(train_loader): v = v.cuda(args.gpu) b = b.cuda(args.gpu) q = q.cuda(args.gpu) a = a.cuda(args.gpu) pred, att = model(v, b, q, a) loss = instance_bce_with_logits(pred, a) optimizer.zero_grad() loss.backward() total_norm += torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) count_norm += 1 total_loss += loss.item() optimizer.step() batch_score = compute_score_with_logits(pred, a.data).sum() train_score += batch_score.item() total_loss /= len(train_loader) train_score /= len(train_loader.dataset) print('total_loss=', total_loss, '; train_score=', train_score)
def virtual_step(self, trn_X, trn_y, w_optim, Likelihood, step, args): """ Compute unrolled weight w' (virtual step) Step process: 1) forward 2) calc loss 3) compute gradient (by backprop) 4) update gradient Args: xi: learning rate for virtual gradient step (same as weights lr) w_optim: weights optimizer """ # forward & calc loss dataIndex = len(trn_y) + step * args.batch_size v, b, q = trn_X a = trn_y # forward pred, att = self.v_net(v, b, q, a) print('len bacth', a.size()) print('len bacth2', pred.size()) print('a size', v.size()) # sigmoid loss first = torch.sigmoid(Likelihood[step * args.batch_size:dataIndex]) second = instance_bce_with_logits(pred, a, reduction='none').mean(1).cuda() print(first.size()) print(second.size()) lossup = torch.dot(first, second) lossdiv = (torch.sigmoid(Likelihood[step * args.batch_size:dataIndex]).sum()) loss = lossup / lossdiv # loss = torch.dot(torch.sigmoid(Likelihood[step*batch_size:dataIndex]), ignore_crit(logits, trn_y))/(torch.sigmoid(Likelihood[step*batch_size:dataIndex]).sum()) # compute gradient of train loss towards likelihhod loss.backward() # do virtual step (update gradient) # below operations do not need gradient tracking with torch.no_grad(): # dict key is not the value, but the pointer. So original network weight have to # be iterated also. for w, vw in zip(self.net.parameters(), self.v_net.parameters()): m = w_optim.state[w].get('momentum_buffer', 0.) * self.w_momentum if w.grad is not None: vw.copy_(w - args.lr * (m + w.grad + self.w_weight_decay * w))
def train(train_loader: DataLoader, eval_loader, train_dset, model, optimizer, epoch, args): model.train() total_loss = 0.0 train_score = 0 total_norm = 0 count_norm = 0 grad_clip = .25 architect = Architect(model, 0.9, 3e-4) Likelihood = torch.nn.Parameter(torch.ones(len(train_dset)).cuda(), requires_grad=True).cuda() Likelihood_optim = torch.optim.Adam({Likelihood}, 0.1, betas=(0.5, 0.999)) for step, ((v, b, q, a), (vv, vb, vq, va)) in enumerate(zip(train_loader, eval_loader)): v = v.cuda(args.gpu) b = b.cuda(args.gpu) q = q.cuda(args.gpu) a = a.cuda(args.gpu) vv = vv.cuda(args.gpu) vb = vb.cuda(args.gpu) vq = vq.cuda(args.gpu) va = va.cuda(args.gpu) Likelihood_optim.zero_grad() Likelihood, Likelihood_optim, valid_loss = architect.unrolled_backward( (v, b, q), a, (vv, vb, vq), va, optimizer, Likelihood, Likelihood_optim, step, args) pred, att = model(v, b, q, a) loss = instance_bce_with_logits(pred, a) optimizer.zero_grad() loss.backward() total_norm += torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) count_norm += 1 total_loss += loss.item() optimizer.step() batch_score = compute_score_with_logits(pred, a.data).sum() train_score += batch_score.item() total_loss /= len(train_loader) train_score /= len(train_loader.dataset) print('total_loss=', total_loss, '; train_score=', train_score)
def unrolled_backward(self, trn_X, trn_y, val_X, val_y, w_optim, Likelihood, Likelihood_optim, step, args): """ Compute unrolled loss and backward its gradients Args: xi: learning rate for virtual gradient step (same as net lr) w_optim: weights optimizer - for virtual step """ crit = nn.CrossEntropyLoss().cuda() xi = 0.01 # do virtual step (calc w`) self.virtual_step(trn_X, trn_y, w_optim, Likelihood, step, args) vv, vb, vq = val_X va = trn_y # calc val prediction pred, att = self.v_net(vv, vb, vq, va) # calc unrolled validation loss loss = instance_bce_with_logits(pred, va) # L_val(w`) # compute gradient of validation loss towards weights v_weights = tuple(self.v_net.parameters()) # some weights not used return none dw = [] for w in v_weights: if w.requires_grad: dw.append( torch.autograd.grad(loss, w, allow_unused=True, retain_graph=True)) else: dw.append(None) hessian = self.compute_hessian(dw, trn_X, trn_y, Likelihood, args.batch_size, step) Likelihood_optim.zero_grad() # update final gradient = - xi*hessian # with torch.no_grad(): # for likelihood, h in zip(Likelihood, hessian): # print(len(hessian)) # likelihood.grad = - xi*h with torch.no_grad(): Likelihood.grad = -xi * hessian[0] Likelihood_optim.step() return Likelihood, Likelihood_optim, loss
def compute_hessian(self, dw, trn_X, trn_y, Likelihood, batch_size, step): """ dw = dw` { L_val(w`, alpha) } w+ = w + eps * dw w- = w - eps * dw hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps) eps = 0.01 / ||dw|| """ norm = torch.cat([ w[0].view(-1) for w in dw if ((w != None) and (w[0] != None)) ]).norm() eps = 0.01 / norm v, b, q = trn_X a = trn_y # w+ = w + eps*dw` with torch.no_grad(): for p, d in zip(self.net.parameters(), dw): if d != None and d[0] != None: pp = eps * d[0] p += eps * d[0] # forward & calc loss dataIndex = len(trn_y) + step * batch_size # forward logits, att = self.net(v, b, q, a) # sigmoid loss first = torch.sigmoid(Likelihood[step * batch_size:dataIndex]) second = instance_bce_with_logits(logits, a, reduction='none').mean(1).cuda() lossup = torch.dot(first, second) lossdiv = (torch.sigmoid(Likelihood[step * batch_size:dataIndex]).sum()) loss = lossup / lossdiv dalpha_pos = torch.autograd.grad(loss, Likelihood) # dalpha { L_trn(w+) } # w- = w - eps*dw` with torch.no_grad(): for p, d in zip(self.net.parameters(), dw): if d != None and d[0] != None: p -= 2. * eps * d[0] # forward logits, att = self.net(v, b, q, a) # sigmoid loss first = torch.sigmoid(Likelihood[step * batch_size:dataIndex]) second = instance_bce_with_logits(logits, a, reduction='none').mean(1).cuda() lossup = torch.dot(first, second) lossdiv = (torch.sigmoid(Likelihood[step * batch_size:dataIndex]).sum()) loss = lossup / lossdiv dalpha_neg = torch.autograd.grad(loss, Likelihood) # dalpha { L_trn(w-) } # recover w with torch.no_grad(): for p, d in zip(self.net.parameters(), dw): if d != None and d[0] != None: p += eps * d[0] hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)] return hessian