def validate(validation_loader, model, loss_fn, device, print_frequency = 2,curr_epoch=1,column_split_order=[]): history = { 'loss': [], 'accuracy':[], 'batch_time':[], 'classification_metrics':None, 'confusion_matrix':None } batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(validation_loader), [batch_time, data_time, losses, top1], prefix="Epoch: [{}]".format(curr_epoch)) # switch to train mode # switch to evaluate mode model.eval() conf_matrix = None if len(column_split_order) > 0: conf_matrix = ConfusionMatrix(column_split_order) with torch.no_grad(): # https://github.com/pytorch/pytorch/issues/16417#issuecomment-566654504 end = time.time() for i, (input_ids,attention_mask, labels) in enumerate(validation_loader): # measure data loading time data_time.update(time.time() - end) input_ids = input_ids.to(device, non_blocking=True) attention_mask = attention_mask.to(device, non_blocking=True) labels = torch.argmax(labels,dim=1).to(device,non_blocking=True) # compute output output = model(input_ids,attention_mask=attention_mask) loss = loss_fn(output, labels) # measure accuracy and record loss acc1 = accuracy(output, labels,conf_matrix=conf_matrix) losses.update(loss.item(), input_ids.size(0)) top1.update(acc1[0].tolist()[0], input_ids.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_frequency == 0: progress.display(i) history['accuracy'].append(float(top1.avg)) history['loss'].append(float(losses.avg)) history['batch_time'].append(float(batch_time.avg)) if conf_matrix is not None: history['classification_metrics'] = conf_matrix.get_all_metrics() history['confusion_matrix'] = str(conf_matrix) return history
def _run_epoch(self, train_data, dev_data, unlabel_data, addn_data, addn_data_unlab, addn_dev, ek, ek_t, ek_u, graph_embs, graph_embs_t, graph_embs_u): addn_dev.cuda() ek_t.cuda() graph_embs_t.cuda() report_stats = utils.Statistics() cm = ConfusionMatrix(self.classes) _, seq_data = list(zip(*train_data)) total_seq_words = len(list(itertools.chain.from_iterable(seq_data))) iter_per_epoch = (1.5 * total_seq_words) // self.config.wbatchsize self.encoder.train() self.clf.train() train_iter = self._create_iter(train_data, self.config.wbatchsize) unlabel_iter = self._create_iter(unlabel_data, self.config.wbatchsize_unlabel) sofar = 0 sofar_1 = 0 for batch_index, train_batch_raw in enumerate(train_iter): seq_iter = list(zip(*train_batch_raw))[1] seq_words = len(list(itertools.chain.from_iterable(seq_iter))) report_stats.n_words += seq_words self.global_steps += 1 # self.enc_clf_opt.zero_grad() if self.config.add_noise: train_batch_raw = add_noise(train_batch_raw, self.config.noise_dropout, self.config.random_permutation) train_batch = batch_utils.seq_pad_concat(train_batch_raw, -1) train_embedded = self.embedder(train_batch) memory_bank_train, enc_final_train = self.encoder( train_embedded, train_batch) if self.config.lambda_vat > 0 or self.config.lambda_ae > 0 or self.config.lambda_entropy: try: unlabel_batch_raw = next(unlabel_iter) except StopIteration: unlabel_iter = self._create_iter( unlabel_data, self.config.wbatchsize_unlabel) unlabel_batch_raw = next(unlabel_iter) if self.config.add_noise: unlabel_batch_raw = add_noise( unlabel_batch_raw, self.config.noise_dropout, self.config.random_permutation) unlabel_batch = batch_utils.seq_pad_concat( unlabel_batch_raw, -1) unlabel_embedded = self.embedder(unlabel_batch) memory_bank_unlabel, enc_final_unlabel = self.encoder( unlabel_embedded, unlabel_batch) addn_batch_unlab = retAddnBatch(addn_data_unlab, memory_bank_unlabel.shape[0], sofar_1).cuda() ek_batch_unlab = retAddnBatch(ek_u, memory_bank_unlabel.shape[0], sofar_1).cuda() graph_embs_unlab = retAddnBatch(graph_embs_u, memory_bank_unlabel.shape[0], sofar_1).cuda() sofar_1 += addn_batch_unlab.shape[0] if sofar_1 >= ek_u.shape[0]: sofar_1 = 0 addn_batch = retAddnBatch(addn_data, memory_bank_train.shape[0], sofar).cuda() ek_batch = retAddnBatch(ek, memory_bank_train.shape[0], sofar).cuda() graph_embs_batch = retAddnBatch(graph_embs, memory_bank_train.shape[0], sofar).cuda() sofar += addn_batch.shape[0] if sofar >= ek.shape[0]: sofar = 0 pred = self.clf(memory_bank_train, addn_batch, ek_batch, enc_final_train, graph_embs_batch) accuracy = self.get_accuracy(cm, pred.data, train_batch.labels.data) lclf = self.clf_loss(pred, train_batch.labels) lat = Variable( torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE)) lvat = Variable( torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE)) if self.config.lambda_at > 0: lat = at_loss( self.embedder, self.encoder, self.clf, train_batch, addn_batch, ek_batch, graph_embs_batch, perturb_norm_length=self.config.perturb_norm_length) if self.config.lambda_vat > 0: lvat_train = vat_loss( self.embedder, self.encoder, self.clf, train_batch, addn_batch, ek_batch, graph_embs_batch, p_logit=pred, perturb_norm_length=self.config.perturb_norm_length) if self.config.inc_unlabeled_loss: if memory_bank_unlabel.shape[0] != ek_batch_unlab.shape[0]: print( f'Skipping; Unequal Shapes: {memory_bank_unlabel.shape} and {ek_batch_unlab.shape}' ) continue else: lvat_unlabel = vat_loss( self.embedder, self.encoder, self.clf, unlabel_batch, addn_batch_unlab, ek_batch_unlab, graph_embs_unlab, p_logit=self.clf(memory_bank_unlabel, addn_batch_unlab, ek_batch_unlab, enc_final_unlabel, graph_embs_unlab), perturb_norm_length=self.config.perturb_norm_length ) if self.config.unlabeled_loss_type == "AvgTrainUnlabel": lvat = 0.5 * (lvat_train + lvat_unlabel) elif self.config.unlabeled_loss_type == "Unlabel": lvat = lvat_unlabel else: lvat = lvat_train lentropy = Variable( torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE)) if self.config.lambda_entropy > 0: lentropy_train = entropy_loss(pred) if self.config.inc_unlabeled_loss: lentropy_unlabel = entropy_loss( self.clf(memory_bank_unlabel, addn_batch_unlab, ek_batch_unlab, enc_final_unlabel, graph_embs_unlab)) if self.config.unlabeled_loss_type == "AvgTrainUnlabel": lentropy = 0.5 * (lentropy_train + lentropy_unlabel) elif self.config.unlabeled_loss_type == "Unlabel": lentropy = lentropy_unlabel else: lentropy = lentropy_train lae = Variable( torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE)) if self.config.lambda_ae > 0: lae = self.ae(memory_bank_unlabel, enc_final_unlabel, unlabel_batch.sent_len, unlabel_batch_raw) ltotal = (self.config.lambda_clf * lclf) + \ (self.config.lambda_ae * lae) + \ (self.config.lambda_at * lat) + \ (self.config.lambda_vat * lvat) + \ (self.config.lambda_entropy * lentropy) report_stats.clf_loss += lclf.data.cpu().numpy() report_stats.at_loss += lat.data.cpu().numpy() report_stats.vat_loss += lvat.data.cpu().numpy() report_stats.ae_loss += lae.data.cpu().numpy() report_stats.entropy_loss += lentropy.data.cpu().numpy() report_stats.n_sent += len(pred) report_stats.n_correct += accuracy self.enc_clf_opt.zero_grad() ltotal.backward() params_list = self._get_trainabe_modules() # Excluding embedder form norm constraint when AT or VAT if not self.config.normalize_embedding: params_list += list(self.embedder.parameters()) norm = torch.nn.utils.clip_grad_norm(params_list, self.config.max_norm) report_stats.grad_norm += norm self.enc_clf_opt.step() if self.config.scheduler == "ExponentialLR": self.scheduler.step() self.ema_embedder.apply(self.embedder.named_parameters()) self.ema_encoder.apply(self.encoder.named_parameters()) self.ema_clf.apply(self.clf.named_parameters()) report_func(self.epoch, batch_index, iter_per_epoch, self.time_s, report_stats, self.config.report_every, self.logger) if self.global_steps % self.config.eval_steps == 0: cm_, accuracy, prc_dev = self._run_evaluate( dev_data, addn_dev, ek_t, graph_embs_t) self.logger.info( "- dev accuracy {} | best dev accuracy {} ".format( accuracy, self.best_accuracy)) self.writer.add_scalar("Dev_Accuracy", accuracy, self.global_steps) pred_, lab_ = zip(*prc_dev) pred_ = torch.cat(pred_) lab_ = torch.cat(lab_) self.writer.add_pr_curve("Dev PR-Curve", lab_, pred_, self.global_steps) pprint.pprint(cm_) pprint.pprint(cm_.get_all_metrics()) if accuracy > self.best_accuracy: self.logger.info("- new best score!") self.best_accuracy = accuracy self._save_model() if self.config.scheduler == "ReduceLROnPlateau": self.scheduler.step(accuracy) self.encoder.train() # self.embedder.train() self.clf.train() if self.config.weight_decay > 0: print(">> Square Norm: %1.4f " % self._get_l2_norm_loss()) cm, train_accuracy, _ = self._run_evaluate(train_data, addn_data, ek, graph_embs) self.logger.info("- Train accuracy {}".format(train_accuracy)) pprint.pprint(cm.get_all_metrics()) cm, dev_accuracy, _ = self._run_evaluate(dev_data, addn_dev, ek_t, graph_embs_t) self.logger.info("- Dev accuracy {} | best dev accuracy {}".format( dev_accuracy, self.best_accuracy)) pprint.pprint(cm.get_all_metrics()) self.writer.add_scalars("Overall_Accuracy", { "Train_Accuracy": train_accuracy, "Dev_Accuracy": dev_accuracy }, self.global_steps) return dev_accuracy