def forward(self, output, target): """ output (FloatTensor): batch_size x n_classes target (LongTensor): batch_size """ model_prob = self.one_hot.repeat(target.size(0), 1) model_prob.scatter_(1, target.unsqueeze(1), self.confidence) model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) return F.kl_div(output, model_prob, reduction='sum')
def softmax_kl_loss(input_logits, target_logits): """Takes softmax on both sides and returns KL divergence Note: - Returns the sum over all examples. Divide by the batch size afterwards if you want the mean. - Sends gradients to inputs but not the targets. """ assert input_logits.size() == target_logits.size() input_log_softmax = F.log_softmax(input_logits, dim=1) target_softmax = F.softmax(target_logits, dim=1) return F.kl_div(input_log_softmax, target_softmax, size_average=False)
def forward(self, output, target): """ output (FloatTensor): batch_size x tgt_vocab_size target (LongTensor): batch_size """ # (batch_size, tgt_vocab_size) true_dist = self.one_hot.repeat(target.size(0), 1) # fill in gold-standard word position with confidence value true_dist.scatter_(1, target.unsqueeze(-1), self.confidence) # fill padded entries with zeros true_dist.masked_fill_((target == self.padding_idx).unsqueeze(-1), 0.) loss = -F.kl_div(output, true_dist, reduction='none').sum(-1) return loss
def train(self, epoch): self.model.train() print("Epochs %d"%epoch) T=2 tasknum = self.incremental_loader.t end = self.incremental_loader.end mid = end-self.args.step_size start = 0 exemplar_dataset_loaders = trainer.ExemplarLoader(self.incremental_loader) exemplar_iterator = torch.utils.data.DataLoader(exemplar_dataset_loaders, batch_size=self.args.replay_batch_size, shuffle=True, drop_last=True, **self.kwargs) if tasknum > 0: iterator = zip(self.train_iterator, exemplar_iterator) else: iterator = self.train_iterator for samples in tqdm(iterator): if tasknum > 0: curr, prev = samples data, target = curr if self.args.ablation == 'None': target = target%(end-mid) batch_size = data.shape[0] data_r, target_r = prev replay_size = data_r.shape[0] data, data_r = data.cuda(), data_r.cuda() data = torch.cat((data,data_r)) target, target_r = target.cuda(), target_r.cuda() else: data, target = samples data = data.cuda() target = target.cuda() batch_size = data.shape[0] output = self.model(data) loss_KD = 0 if self.args.ablation == 'naive': target = torch.cat((target, target_r)) loss_CE = self.loss(output[:,:end],target) / (batch_size + replay_size) else: loss_CE_curr = 0 loss_CE_prev = 0 curr = output[:batch_size,mid:end] loss_CE_curr = self.loss(curr, target) if tasknum > 0: prev = output[batch_size:batch_size+replay_size,start:mid] loss_CE_prev = self.loss(prev, target_r) loss_CE = (loss_CE_curr + loss_CE_prev) / (batch_size + replay_size) # loss_KD score = self.model_fixed(data)[:,:mid].data if self.distill == 'global': soft_target = F.softmax(score / T, dim=1) output_log = F.log_softmax(output[:,:mid] / T, dim=1) loss_KD = F.kl_div(output_log, soft_target, reduction='batchmean') * (T ** 2) elif self.distill == 'local': loss_KD = torch.zeros(tasknum).cuda() for t in range(tasknum): # local distillation start_KD = (t) * self.args.step_size end_KD = (t+1) * self.args.step_size soft_target = F.softmax(score[:,start_KD:end_KD] / T, dim=1) output_log = F.log_softmax(output[:,start_KD:end_KD] / T, dim=1) loss_KD[t] = F.kl_div(output_log, soft_target, reduction='batchmean') * (T**2) loss_KD = loss_KD.sum() else: loss_CE = loss_CE_curr / batch_size self.optimizer.zero_grad() (loss_KD + loss_CE).backward() self.optimizer.step()
def forward(self, data): KL_loss = None x, A, mask, _, params_dict = data[:5] mask_float = mask.float() N_nodes_float = params_dict['N_nodes'].float() B, N, C = x.shape alpha_gt = None if 'node_attn' in params_dict: if not isinstance(params_dict['node_attn'], list): params_dict['node_attn'] = [params_dict['node_attn']] alpha_gt = params_dict['node_attn'][-1].view(B, N) if 'node_attn_eval' in params_dict: if not isinstance(params_dict['node_attn_eval'], list): params_dict['node_attn_eval'] = [params_dict['node_attn_eval']] if (self.pool_type[1] == 'gt' or (self.pool_type[1] == 'sup' and self.training)) and alpha_gt is None: raise ValueError( 'ground truth node attention values node_attn required for %s' % self.pool_type) if self.pool_type[1] in ['unsup', 'sup']: attn_input = data[-1] if self.pool_arch[1] == 'prev' else x.clone() if self.pool_arch[0] == 'fc': alpha_pre = self.proj(attn_input).view(B, N) else: alpha_pre = self.proj([attn_input, *data[1:]])[0].view(B, N) # softmax with masking out dummy nodes alpha_pre = torch.clamp(alpha_pre, -self.clamp_value, self.clamp_value) alpha = normalize_batch( self.mask_out(torch.exp(alpha_pre), mask_float).view(B, N)) if self.pool_type[1] == 'sup' and self.training: KL_loss_per_node = self.mask_out( F.kl_div(torch.log(alpha + 1e-14), alpha_gt, reduction='none'), mask_float.view(B, N)) KL_loss = self.kl_weight * torch.mean( KL_loss_per_node.sum(dim=1) / (N_nodes_float + 1e-7)) # mean over nodes, then mean over batches else: alpha = alpha_gt x = x * alpha.view(B, N, 1) if self.large_graph: # For large graphs during training, all alpha values can be very small hindering training x = x * N_nodes_float.view(B, 1, 1) if self.is_topk: N_remove = torch.round(N_nodes_float * (1 - self.topk_ratio)).long( ) # number of nodes to be removed for each graph idx = torch.sort( alpha, dim=1, descending=False)[1] # indices of alpha in ascending order mask = mask.clone().view(B, N) for b in range(B): idx_b = idx[b, mask[b, idx[ b]]] # take indices of non-dummy nodes for current data example mask[b, idx_b[:N_remove[b]]] = 0 else: mask = (mask & (alpha.view_as(mask) > self.threshold)).view(B, N) if self.drop_nodes: x, A, mask, N_nodes_pooled, idx = self.drop_nodes_edges(x, A, mask) if idx is not None and 'node_attn' in params_dict: # update ground truth (or weakly labeled) attention for a reduced graph params_dict['node_attn'].append( normalize_batch( self.mask_out(torch.gather(alpha_gt, dim=1, index=idx), mask.float()))) if idx is not None and 'node_attn_eval' in params_dict: # update ground truth (or weakly labeled) attention for a reduced graph params_dict['node_attn_eval'].append( normalize_batch( self.mask_out( torch.gather(params_dict['node_attn_eval'][-1], dim=1, index=idx), mask.float()))) else: N_nodes_pooled = torch.sum(mask, dim=1).long() # B if 'node_attn' in params_dict: params_dict['node_attn'].append( (self.mask_out(params_dict['node_attn'][-1], mask.float()))) if 'node_attn_eval' in params_dict: params_dict['node_attn_eval'].append( (self.mask_out(params_dict['node_attn_eval'][-1], mask.float()))) params_dict['N_nodes'] = N_nodes_pooled mask_matrix = mask.unsqueeze(2) & mask.unsqueeze(1) A = A * mask_matrix.float() # or A[~mask_matrix] = 0 # Add additional losses regularizing the model if KL_loss is not None: if 'reg' not in params_dict: params_dict['reg'] = [] params_dict['reg'].append(KL_loss) # Keep attention coefficients for evaluation for key, value in zip(['alpha', 'mask'], [alpha, mask]): if key not in params_dict: params_dict[key] = [] params_dict[key].append(value.detach()) if self.debug and alpha_gt is not None: idx_correct_pool = (alpha_gt > 0) idx_correct_drop = (alpha_gt == 0) alpha_correct_pool = alpha[idx_correct_pool].sum( ) / N_nodes_float.sum() alpha_correct_drop = alpha[idx_correct_drop].sum( ) / N_nodes_float.sum() ratio_avg = (N_nodes_pooled.float() / N_nodes_float).mean() for key, values in zip([ 'alpha_correct_pool_debug', 'alpha_correct_drop_debug', 'ratio_avg_debug' ], [alpha_correct_pool, alpha_correct_drop, ratio_avg]): if key not in params_dict: params_dict[key] = [] params_dict[key].append(values.detach()) return [x, A, mask, *data[3:]]
def distillation(y, teacher_scores, labels, T, alpha): kl_div = F.kl_div(F.log_softmax(y/T, dim=1), F.softmax(teacher_scores/T, dim=1)) return kl_div * (T*T * 2. * alpha) + F.cross_entropy(y, labels, ignore_index=255) * (1. - alpha)
def update(self): # discount reward # R_i = r_i + GAMMA * R_{i+1} r_list = [] R = 0 eps = np.finfo(np.float32).eps.item() norm = lambda a: (a - a.mean()) / (a.std() + eps) for r in self.rewards[::-1]: R = r + self.gamma * R r_list.append(R) r_list = r_list[::-1] r_list = torch.tensor(r_list) r_list = norm(r_list).to(self.device) states = torch.cat(self.states) actions = torch.cat(self.actions).to(self.device) saved_log_probs = torch.cat(self.saved_log_probs).to(self.device) v_list = torch.cat(self.value_list).squeeze(-1) advantage = r_list - v_list advantage = norm(advantage) # compute PG loss # loss = sum(-R_i * log(action_prob)) states = states.detach() actions = actions.detach() log_prob_actions = saved_log_probs.detach() advantage = advantage.detach() r_list = r_list.detach() loss = 0 for _ in range(self.ppo_steps): #get new log prob of actions for all input states action_pred, value_pred = self.model(states) value_pred = value_pred.squeeze(-1) action_prob = F.softmax(action_pred, dim = -1) dist = Categorical(action_prob) #new log prob using old actions new_log_prob_actions = dist.log_prob(actions) policy_ratio = (new_log_prob_actions - log_prob_actions).exp() policy_loss_1 = policy_ratio * advantage # import pdb # pdb.set_trace() kl=F.kl_div(log_prob_actions,new_log_prob_actions.exp(),reduction='mean') if kl < self.kl_target/1.5: self.beta/=2 elif kl >self.kl_target*1.5: self.beta*=2 #kl= #policy_loss_2 = torch.clamp(policy_ratio, min = 1.0 - self.ppo_clip, max = 1.0 + self.ppo_clip) * advantage #policy_loss = - torch.min(policy_loss_1, policy_loss_2).mean() policy_loss=-(policy_loss_1-beta*kl).mean() value_loss = F.smooth_l1_loss(r_list, value_pred).mean() loss += policy_loss.item() + value_loss.item() self.optimizer.zero_grad() policy_loss.backward() value_loss.backward() self.optimizer.step() loss /= self.ppo_steps return loss
def klloss(pred, true): return F.kl_div(pred, true)
def train(vocab, train_sets, dev_sets, test_sets, unlabeled_sets): """ train_sets, dev_sets, test_sets: dict[domain] -> AmazonDataset For unlabeled domains, no train_sets are available """ # dataset loaders train_loaders, unlabeled_loaders = {}, {} train_iters, unlabeled_iters = {}, {} dev_loaders, test_loaders = {}, {} my_collate = utils.sorted_collate if opt.model == 'lstm' else utils.unsorted_collate for domain in opt.domains: train_loaders[domain] = DataLoader(train_sets[domain], opt.batch_size, shuffle=True, collate_fn=my_collate) train_iters[domain] = iter(train_loaders[domain]) for domain in opt.dev_domains: dev_loaders[domain] = DataLoader(dev_sets[domain], opt.batch_size, shuffle=False, collate_fn=my_collate) test_loaders[domain] = DataLoader(test_sets[domain], opt.batch_size, shuffle=False, collate_fn=my_collate) for domain in opt.all_domains: if domain in opt.unlabeled_domains: uset = unlabeled_sets[domain] else: # for labeled domains, consider which data to use as unlabeled set if opt.unlabeled_data == 'both': uset = ConcatDataset( [train_sets[domain], unlabeled_sets[domain]]) elif opt.unlabeled_data == 'unlabeled': uset = unlabeled_sets[domain] elif opt.unlabeled_data == 'train': uset = train_sets[domain] else: raise Exception( f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}' ) unlabeled_loaders[domain] = DataLoader(uset, opt.batch_size, shuffle=True, collate_fn=my_collate) unlabeled_iters[domain] = iter(unlabeled_loaders[domain]) # model F_s = None F_d = {} C, D = None, None if opt.model.lower() == 'dan': F_s = DanFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.sum_pooling, opt.dropout, opt.F_bn) for domain in opt.domains: F_d[domain] = DanFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.sum_pooling, opt.dropout, opt.F_bn) elif opt.model.lower() == 'lstm': F_s = LSTMFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.dropout, opt.bdrnn, opt.attn) for domain in opt.domains: F_d[domain] = LSTMFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.dropout, opt.bdrnn, opt.attn) elif opt.model.lower() == 'cnn': F_s = CNNFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) for domain in opt.domains: F_d[domain] = CNNFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) else: raise Exception(f'Unknown model architecture {opt.model}') C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size, opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels, opt.dropout, opt.C_bn) D = DomainClassifier(opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn) F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device) for f_d in F_d.values(): f_d = f_d.to(opt.device) # optimizers optimizer = optim.Adam(itertools.chain( *map(list, [F_s.parameters() if F_s else [], C.parameters()] + [f.parameters() for f in F_d.values()])), lr=opt.learning_rate) optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate) # testing if opt.test_only: log.info(f'Loading model from {opt.exp3_model_save_file}...') if F_s: F_s.load_state_dict( torch.load( os.path.join(opt.exp3_model_save_file, f'netF_s.pth'))) for domain in opt.all_domains: if domain in F_d: F_d[domain].load_state_dict( torch.load( os.path.join(opt.exp3_model_save_file, f'net_F_d_{domain}.pth'))) C.load_state_dict( torch.load(os.path.join(opt.exp3_model_save_file, f'netC.pth'))) D.load_state_dict( torch.load(os.path.join(opt.exp3_model_save_file, f'netD.pth'))) log.info('Evaluating validation sets:') acc = {} for domain in opt.all_domains: acc[domain] = evaluate(domain, dev_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average validation accuracy: {avg_acc}') log.info('Evaluating test sets:') test_acc = {} for domain in opt.all_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average test accuracy: {avg_test_acc}') return {'valid': acc, 'test': test_acc} # training best_acc, best_avg_acc = defaultdict(float), 0.0 for epoch in range(opt.max_epoch): F_s.train() C.train() D.train() for f in F_d.values(): f.train() # training accuracy correct, total = defaultdict(int), defaultdict(int) # D accuracy d_correct, d_total = 0, 0 # conceptually view 1 epoch as 1 epoch of the first domain num_iter = len(train_loaders[opt.domains[0]]) for i in tqdm(range(num_iter)): # D iterations utils.freeze_net(F_s) map(utils.freeze_net, F_d.values()) utils.freeze_net(C) utils.unfreeze_net(D) # WGAN n_critic trick since D trains slower n_critic = opt.n_critic if opt.wgan_trick: if opt.n_critic > 0 and ((epoch == 0 and i < 25) or i % 500 == 0): n_critic = 100 for _ in range(n_critic): D.zero_grad() loss_d = {} # train on both labeled and unlabeled domains for domain in opt.all_domains: # targets not used d_inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, domain) d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs[1])) shared_feat = F_s(d_inputs) d_outputs = D(shared_feat) # D accuracy _, pred = torch.max(d_outputs, 1) d_total += len(d_inputs[1]) if opt.loss.lower() == 'l2': _, tgt_indices = torch.max(d_targets, 1) d_correct += (pred == tgt_indices).sum().item() l_d = functional.mse_loss(d_outputs, d_targets) l_d.backward() else: d_correct += (pred == d_targets).sum().item() l_d = functional.nll_loss(d_outputs, d_targets) l_d.backward() loss_d[domain] = l_d.item() optimizerD.step() # F&C iteration utils.unfreeze_net(F_s) map(utils.unfreeze_net, F_d.values()) utils.unfreeze_net(C) utils.freeze_net(D) if opt.fix_emb: utils.freeze_net(F_s.word_emb) for f_d in F_d.values(): utils.freeze_net(f_d.word_emb) F_s.zero_grad() for f_d in F_d.values(): f_d.zero_grad() C.zero_grad() for domain in opt.domains: inputs, targets = utils.endless_get_next_batch( train_loaders, train_iters, domain) targets = targets.to(opt.device) shared_feat = F_s(inputs) domain_feat = F_d[domain](inputs) features = torch.cat((shared_feat, domain_feat), dim=1) c_outputs = C(features) l_c = functional.nll_loss(c_outputs, targets) l_c.backward(retain_graph=True) _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().item() # update F with D gradients on all domains for domain in opt.all_domains: d_inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, domain) shared_feat = F_s(d_inputs) d_outputs = D(shared_feat) if opt.loss.lower() == 'gr': d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs[1])) l_d = functional.nll_loss(d_outputs, d_targets) if opt.lambd > 0: l_d *= -opt.lambd elif opt.loss.lower() == 'bs': d_targets = utils.get_random_domain_label( opt.loss, len(d_inputs[1])) l_d = functional.kl_div(d_outputs, d_targets, size_average=False) if opt.lambd > 0: l_d *= opt.lambd elif opt.loss.lower() == 'l2': d_targets = utils.get_random_domain_label( opt.loss, len(d_inputs[1])) l_d = functional.mse_loss(d_outputs, d_targets) if opt.lambd > 0: l_d *= opt.lambd l_d.backward() optimizer.step() # end of epoch log.info('Ending epoch {}'.format(epoch + 1)) if d_total > 0: log.info('D Training Accuracy: {}%'.format(100.0 * d_correct / d_total)) log.info('Training accuracy:') log.info('\t'.join(opt.domains)) log.info('\t'.join( [str(100.0 * correct[d] / total[d]) for d in opt.domains])) log.info('Evaluating validation sets:') acc = {} for domain in opt.dev_domains: acc[domain] = evaluate(domain, dev_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average validation accuracy: {avg_acc}') log.info('Evaluating test sets:') test_acc = {} for domain in opt.dev_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average test accuracy: {avg_test_acc}') if avg_acc > best_avg_acc: log.info(f'New best average validation accuracy: {avg_acc}') best_acc['valid'] = acc best_acc['test'] = test_acc best_avg_acc = avg_acc with open(os.path.join(opt.exp3_model_save_file, 'options.pkl'), 'wb') as ouf: pickle.dump(opt, ouf) torch.save(F_s.state_dict(), '{}/netF_s.pth'.format(opt.exp3_model_save_file)) for d in opt.domains: if d in F_d: torch.save( F_d[d].state_dict(), '{}/net_F_d_{}.pth'.format(opt.exp3_model_save_file, d)) torch.save(C.state_dict(), '{}/netC.pth'.format(opt.exp3_model_save_file)) torch.save(D.state_dict(), '{}/netD.pth'.format(opt.exp3_model_save_file)) # end of training log.info(f'Best average validation accuracy: {best_avg_acc}') return best_acc
def train(self, epoch, triplet=False): T = 2 self.model.train() print("Epochs %d" % epoch) tasknum = self.train_data_iterator.dataset.t end = self.train_data_iterator.dataset.end start = end - self.args.step_size lamb = start / end self.training_target = torch.tensor([]).cuda() self.training_output = torch.tensor([]).cuda() self.training_idx = torch.tensor([]).cuda() for data, target, traindata_idx in tqdm(self.train_data_iterator): target = target.type(dtype=torch.long) data, target = data.cuda(), target.cuda() old_idx = torch.where(target < start)[0] new_idx = torch.where(target >= start)[0] loss_KD = 0 loss_CE = 0 loss_triplet = 0 loss_mr = 0 if tasknum > 0 and triplet == True: if epoch > self.args.triplet_epoch: target_a = target target_b = self.class_corr_prob(target) output, a_feature = self.model(data, feature_return=True) p_feature = torch.zeros( (data.size()[0], a_feature.size()[1])).cuda() n_feature = torch.zeros( (data.size()[0], a_feature.size()[1])).cuda() for i in range(data.size()[0]): p_feature[i] = self.class_anchor[target_a[i]] p_target = torch.ones(data.size()[0], 1).cuda() loss_mr = torch.mean(torch.norm( (a_feature - p_feature))) * self.args.triplet_lam # loss_CE = self.loss(output[:, :end], target) + loss_triplet loss_CE = self.loss(output[:, :end], target) else: output, a_feature = self.model(data, feature_return=True) target_a = target target_b = self.class_corr_prob(target) n_feature = torch.zeros((data[new_idx, :].size()[0], a_feature.size()[1])).cuda() for i in range(data[new_idx, :].size()[0]): n_feature[i] = self.class_anchor[target_b[new_idx][i]] n_target = torch.ones(data[new_idx, :].size()[0], 1).cuda() n_target = n_target * -1 dist = torch.norm((a_feature[new_idx, :] - n_feature), dim=1) loss_mr = self.mr_loss(dist, n_target) loss_CE = self.loss(output[:, :end], target) else: output, a_feature = self.model(data, feature_return=True) loss_CE = self.loss(output[:, :end], target) loss_triplet = 0 _, predicted = output.max(1) prob = F.softmax(output, dim=1).cpu() if tasknum > 0: self.model_fixed.eval() end_KD = start start_KD = end_KD - self.args.step_size score = self.model_fixed(data)[:, :end_KD].data if self.args.KD == "naive_global": soft_target = F.softmax(score / T, dim=1) output_log = F.log_softmax(output[:, :end_KD] / T, dim=1) loss_KD = F.kl_div(output_log, soft_target, reduction='batchmean') elif self.args.KD == "naive_local": local_KD = torch.zeros(tasknum).cuda() for t in range(tasknum): local_start_KD = (t) * self.args.step_size local_end_KD = (t + 1) * self.args.step_size soft_target = F.softmax( score[:, local_start_KD:local_end_KD] / T, dim=1) output_log = F.log_softmax( output[:, local_start_KD:local_end_KD] / T, dim=1) local_KD[t] = F.kl_div(output_log, soft_target, reduction="batchmean") loss_KD = local_KD.sum() self.optimizer.zero_grad() if (self.args.KD == "PCKD") or (self.args.KD == "naive_global") or (self.args.KD == "naive_local"): if triplet: # (lamb * loss_KD + (1 - lamb) * loss_CE + loss_triplet).backward() (loss_KD + loss_CE + loss_triplet + loss_mr).backward() else: (lamb * loss_KD + (1 - lamb) * loss_CE).backward() elif self.args.KD == "No": loss_CE.backward() self.optimizer.step() # weight cliping 0인걸 없애기 weight = self.model.fc.weight.data weight[weight < 0] = 0 self.model.fc.bias.data[:] = 0 print(self.train_data_iterator.dataset.__len__())
def train(net, dataloader, optimizer, epoch): criterion = nn.CrossEntropyLoss() net.train() hard_loss_sum = 0 soft_loss_sum = 0 loss_sum = 0 correct = 0 total = 0 print('\n=> [%s] Training Epoch #%d, lr=%.4f' % (model_name, epoch, cf.learning_rate(args.lr, epoch))) log_file.write('\n=> [%s] Training Epoch #%d, lr=%.4f\n' % (model_name, epoch, cf.learning_rate(args.lr, epoch))) for batch_idx, (inputs, targets) in enumerate(dataloader): inputs = inputs.to(device) targets = targets.to(device) # obtain soft_target by forwarding data in test mode if epoch >= args.distill_from and args.distill > 0: with torch.no_grad(): net.eval() soft_target = net(inputs) net.train() optimizer.zero_grad() outputs = net(inputs) # Forward Propagation loss = criterion(outputs, targets) # Loss hard_loss_sum = hard_loss_sum + loss.item() * targets.size(0) # compute distillation loss if epoch >= args.distill_from and args.distill > 0: heat_output = outputs / args.temp heat_soft_target = soft_target / args.temp distill_loss = F.kl_div(F.log_softmax(heat_output, 1), F.softmax(heat_soft_target), size_average=False) / targets.size(0) soft_loss_sum = soft_loss_sum + distill_loss.item() * targets.size( 0) distill_loss = distill_loss * (args.temp * args.temp) loss = loss + args.distill * distill_loss loss.backward() # Backward Propagation optimizer.step() # Optimizer update loss_sum = loss_sum + loss.item() * targets.size(0) _, predicted = torch.max(outputs.detach(), 1) total += targets.size(0) correct += predicted.eq(targets.detach()).long().sum().item() if math.isnan(loss.item()): print('@@@@@@@nan@@@@@@@@@@@@') log_file.write('@@@@@@@@@@@nan @@@@@@@@@@@@@\n') sys.exit(0) sys.stdout.write('\r') sys.stdout.write( '| Epoch [%3d/%3d] Iter[%3d/%3d]\tLoss: %.4g Acc@1: %.2f%% Hard: %.4g Soft: %.4g' % (epoch, args.num_epochs, batch_idx + 1, (len(trainset) // args.bs) + 1, loss_sum / total, 100. * correct / total, hard_loss_sum / total, soft_loss_sum / total)) sys.stdout.flush() log_file.write( '| Epoch [%3d/%3d] \tLoss: %.4f Acc@1: %.2f%% Hard: %.4f Soft: %.8f' % (epoch, args.num_epochs, loss_sum / total, 100. * correct / total, hard_loss_sum / total, soft_loss_sum / total))
import numpy as np import math from copy import deepcopy import pdb import torch import torch.nn as nn import torch.nn.functional as F import utils from utils import get_grad_vector, get_future_step_parameters, add_grad, get_weight_accumelated_gradient_norm, weight_gradient_norm, get_future_step_parameters_with_grads, get_nearbysamples, get_mask_unused_memories, get_weight_norm_diff from VAE.loss import calculate_loss #---------- # Functions dist_kl = lambda y, t_s: F.kl_div(F.log_softmax(y, dim=-1), F.softmax(t_s, dim=-1), reduction='mean') * y.size(0) # this returns -entropy entropy_fn = lambda x: torch.sum( F.softmax(x, dim=-1) * F.log_softmax(x, dim=-1), dim=-1) cross_entropy = lambda y, t_s: -torch.sum( F.log_softmax(y, dim=-1) * F.softmax(t_s, dim=-1), dim=-1).mean() mse = torch.nn.MSELoss() def retrieve_gen_for_cls(args, x, cls, prev_cls, prev_gen): grad_vector = get_grad_vector(args, cls.parameters, cls.grad_dims)
def train(net, train_loader, optimizer): """Train for one epoch.""" net.train() data_ema = 0. batch_ema = 0. loss_ema = 0. acc1_ema = 0. acc5_ema = 0. end = time.time() for i, (images, targets) in enumerate(train_loader): # Compute data loading time data_time = time.time() - end optimizer.zero_grad() if args.no_jsd: images = images.cuda() targets = targets.cuda() logits = net(images) loss = F.cross_entropy(logits, targets) acc1, acc5 = accuracy(logits, targets, topk=(1, 5)) # pylint: disable=unbalanced-tuple-unpacking else: images_all = torch.cat(images, 0).cuda() targets = targets.cuda() logits_all = net(images_all) logits_clean, logits_aug1, logits_aug2 = torch.split( logits_all, images[0].size(0)) # Cross-entropy is only computed on clean images loss = F.cross_entropy(logits_clean, targets) p_clean, p_aug1, p_aug2 = F.softmax( logits_clean, dim=1), F.softmax(logits_aug1, dim=1), F.softmax(logits_aug2, dim=1) # Clamp mixture distribution to avoid exploding KL divergence p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log() loss += 12 * ( F.kl_div(p_mixture, p_clean, reduction='batchmean') + F.kl_div(p_mixture, p_aug1, reduction='batchmean') + F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3. acc1, acc5 = accuracy(logits_clean, targets, topk=(1, 5)) # pylint: disable=unbalanced-tuple-unpacking loss.backward() optimizer.step() # Compute batch computation time and update moving averages. batch_time = time.time() - end end = time.time() data_ema = data_ema * 0.1 + float(data_time) * 0.9 batch_ema = batch_ema * 0.1 + float(batch_time) * 0.9 loss_ema = loss_ema * 0.1 + float(loss) * 0.9 acc1_ema = acc1_ema * 0.1 + float(acc1) * 0.9 acc5_ema = acc5_ema * 0.1 + float(acc5) * 0.9 if i % args.print_freq == 0: print( 'Batch {}/{}: Data Time {:.3f} | Batch Time {:.3f} | Train Loss {:.3f} | Train Acc1 ' '{:.3f} | Train Acc5 {:.3f}'.format(i, len(train_loader), data_ema, batch_ema, loss_ema, acc1_ema, acc5_ema)) return loss_ema, acc1_ema, batch_ema
def _kl_div(logit1, logit2): return F.kl_div(F.log_softmax(logit1, dim=1), F.softmax(logit2, dim=1), reduction='batchmean')
def forward(self, input, target): logp = F.log_softmax(input, dim=1) target_one_hot = self._smooth_labels(input.size(1), target) return F.kl_div(logp, target_one_hot, reduction='sum')
def _kl_div(log_probs, probs): # pytorch KLDLoss is averaged over all dim if size_average=True kld = F.kl_div(log_probs, probs, size_average=False) return kld / log_probs.shape[0]
def beam_sample(self, src, src_len, dict_spk2idx, tgt, beam_size=1): src = src.transpose(0, 1) #beam_size = self.config.beam_size batch_size = src.size(0) # (1) Run the encoder on the src. Done!!!! if self.use_cuda: src = src.cuda() src_len = src_len.cuda() lengths, indices = torch.sort(src_len, dim=0, descending=True) # _, ind = torch.sort(indices) # src = Variable(torch.index_select(src, dim=1, index=indices), volatile=True) contexts, encState = self.encoder(src, lengths.data.cpu().numpy()[0]) # (1b) Initialize for the decoder. def var(a): return Variable(a, volatile=True) def rvar(a): return var(a.repeat(1, beam_size, 1)) def bottle(m): return m.view(batch_size * beam_size, -1) def unbottle(m): return m.view(beam_size, batch_size, -1) # Repeat everything beam_size times. contexts = rvar(contexts.data).transpose(0, 1) decState = (rvar(encState[0].data), rvar(encState[1].data)) #decState.repeat_beam_size_times(beam_size) beam = [ models.Beam(beam_size, dict_spk2idx, n_best=1, cuda=self.use_cuda) for __ in range(batch_size) ] # (2) run the decoder to generate sentences, using beam search. mask = None soft_score = None tmp_hiddens = [] tmp_soft_score = [] for i in range(self.config.max_tgt_len): if all((b.done() for b in beam)): break # Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. inp = var( torch.stack([b.getCurrentState() for b in beam]).t().contiguous().view(-1)) if self.config.schmidt and i > 0: assert len(beam[0].sch_hiddens[-1]) == i tmp_hiddens = [] for xxx in range(i): #每一个sch之前的列表 one_len = [] for bm_idx in range(beam_size): for bs_idx in range(batch_size): one_len.append( beam[bs_idx].sch_hiddens[-1][xxx][bm_idx, :]) tmp_hiddens.append(var(torch.stack(one_len))) # Run one step. output, decState, attn, hidden, emb = self.decoder.sample_one( inp, soft_score, decState, tmp_hiddens, contexts, mask) # print "sample after decState:",decState[0].data.cpu().numpy().mean() if self.config.schmidt: tmp_hiddens += [hidden] if self.config.ct_recu: contexts = (1 - (attn > 0.003).float()).unsqueeze(-1) * contexts soft_score = F.softmax(output) if self.config.tmp_score: tmp_soft_score += [soft_score] if i == 1: kl_loss = np.array([]) for kk in range(self.config.beam_size): kl_loss = np.append( kl_loss, F.kl_div(soft_score[kk], tmp_soft_score[0][kk]).data[0]) kl_loss = Variable( torch.from_numpy(kl_loss).float().cuda().unsqueeze(-1)) predicted = output.max(1)[1] if self.config.mask: if mask is None: mask = predicted.unsqueeze(1).long() else: mask = torch.cat((mask, predicted.unsqueeze(1)), 1) # decOut: beam x rnn_size # (b) Compute a vector of batch*beam word scores. if self.config.tmp_score and i == 1: output = unbottle( self.log_softmax(output) + self.config.tmp_score * kl_loss) else: output = unbottle(self.log_softmax(output)) attn = unbottle(attn) hidden = unbottle(hidden) emb = unbottle(emb) # beam x tgt_vocab # (c) Advance each beam. # update state for j, b in enumerate(beam): b.advance(output.data[:, j], attn.data[:, j], hidden.data[:, j], emb.data[:, j]) b.beam_update(decState, j) #这个函数更新了原来的decState,只不过不是用return,是直接赋值! if self.config.ct_recu: b.beam_update_context( contexts, j) #这个函数更新了原来的decState,只不过不是用return,是直接赋值! # print "beam after decState:",decState[0].data.cpu().numpy().mean() # (3) Package everything up. allHyps, allScores, allAttn, allHiddens, allEmbs = [], [], [], [], [] ind = range(batch_size) for j in ind: b = beam[j] n_best = 1 scores, ks = b.sortFinished(minimum=n_best) hyps, attn, hiddens, embs = [], [], [], [] for i, (times, k) in enumerate(ks[:n_best]): hyp, att, hidden, emb = b.getHyp(times, k) if self.config.relitu: relitu_line(626, 1, att[0].cpu().numpy()) relitu_line(626, 1, att[1].cpu().numpy()) hyps.append(hyp) attn.append(att.max(1)[1]) hiddens.append(hidden) embs.append(emb) allHyps.append(hyps[0]) allScores.append(scores[0]) allAttn.append(attn[0]) allHiddens.append(hiddens[0]) allEmbs.append(embs[0]) print allHyps if not self.config.global_emb: outputs = Variable(torch.stack(allHiddens, 0).transpose( 0, 1)) # to [decLen, bs, dim] if not self.config.hidden_mix: predicted_maps = self.ss_model(src, outputs[:-1, :], tgt[1:-1]) else: ss_embs = Variable(torch.stack(allEmbs, 0).transpose( 0, 1)) # to [decLen, bs, dim] mix = torch.cat((outputs[:-1, :], ss_embs[1:]), dim=2) predicted_maps = self.ss_model(src, mix, tgt[1:-1]) if self.config.top1: predicted_maps = predicted_maps[:, 0].unsqueeze(1) else: ss_embs = Variable(torch.stack(allEmbs, 0).transpose( 0, 1)) # to [decLen, bs, dim] if not self.config.top1: predicted_maps = self.ss_model(src, ss_embs[1:, :], tgt[1:-1], dict_spk2idx) else: predicted_maps = self.ss_model(src, ss_embs[1:2], tgt[1:2]) return allHyps, allAttn, allHiddens, predicted_maps #.transpose(0,1)
def train_sdcn(dataset): # KNN Graph adj = load_graph(args.name, args.k) adj = adj.cuda() model = SDCN(dataset=dataset, n_enc_1=500, n_enc_2=500, n_enc_3=2000, n_dec_1=2000, n_dec_2=500, n_dec_3=500, n_input=args.n_input, n_z=args.n_z, n_clusters=args.n_clusters, v=1.0, adj=adj).to(device) # print(model) # model.pretrain(args.pretrain_path) if args.name == 'cite': optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=5e-4) if args.name == 'dblp': optimizer = Adam(model.parameters(), lr=args.lr) # cluster parameter initiate data = torch.Tensor(dataset.x).to(device) y = dataset.y with torch.no_grad(): _, _, _, _, z = model.ae(data) kmeans = KMeans(n_clusters=args.n_clusters, n_init=20) y_pred = kmeans.fit_predict(z.data.cpu().numpy()) y_pred_last = y_pred model.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to(device) eva(y, y_pred, 'pae') # ## Visualization # with SummaryWriter(comment='model') as w: # testdata = torch.rand(3327,3703).cuda() # testadj = torch.rand(3327,3327).cuda() # w.add_graph(model,(testdata,testadj)) # ## for epoch in range(200): if epoch % 1 == 0: # update_interval # x_bar为decoder的结果 # q为t分布 # predict为softmax后预测的种类 # z为encoder的潜在特征表示 _, tmp_q, pred, _, adj_re = model(data, adj) tmp_q = tmp_q.data p = target_distribution(tmp_q) res1 = tmp_q.cpu().numpy().argmax(1) #Q res2 = pred.data.cpu().numpy().argmax(1) #Z res3 = p.data.cpu().numpy().argmax(1) #P # eva(y, res1, str(epoch) + 'Q') acc, nmi, ari, f1 = eva(y, res2, str(epoch) + 'Z') ## Visualization # writer.add_scalar('checkpoints/scalar', acc, epoch) # writer.add_scalar('checkpoints/scalar', nmi, epoch) # writer.add_scalar('checkpoints/scalar', ari, epoch) # writer.add_scalar('checkpoints/scalar', f1, epoch) ## # eva(y, res3, str(epoch) + 'P') # x_bar为decoder的结果 # q为t分布 # predict为softmax后预测的种类 # z为encoder的潜在特征表示 x_bar, q, pred, _, adj_re = model(data, adj) kl_loss = F.kl_div(q.log(), p, reduction='batchmean') ce_loss = F.kl_div(pred.log(), p, reduction='batchmean') re_loss = F.mse_loss(x_bar, data) # print('kl_loss:{} ce_loss:{}'.format(kl_loss, ce_loss)) # print(p.shape) # torch.Size([3327, 6]) # print(q.log().shape) # torch.Size([3327, 6]) # print(pred.log().shape) # torch.Size([3327, 6]) loss = 0.1 * kl_loss + 0.01 * ce_loss + re_loss # adj_loss = norm*F.binary_cross_entropy(adj_re.view(-1).cpu(), adj_label.to_dense().view(-1).cpu(), weight = weight_tensor) # loss = 0.1 * kl_loss + 0.01 * ce_loss + re_loss + 0.1*adj_loss optimizer.zero_grad() loss.backward() optimizer.step()
def kdloss(y, teacher_scores): p = F.log_softmax(y, dim=1) q = F.softmax(teacher_scores, dim=1) l_kl = F.kl_div(p, q, size_average=False) / y.shape[0] return l_kl
def distillation(y, teacher_scores, labels, T, alpha): p = F.log_softmax(y / T, dim=1) q = F.softmax(teacher_scores / T, dim=1) l_kl = F.kl_div(p, q, reduction='sum') * (T**2) / y.shape[0] l_ce = F.cross_entropy(y, labels) return l_kl * alpha + l_ce * (1. - alpha)
def reg_crit(self, x, y): return F.kl_div(torch.log(x), target=y, reduction="sum")
def main(): data_times, batch_times, losses, acc = [AverageMeter() for _ in range(4)] if args.mix_up: unlabel_acc = AverageMeter() best_acc = 0. model.train() logger.info("Start training...") for step in range(args.start_step, args.total_steps): # Load data and distribute to devices data_start = time.time() if args.mix_up: label_img, label_gt, unlabel_img, unlabel_gt = next(train_loader) else: label_img, label_gt = next(train_loader) if args.gpu: label_img = label_img.cuda() label_gt = label_gt.cuda() if args.mix_up: if args.gpu: unlabel_img = unlabel_img.cuda() unlabel_gt = unlabel_gt.cuda() _label_gt = F.one_hot(label_gt, num_classes=args.num_classes).float() data_end = time.time() # Compute learning rate lr = compute_lr(step) for param_group in optimizer.param_groups: param_group['lr'] = lr if args.mix_up: # Adopt mix-up augmentation model.eval() with torch.no_grad(): label_pred = model(label_img) unlabel_pred = F.softmax(model(unlabel_img), dim=1) alpha = beta_distribution.sample((args.batch_size,)) if args.gpu: alpha = alpha.cuda() _alpha = alpha.view(-1, 1, 1, 1) interp_img = (label_img * _alpha + unlabel_img * (1. - _alpha)).detach() interp_pseudo_gt = (_label_gt * alpha + unlabel_pred * (1. - alpha)).detach() model.train() interp_pred = model(interp_img) loss = F.kl_div(F.log_softmax(interp_pred, dim=1), interp_pseudo_gt, reduction='batchmean') else: # Regular label loss label_pred = model(label_img) loss = F.cross_entropy(label_pred, label_gt, reduction='mean') # One SGD step optimizer.zero_grad() loss.backward() optimizer.step() # Compute accuracy top1, = accuracy(label_pred, label_gt, topk=(1,)) if args.mix_up: unlabel_top1, = accuracy(unlabel_pred, unlabel_gt, topk=(1,)) # Update AverageMeter stats data_times.update(data_end - data_start) batch_times.update(time.time() - data_end) losses.update(loss.item(), label_img.size(0)) acc.update(top1.item(), label_img.size(0)) if args.mix_up: unlabel_acc.update(unlabel_top1.item(), unlabel_img.size(0)) # Print and log if step % args.print_freq == 0: logger.info("Step: [{0:05d}/{1:05d}] Dtime: {dtimes.avg:.3f} Btime: {btimes.avg:.3f} " "loss: {losses.val:.3f} (avg {losses.avg:.3f}) Lacc: {label.val:.3f} (avg {label.avg:.3f}) " "LR: {2:.4f}".format(step, args.total_steps, optimizer.param_groups[0]['lr'], dtimes=data_times, btimes=batch_times, losses=losses, label=acc)) # Test and save model if (step + 1) % args.test_freq == 0 or step == args.total_steps - 1: val_acc = evaluate(test_loader, model) # Remember best accuracy and save checkpoint is_best = val_acc > best_acc if is_best: best_acc = val_acc logger.info("Best Accuracy: %.5f" % best_acc) save_checkpoint({ 'step': step + 1, 'model': model.state_dict(), 'best_acc': best_acc, 'optimizer' : optimizer.state_dict() }, is_best, path=args.save_path, filename="checkpoint.pth") # Write to tfboard writer.add_scalar('train/label-acc', top1.item(), step) if args.mix_up: writer.add_scalar('train/unlabel-acc', unlabel_top1.item(), step) writer.add_scalar('train/loss', loss.item(), step) writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], step) writer.add_scalar('test/accuracy', val_acc, step) # Reset AverageMeters losses.reset() acc.reset() if args.mix_up: unlabel_acc.reset()
def get_q_values(self, batch): """update where states are whole conversations which each have several sentences, and actions are a sentence (series of words). Q values are per word. Target Q values are over the next word in the sentence, or, if at the end of the sentence, the first word in a new sentence after the user response. """ actions = to_var(torch.LongTensor(batch['action'])) # [batch_size] # Prepare inputs to Q network conversations = [ np.concatenate((conv, np.atleast_2d(batch['action'][i]))) for i, conv in enumerate(batch['state']) ] sent_lens = [ np.concatenate((lens, np.atleast_1d(batch['action_lens'][i]))) for i, lens in enumerate(batch['state_lens']) ] target_conversations = [conv[1:] for conv in conversations] conv_lens = [len(c) - 1 for c in conversations] if self.config.model not in VariationalModels: conversations = [conv[:-1] for conv in conversations] sent_lens = np.concatenate([l[:-1] for l in sent_lens]) else: sent_lens = np.concatenate([l for l in sent_lens]) conv_lens = to_var(torch.LongTensor(conv_lens)) # Run Q network. Will produce [num_sentences, max sent len, vocab size] all_q_values = self.run_seq2seq_model(self.q_net, conversations, sent_lens, target_conversations, conv_lens) # Index to get only q values for actions taken (last sentence in each # conversation) start_q = torch.cumsum( torch.cat((to_var(conv_lens.data.new(1).zero_()), conv_lens[:-1])), 0) conv_q_values = torch.stack([ all_q_values[s + l - 1, :, :] for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist()) ], 0) # [num_sentences, max_sent_len, vocab_size] # Limit by actual sentence length (remove padding) and flatten into # long list of words word_q_values = torch.cat([ conv_q_values[i, :l, :] for i, l in enumerate(batch['action_lens']) ], 0) # [total words, vocab_size] word_actions = torch.cat( [actions[i, :l] for i, l in enumerate(batch['action_lens'])], 0) # [total words] # Extract q values corresponding to actions taken q_values = word_q_values.gather( 1, word_actions.unsqueeze(1)).squeeze() # [total words] """ Compute KL metrics """ prior_rewards = None # Get probabilities from policy network q_dists = torch.nn.functional.softmax(word_q_values, 1) q_probs = q_dists.gather(1, word_actions.unsqueeze(1)).squeeze() with torch.no_grad(): # Run pretrained prior network. # [num_sentences, max sent len, vocab size] all_prior_logits = self.run_seq2seq_model(self.pretrained_prior, conversations, sent_lens, target_conversations, conv_lens) # Get relevant actions. [num_sentences, max_sent_len, vocab_size] conv_prior = torch.stack([ all_prior_logits[s + l - 1, :, :] for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist()) ], 0) # Limit by actual sentence length (remove padding) and flatten. # [total words, vocab_size] word_prior_logits = torch.cat([ conv_prior[i, :l, :] for i, l in enumerate(batch['action_lens']) ], 0) # Take the softmax prior_dists = torch.nn.functional.softmax(word_prior_logits, 1) kl_div = F.kl_div(q_dists.log(), prior_dists, reduce=False) # [total words] prior_probs = prior_dists.gather( 1, word_actions.unsqueeze(1)).squeeze() logp_logq = prior_probs.log() - q_probs.log() if self.config.model_averaging: model_avg_sentences = batch['model_averaged_probs'] # Convert to tensors and flatten into [num_words] word_model_avg = torch.cat([ to_var(torch.FloatTensor(m)) for m in model_avg_sentences ], 0) # Compute KL from model-averaged prior prior_rewards = word_model_avg.log() - q_probs.log() # Clip because KL should never be negative, so because we # are subtracting KL, rewards should never be positive prior_rewards = torch.clamp(prior_rewards, max=0.0) elif self.config.kl_control and self.config.kl_calc == 'integral': # Note: we reward the negative KL divergence to ensure the # RL model stays close to the prior prior_rewards = -1.0 * torch.sum(kl_div, dim=1) elif self.config.kl_control: prior_rewards = logp_logq if self.config.kl_control: prior_rewards = prior_rewards * self.config.kl_weight_c self.kl_reward_batch_history.append( torch.sum(prior_rewards).item()) # Track all metrics self.kl_div_batch_history.append(torch.mean(kl_div).item()) self.logp_batch_history.append( torch.mean(prior_probs.log()).item()) self.logp_logq_batch_history.append(torch.mean(logp_logq).item()) return q_values, prior_rewards
def tts_train_loop_af_online(paths: Paths, model: Tacotron, model_tf: Tacotron, optimizer, train_set, lr, train_steps, attn_example, hp=None): # setattr(model, 'mode', 'attention_forcing') # setattr(model, 'mode', 'teacher_forcing') # import pdb; pdb.set_trace() def smooth(d, eps=float(1e-10)): u = 1.0 / float(d.size()[2]) return eps * u + (1 - eps) * d device = next( model.parameters()).device # use same device as model parameters for g in optimizer.param_groups: g['lr'] = lr total_iters = len(train_set) epochs = train_steps // total_iters + 1 for e in range(1, epochs + 1): start = time.time() running_loss_out, running_loss_attn = 0, 0 # Perform 1 epoch for i, (x, m, ids, _) in enumerate(train_set, 1): # print(i) # import pdb; pdb.set_trace() x, m = x.to(device), m.to(device) # pdb.set_trace() # print(model.r, model_tf.r) # import pdb; pdb.set_trace() # Parallelize model onto GPUS using workaround due to python bug if device.type == 'cuda' and torch.cuda.device_count() > 1: with torch.no_grad(): _, _, attn_ref = data_parallel_workaround(model_tf, x, m) m1_hat, m2_hat, attention = data_parallel_workaround( model, x, m, False, attn_ref) else: with torch.no_grad(): _, _, attn_ref = model_tf(x, m) # pdb.set_trace() # setattr(model, 'mode', 'teacher_forcing') # with torch.no_grad(): _, _, attn_ref = model(x, m) # setattr(model, 'mode', 'attention_forcing_online') m1_hat, m2_hat, attention = model(x, m, generate_gta=False, attn_ref=attn_ref) # m1_hat, m2_hat, attention = model(x, m, generate_gta=False, attn_ref=None) # pdb.set_trace() # print(x.size()) # print(m.size()) # print(m1_hat.size(), m2_hat.size()) # print(attention.size(), attention.size(1)*model.r) # print(attn_ref.size()) # pdb.set_trace() m1_loss = F.l1_loss(m1_hat, m) m2_loss = F.l1_loss(m2_hat, m) attn_loss = F.kl_div(torch.log(smooth(attention)), smooth(attn_ref), reduction='none') # 'batchmean' attn_loss = attn_loss.sum(2).mean() # attn_loss = F.l1_loss(smooth(attention), smooth(attn_ref)) loss_out = m1_loss + m2_loss loss_attn = attn_loss * hp.attn_loss_coeff loss = loss_out + loss_attn # if i%100==0: # save_attention(np_now(attn_ref[0][:, :160]), paths.tts_attention/f'asup_{step}_tf') # save_attention(np_now(attention[0][:, :160]), paths.tts_attention/f'asup_{step}_af') # model_tf.r = 2 # with torch.no_grad(): _, _, attn_ref = model_tf(x, m) # save_attention(np_now(attn_ref[0][:, :160]), paths.tts_attention/f'asup_{step}_tf_r2') # model_tf.r = model.r # pdb.set_trace() optimizer.zero_grad() loss.backward() if hp.tts_clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), hp.tts_clip_grad_norm) if np.isnan(grad_norm): print('grad_norm was NaN!') optimizer.step() running_loss_out += loss_out.item() avg_loss_out = running_loss_out / i running_loss_attn += loss_attn.item() avg_loss_attn = running_loss_attn / i speed = i / (time.time() - start) step = model.get_step() k = step // 1000 if step % hp.tts_checkpoint_every == 0: ckpt_name = f'taco_step{k}K' save_checkpoint('tts', paths, model, optimizer, name=ckpt_name, is_silent=True) if attn_example in ids: idx = ids.index(attn_example) save_attention(np_now(attn_ref[idx][:, :160]), paths.tts_attention / f'{step}_tf') save_attention(np_now(attention[idx][:, :160]), paths.tts_attention / f'{step}_af') save_spectrogram(np_now(m2_hat[idx]), paths.tts_mel_plot / f'{step}', 600) msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss_out: {avg_loss_out:#.4}; Loss_attn: {avg_loss_attn:#.4} | {speed:#.2} steps/s | Step: {k}k | ' stream(msg) # Must save latest optimizer state to ensure that resuming training # doesn't produce artifacts save_checkpoint('tts', paths, model, optimizer, is_silent=True) model.log(paths.tts_log, msg) print(' ')
def _get_loss_f(self, x, y, targeted, reduction): #x, y original ref_data / target #targeted whether to use a targeted attack or not #reduction: reduction to use: 'sum', 'mean', 'none' if isinstance(self.loss, str): if self.loss.lower() in ['crossentropy', 'ce']: if not targeted: l_f = lambda data, data_out: -F.cross_entropy( data_out, y, reduction=reduction) else: l_f = lambda data, data_out: F.cross_entropy( data_out, y, reduction=reduction) elif self.loss.lower() == 'kl': if not targeted: l_f = lambda data, data_out: -reduce( F.kl_div(torch.log_softmax(data_out, dim=1), y, reduction='none').sum(dim=1), reduction) else: l_f = lambda data, data_out: reduce( F.kl_div(torch.log_softmax(data_out, dim=1), y, reduction='none').sum(dim=1), reduction) elif self.loss.lower() == 'logitsdiff': if not targeted: y_oh = F.one_hot(y, self.num_classes) y_oh = y_oh.float() l_f = lambda data, data_out: -logits_diff_loss( data_out, y_oh, reduction=reduction) else: y_oh = F.one_hot(y, self.num_classes) y_oh = y_oh.float() l_f = lambda data, data_out: logits_diff_loss( data_out, y_oh, reduction=reduction) elif self.loss.lower() == 'conf': if not targeted: l_f = lambda data, data_out: confidence_loss( data_out, y, reduction=reduction) else: l_f = lambda data, data_out: -confidence_loss( data_out, y, reduction=reduction) elif self.loss.lower() == 'confdiff': if not targeted: y_oh = F.one_hot(y, self.num_classes) y_oh = y_oh.float() l_f = lambda data, data_out: -conf_diff_loss( data_out, y_oh, reduction=reduction) else: y_oh = F.one_hot(y, self.num_classes) y_oh = y_oh.float() l_f = lambda data, data_out: conf_diff_loss( data_out, y_oh, reduction=reduction) else: raise ValueError(f'Loss {self.loss} not supported') else: #custom 5 argument loss #(x_adv, x_adv_out, x, y, reduction) l_f = lambda data, data_out: self.loss( data, data_out, x, y, reduction=reduction) return l_f
def distillation(y, teacher_scores, labels, T, alpha): return F.kl_div(F.log_softmax(y/T), F.softmax(teacher_scores/T)) * (T*T * 2. * alpha) \ + F.cross_entropy(y, labels) * (1. - alpha)
def trn_step(self, epoch, sample_l, sample_u, scaler=None): self.optim.zero_grad() images_l = sample_l['image'].to(self.cfg.device) images_o = sample_u["image_ori"].to(self.cfg.device) images_a = sample_u["image_aug"].to(self.cfg.device) labels = sample_l['label'].to(self.cfg.device) batch_s = images_l.size(0) images_t = torch.cat([images_l, images_o, images_a]) if scaler is not None: with autocast(): logits_t = self.model(images_t) logits_l = logits_t[:batch_s] logits_o, logits_a = logits_t[batch_s:].chunk(2) del logits_t preds_o = F.softmax(logits_o, dim=-1).detach() preds_a = F.log_softmax(logits_a, dim=-1) kl_loss = F.kl_div(preds_a, preds_o, reduction='none') kl_loss = torch.mean(torch.sum(kl_loss, dim=-1)) l_loss = self.trn_crit(logits_l, labels) if self.cfg.ratio_mode == 'constant': t_loss = l_loss + self.cfg.ratio * torch.mean(kl_loss) elif self.cfg.ratio_mode == "gradual": t_loss = epoch / self.cfg.t_epoch * self.cfg.ratio * torch.mean( kl_loss) + l_loss scaler.scale(t_loss).backward() # clipping point -> batchnorm을 대체하는 역할 AGC scaler.unscale_(self.optim) if self.cfg.clipping: timm.utils.adaptive_clip_grad(self.model.parameters()) scaler.step(self.optim) scaler.update() else: logits_t = self.model(images_t) logits_l = logits_t[:batch_s] logits_o, logits_a = logits_t[batch_s:].chunk(2) del logits_t preds_o = F.softmax(logits_o, dim=-1).detach() preds_a = F.log_softmax(logits_a, dim=-1) kl_loss = F.kl_div(preds_a, preds_o, reduction='none') kl_loss = torch.mean(torch.sum(kl_loss, dim=-1)) l_loss = self.trn_crit(logits_l, labels) if self.cfg.ratio_mode == 'constant': t_loss = l_loss + self.cfg.ratio * kl_loss elif self.cfg.ratio_mode == "gradual": t_loss = epoch / self.cfg.t_epoch * self.cfg.ratio * kl_loss + l_loss t_loss.backward() if self.cfg.clipping: timm.utils.adaptive_clip_grad(self.model.parameters()) self.optim.step() batch_acc = self.accuracy(logits_l, labels) batch_f1 = self.f1_score(logits_l, labels) result = { 'l_loss': l_loss, 't_loss': t_loss, 'kl_loss': kl_loss, 'batch_acc': batch_acc, 'batch_f1': batch_f1 } return result
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, sentence_span_list=None, sentence_labels=None, sentence_ids=None, sentence_prob=None, max_sentences=0): flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None seq_output, _ = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) origin_hidden_size = int(seq_output.size(2) / 2) seq_output = seq_output[:, :, self.view_id * origin_hidden_size: (self.view_id + 1) * origin_hidden_size] # mask: 1 for masked value and 0 for true value # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask) doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \ rep_layers.split_doc_sen_que(seq_output, flat_token_type_ids, flat_attention_mask, sentence_span_list, max_sentences=max_sentences) # doc_sen_mask = 1 - doc_sen_mask # que_mask = 1 - que_mask # sentence_mask = 1 - sentence_mask # assert doc_sen.sum() != torch.nan batch, max_sen, doc_len = doc_sen_mask.size() assert max_sen == max_sentences que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1) doc = doc_sen.reshape(batch, max_sen * doc_len, -1) word_sim = self.word_similarity(que_vec, doc).view(batch * max_sen, doc_len) doc = doc_sen.reshape(batch * max_sen, doc_len, -1) doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len) word_hidden = rep_layers.masked_softmax(word_sim, doc_mask, dim=1).unsqueeze(1).bmm(doc) word_hidden = word_hidden.view(batch, max_sen, -1) doc_vecs = self.doc_sen_self_attn(doc, doc_mask).view(batch, max_sen, -1) sentence_sim = self.vector_similarity(que_vec, doc_vecs) sentence_alpha = rep_layers.masked_softmax(sentence_sim, sentence_mask) sentence_hidden = sentence_alpha.bmm(word_hidden).squeeze(1) choice_logits = self.classifier(torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1)).reshape(-1, self.num_choices) if self.training: output_dict = {} else: output_dict = { 'choice_logits': torch.softmax(choice_logits, dim=-1).detach().cpu().float(), 'sentence_logits': sentence_alpha.reshape(choice_logits.size(0), self.num_choices, max_sen).detach().cpu().float() } if labels is not None: choice_loss = functional.cross_entropy(choice_logits, labels) loss = choice_loss if self.multi_evidence and sentence_prob is not None: sentence_prob = sentence_prob.reshape(batch, -1).to(sentence_sim.dtype) true_prob_mask = ((sentence_prob > 0).to(dtype=sentence_sim.dtype)) * sentence_mask.to(dtype=sentence_sim.dtype) kl_sentence_loss = functional.kl_div(sentence_alpha * true_prob_mask, sentence_prob * true_prob_mask, reduction='sum') output_dict['sentence_loss'] = kl_sentence_loss.item() loss += self.evidence_lam * kl_sentence_loss / choice_logits.size(0) elif not self.multi_evidence and sentence_ids is not None: # logger.info(f'sentence_ids total number: {(sentence_ids != -1).sum().item()}') # logger.info('sentence_mask.sum() = ', sentence_mask.sum()) assert sentence_mask.sum() != 0, sentence_mask.sum() # assert all(x < sentence_mask.sum() for x in sentence_ids.view(batch).detach().tolist()) assertion = (sentence_ids.view(batch) >= sentence_mask.sum(dim=-1)).sum() log_masked_sentence_prob = rep_layers.masked_log_softmax(sentence_sim.squeeze(1), sentence_mask) sentence_loss = functional.nll_loss(log_masked_sentence_prob, sentence_ids.view(batch), reduction='sum', ignore_index=-1) # sentence_loss = functional.cross_entropy(sentence_sim.squeeze(1), sentence_ids.view(batch), reduction='sum', # ignore_index=-1) # logger.info(f'sentence loss: {sentence_loss.item()}') loss += self.evidence_lam * sentence_loss / choice_logits.size(0) output_dict['sentence_loss'] = sentence_loss.item() output_dict['loss'] = loss return output_dict
def train(train_loader, model, optimizer, epoch, lr_schedule, half=False): mean = torch.Tensor( np.array(configs.TRAIN.mean)[:, np.newaxis, np.newaxis]) mean = mean.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda() std = torch.Tensor(np.array(configs.TRAIN.std)[:, np.newaxis, np.newaxis]) std = std.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda() # Initialize the meters batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() end = time.time() for i, (input, target) in enumerate(train_loader): if configs.TRAIN.methods != 'augmix': input = input.cuda(non_blocking=True) else: input = torch.cat(input, 0).cuda(non_blocking=True) target = target.cuda(non_blocking=True) data_time.update(time.time() - end) # update learning rate lr = lr_schedule(epoch + (i + 1) / len(train_loader)) for param_group in optimizer.param_groups: param_group['lr'] = lr optimizer.zero_grad() input.sub_(mean).div_(std) lam = np.random.beta(configs.TRAIN.alpha, configs.TRAIN.alpha) if configs.TRAIN.methods == 'manifold' or configs.TRAIN.methods == 'graphcut': permuted_idx1 = np.random.permutation(input.size(0) // 4) permuted_idx2 = permuted_idx1 + input.size(0) // 4 permuted_idx3 = permuted_idx2 + input.size(0) // 4 permuted_idx4 = permuted_idx3 + input.size(0) // 4 permuted_idx = np.concatenate( [permuted_idx1, permuted_idx2, permuted_idx3, permuted_idx4], axis=0) else: permuted_idx = torch.tensor(np.random.permutation(input.size(0))) if configs.TRAIN.methods == 'input': input = lam * input + (1 - lam) * input[permuted_idx] elif configs.TRAIN.methods == 'cutmix': input, lam = mixup_box(input, lam=lam, permuted_idx=permuted_idx) elif configs.TRAIN.methods == 'augmix': logit = model(input) logit_clean, logit_aug1, logit_aug2 = torch.split( logit, logit.size(0) // 3) output = logit_clean p_clean = F.softmax(logit_clean, dim=1) p_aug1 = F.softmax(logit_aug1, dim=1) p_aug2 = F.softmax(logit_aug2, dim=1) p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log() loss_JSD = 4 * ( F.kl_div(p_mixture, p_clean, reduction='batchmean') + F.kl_div(p_mixture, p_aug1, reduction='batchmean') + F.kl_div(p_mixture, p_aug2, reduction='batchmean')) elif configs.TRAIN.methods == 'graphcut': input_var = Variable(input, requires_grad=True) output = model(input_var) loss_clean = criterion(output, target) if half: with amp.scale_loss(loss_clean, optimizer) as scaled_loss: scaled_loss.backward() else: loss_clean.backward() unary = torch.sqrt(torch.mean(input_var.grad**2, dim=1)) block_num = 2**(np.random.randint(1, 5)) mask = get_mask(input, unary, block_num, permuted_idx, alpha=lam, mean=mean, std=std) output, lam = model(input, graphcut=True, permuted_idx=permuted_idx1, block_num=block_num, mask=mask, unary=unary) if configs.TRAIN.methods == 'manifold': output = model(input, manifold=True, lam=lam, permuted_idx=permuted_idx1) elif configs.TRAIN.methods != 'augmix' and configs.TRAIN.methods != 'graphcut': output = model(input) if configs.TRAIN.methods == 'nat': loss = criterion(output, target) elif configs.TRAIN.methods == 'augmix': loss = criterion(output, target) + loss_JSD else: loss = lam * criterion_batch(output, target) + ( 1 - lam) * criterion_batch(output, target[permuted_idx]) loss = torch.mean(loss) # compute gradient and do SGD step #optimizer.zero_grad() if half: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % configs.TRAIN.print_freq == 0: print('Train Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {cls_loss.val:.4f} ({cls_loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t' 'LR {lr:.3f}'.format(epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, top1=top1, top5=top5, cls_loss=losses, lr=lr)) sys.stdout.flush()
err_d = err_d_real + err_d_fake d_optimizer.step() ######################## # (2) Update G network # ######################## generator.zero_grad() output = discriminator(fake_images) err_g = criterion(output, real_label) D_G_z2 = output.mean().item() if args.classifier: output = F.log_softmax(classifier(fake_images), dim=1) uniform_dist = torch.zeros((true_batch_size, 10), device=device).fill_((1. / 10)) err_kl = args.kl_beta * F.kl_div(output, uniform_dist, reduction='batchmean') err_g += err_kl err_g.backward() g_optimizer.step() if i % 100 == 99: if args.classifier: print( '[%3d|%3d] Loss_D: %.4f Loss_G: %.4f Loss_C: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, i + 1, err_d.item(), err_g.item(), err_kl.item(), D_x, D_G_z1, D_G_z2) ) else: print( '[%3d|%3d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, i + 1, err_d.item(), err_g.item(), D_x, D_G_z1, D_G_z2)
def distillation(y, teacher_scores, T): p = F.log_softmax(y / T, dim=1) q = F.softmax(teacher_scores / T, dim=1) l_kl = F.kl_div(p, q, size_average=False) * (T ** 2) / y.shape[0] return l_kl
def cross_entropy(p,q): return F.kl_div(F.log_softmax(q), F.softmax(p), reduction='sum')
def thread_target(i): while not exit_event.isSet(): # If the run event is not set, the thread just waits. if not run_event1.wait(0.001): continue ###################################### # Phase 1: Forward and Separate Loss # ###################################### TVT = TVTs[i] model_w = model_ws[i] ims = ims_list[i] labels = labels_list[i] optimizer = optimizers[i] ims_var = Variable(TVT(torch.from_numpy(ims).float())) labels_t = TVT(torch.from_numpy(labels).long()) labels_var = Variable(labels_t) global_feat, local_feat, logits = model_w(ims_var) probs = F.softmax(logits) log_probs = F.log_softmax(logits) g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss( g_tri_loss, global_feat, labels_t, normalize_feature=cfg.normalize_feature) if cfg.l_loss_weight == 0: l_loss, l_dist_mat = 0, 0 elif cfg.local_dist_own_hard_sample: # Let local distance find its own hard samples. l_loss, l_dist_ap, l_dist_an, l_dist_mat = local_loss( l_tri_loss, local_feat, None, None, labels_t, normalize_feature=cfg.normalize_feature) else: l_loss, l_dist_ap, l_dist_an = local_loss( l_tri_loss, local_feat, p_inds, n_inds, labels_t, normalize_feature=cfg.normalize_feature) l_dist_mat = 0 id_loss = 0 if cfg.id_loss_weight > 0: id_loss = id_criterion(logits, labels_var) probs_list[i] = probs g_dist_mat_list[i] = g_dist_mat l_dist_mat_list[i] = l_dist_mat done_list1[i] = True # Wait for event to be set, meanwhile checking if need to exit. while True: phase2_ready = run_event2.wait(0.001) if exit_event.isSet(): return if phase2_ready: break ##################################### # Phase 2: Mutual Loss and Backward # ##################################### # Probability Mutual Loss (KL Loss) pm_loss = 0 if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0): for j in range(cfg.num_models): if j != i: pm_loss += F.kl_div(log_probs, TVT(probs_list[j]).detach(), False) pm_loss /= 1. * (cfg.num_models - 1) * len(ims) # Global Distance Mutual Loss (L2 Loss) gdm_loss = 0 if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0): for j in range(cfg.num_models): if j != i: gdm_loss += torch.sum(torch.pow( g_dist_mat - TVT(g_dist_mat_list[j]).detach(), 2)) gdm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims) # Local Distance Mutual Loss (L2 Loss) ldm_loss = 0 if (cfg.num_models > 1) \ and cfg.local_dist_own_hard_sample \ and (cfg.ldm_loss_weight > 0): for j in range(cfg.num_models): if j != i: ldm_loss += torch.sum(torch.pow( l_dist_mat - TVT(l_dist_mat_list[j]).detach(), 2)) ldm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims) loss = g_loss * cfg.g_loss_weight \ + l_loss * cfg.l_loss_weight \ + id_loss * cfg.id_loss_weight \ + pm_loss * cfg.pm_loss_weight \ + gdm_loss * cfg.gdm_loss_weight \ + ldm_loss * cfg.ldm_loss_weight optimizer.zero_grad() loss.backward() optimizer.step() ################################## # Step Log For One of the Models # ################################## # These meters are outer-scope variables # Just record for the first model if i == 0: # precision g_prec = (g_dist_an > g_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean() g_d_ap = g_dist_ap.data.mean() g_d_an = g_dist_an.data.mean() g_prec_meter.update(g_prec) g_m_meter.update(g_m) g_dist_ap_meter.update(g_d_ap) g_dist_an_meter.update(g_d_an) g_loss_meter.update(to_scalar(g_loss)) if cfg.l_loss_weight > 0: # precision l_prec = (l_dist_an > l_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean() l_d_ap = l_dist_ap.data.mean() l_d_an = l_dist_an.data.mean() l_prec_meter.update(l_prec) l_m_meter.update(l_m) l_dist_ap_meter.update(l_d_ap) l_dist_an_meter.update(l_d_an) l_loss_meter.update(to_scalar(l_loss)) if cfg.id_loss_weight > 0: id_loss_meter.update(to_scalar(id_loss)) if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0): pm_loss_meter.update(to_scalar(pm_loss)) if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0): gdm_loss_meter.update(to_scalar(gdm_loss)) if (cfg.num_models > 1) \ and cfg.local_dist_own_hard_sample \ and (cfg.ldm_loss_weight > 0): ldm_loss_meter.update(to_scalar(ldm_loss)) loss_meter.update(to_scalar(loss)) ################### # End Up One Step # ################### run_event1.clear() run_event2.clear() done_list2[i] = True
def train(trainloader, model, criterion, optimizer, epoch, use_cuda, teacher): # switch to train mode model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() bar = Bar('Processing', max=len(trainloader)) for batch_idx, (inputs, targets) in enumerate(trainloader): # measure data loading time data_time.update(time.time() - end) if use_cuda: inputs, targets = inputs.cuda(), targets.cuda(async=True) inputs, targets = torch.autograd.Variable( inputs), torch.autograd.Variable(targets) # compute output outputs = model(inputs) #outputs --> logits batch_size x 10 teacher_outputs = teacher(inputs) if args.caruana: soft_loss = F.mse_loss(outputs, teacher_outputs.detach(), reduction="mean") else: soft_log_probs = F.log_softmax(outputs / args.temperature, dim=1) soft_targets = F.softmax(teacher_outputs / args.temperature, dim=1) soft_loss = F.kl_div(soft_log_probs, soft_targets.detach(), size_average=False) / outputs.shape[0] hard_loss = criterion(outputs, targets) loss = args.alpha * soft_loss * (1 - args.alpha) * hard_loss # measure accuracy and record loss prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) losses.update(loss.data.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() #----------------------------------------- for k, m in enumerate(model.modules()): # print(k, m) if isinstance(m, nn.Conv2d): weight_copy = m.weight.data.abs().clone() mask = weight_copy.gt(0).float().cuda() m.weight.grad.data.mul_(mask) #----------------------------------------- optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( batch=batch_idx + 1, size=len(trainloader), data=data_time.avg, bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, top1=top1.avg, top5=top5.avg, ) bar.next() bar.finish() return (losses.avg, top1.avg)