def _run_evaluate(self, test_data): pr_curve_data = [] cm = ConfusionMatrix(self.classes) accuracy_list = [] # test_iter = self._create_iter(test_data, self.config.wbatchsize, # random_shuffler=utils.identity_fun) test_iter = self.chunks(test_data) for test_batch in test_iter: test_batch = batch_utils.seq_pad_concat(test_batch, -1) pred, acc = self._predict_batch(cm, test_batch) accuracy_list.append(acc) pr_curve_data.append( (F.softmax(pred, -1)[:, 1].data, test_batch.labels.data)) accuracy = 100 * (sum(accuracy_list) / len(test_data)) return cm, accuracy, pr_curve_data
def _run_evaluate(self, test_data, addn_test): pr_curve_data = [] cm = ConfusionMatrix(self.classes) accuracy_list = [] # test_iter = self._create_iter(test_data, self.config.wbatchsize, # random_shuffler=utils.identity_fun) test_iter = self.chunks(test_data) for batch_index, test_batch in enumerate(test_iter): addn_batch = addn_test[batch_index * 15:(batch_index + 1) * 15] test_batch = batch_utils.seq_pad_concat(test_batch, -1) # print(addn_batch.shape) try: pred, acc = self._predict_batch(cm, test_batch, addn_batch) except: continue accuracy_list.append(acc) pr_curve_data.append( (F.softmax(pred, -1)[:, 1].data, test_batch.labels.data)) accuracy = 100 * (sum(accuracy_list) / len(test_data)) return cm, accuracy, pr_curve_data
def _run_epoch(self, train_data, dev_data, unlabel_data, addn_data, addn_data_unlab, addn_dev, ek, ek_t, ek_u, graph_embs, graph_embs_t, graph_embs_u): addn_dev.cuda() ek_t.cuda() graph_embs_t.cuda() report_stats = utils.Statistics() cm = ConfusionMatrix(self.classes) _, seq_data = list(zip(*train_data)) total_seq_words = len(list(itertools.chain.from_iterable(seq_data))) iter_per_epoch = (1.5 * total_seq_words) // self.config.wbatchsize self.encoder.train() self.clf.train() train_iter = self._create_iter(train_data, self.config.wbatchsize) unlabel_iter = self._create_iter(unlabel_data, self.config.wbatchsize_unlabel) sofar = 0 sofar_1 = 0 for batch_index, train_batch_raw in enumerate(train_iter): seq_iter = list(zip(*train_batch_raw))[1] seq_words = len(list(itertools.chain.from_iterable(seq_iter))) report_stats.n_words += seq_words self.global_steps += 1 # self.enc_clf_opt.zero_grad() if self.config.add_noise: train_batch_raw = add_noise(train_batch_raw, self.config.noise_dropout, self.config.random_permutation) train_batch = batch_utils.seq_pad_concat(train_batch_raw, -1) train_embedded = self.embedder(train_batch) memory_bank_train, enc_final_train = self.encoder( train_embedded, train_batch) if self.config.lambda_vat > 0 or self.config.lambda_ae > 0 or self.config.lambda_entropy: try: unlabel_batch_raw = next(unlabel_iter) except StopIteration: unlabel_iter = self._create_iter( unlabel_data, self.config.wbatchsize_unlabel) unlabel_batch_raw = next(unlabel_iter) if self.config.add_noise: unlabel_batch_raw = add_noise( unlabel_batch_raw, self.config.noise_dropout, self.config.random_permutation) unlabel_batch = batch_utils.seq_pad_concat( unlabel_batch_raw, -1) unlabel_embedded = self.embedder(unlabel_batch) memory_bank_unlabel, enc_final_unlabel = self.encoder( unlabel_embedded, unlabel_batch) addn_batch_unlab = retAddnBatch(addn_data_unlab, memory_bank_unlabel.shape[0], sofar_1).cuda() ek_batch_unlab = retAddnBatch(ek_u, memory_bank_unlabel.shape[0], sofar_1).cuda() graph_embs_unlab = retAddnBatch(graph_embs_u, memory_bank_unlabel.shape[0], sofar_1).cuda() sofar_1 += addn_batch_unlab.shape[0] if sofar_1 >= ek_u.shape[0]: sofar_1 = 0 addn_batch = retAddnBatch(addn_data, memory_bank_train.shape[0], sofar).cuda() ek_batch = retAddnBatch(ek, memory_bank_train.shape[0], sofar).cuda() graph_embs_batch = retAddnBatch(graph_embs, memory_bank_train.shape[0], sofar).cuda() sofar += addn_batch.shape[0] if sofar >= ek.shape[0]: sofar = 0 pred = self.clf(memory_bank_train, addn_batch, ek_batch, enc_final_train, graph_embs_batch) accuracy = self.get_accuracy(cm, pred.data, train_batch.labels.data) lclf = self.clf_loss(pred, train_batch.labels) lat = Variable( torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE)) lvat = Variable( torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE)) if self.config.lambda_at > 0: lat = at_loss( self.embedder, self.encoder, self.clf, train_batch, addn_batch, ek_batch, graph_embs_batch, perturb_norm_length=self.config.perturb_norm_length) if self.config.lambda_vat > 0: lvat_train = vat_loss( self.embedder, self.encoder, self.clf, train_batch, addn_batch, ek_batch, graph_embs_batch, p_logit=pred, perturb_norm_length=self.config.perturb_norm_length) if self.config.inc_unlabeled_loss: if memory_bank_unlabel.shape[0] != ek_batch_unlab.shape[0]: print( f'Skipping; Unequal Shapes: {memory_bank_unlabel.shape} and {ek_batch_unlab.shape}' ) continue else: lvat_unlabel = vat_loss( self.embedder, self.encoder, self.clf, unlabel_batch, addn_batch_unlab, ek_batch_unlab, graph_embs_unlab, p_logit=self.clf(memory_bank_unlabel, addn_batch_unlab, ek_batch_unlab, enc_final_unlabel, graph_embs_unlab), perturb_norm_length=self.config.perturb_norm_length ) if self.config.unlabeled_loss_type == "AvgTrainUnlabel": lvat = 0.5 * (lvat_train + lvat_unlabel) elif self.config.unlabeled_loss_type == "Unlabel": lvat = lvat_unlabel else: lvat = lvat_train lentropy = Variable( torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE)) if self.config.lambda_entropy > 0: lentropy_train = entropy_loss(pred) if self.config.inc_unlabeled_loss: lentropy_unlabel = entropy_loss( self.clf(memory_bank_unlabel, addn_batch_unlab, ek_batch_unlab, enc_final_unlabel, graph_embs_unlab)) if self.config.unlabeled_loss_type == "AvgTrainUnlabel": lentropy = 0.5 * (lentropy_train + lentropy_unlabel) elif self.config.unlabeled_loss_type == "Unlabel": lentropy = lentropy_unlabel else: lentropy = lentropy_train lae = Variable( torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE)) if self.config.lambda_ae > 0: lae = self.ae(memory_bank_unlabel, enc_final_unlabel, unlabel_batch.sent_len, unlabel_batch_raw) ltotal = (self.config.lambda_clf * lclf) + \ (self.config.lambda_ae * lae) + \ (self.config.lambda_at * lat) + \ (self.config.lambda_vat * lvat) + \ (self.config.lambda_entropy * lentropy) report_stats.clf_loss += lclf.data.cpu().numpy() report_stats.at_loss += lat.data.cpu().numpy() report_stats.vat_loss += lvat.data.cpu().numpy() report_stats.ae_loss += lae.data.cpu().numpy() report_stats.entropy_loss += lentropy.data.cpu().numpy() report_stats.n_sent += len(pred) report_stats.n_correct += accuracy self.enc_clf_opt.zero_grad() ltotal.backward() params_list = self._get_trainabe_modules() # Excluding embedder form norm constraint when AT or VAT if not self.config.normalize_embedding: params_list += list(self.embedder.parameters()) norm = torch.nn.utils.clip_grad_norm(params_list, self.config.max_norm) report_stats.grad_norm += norm self.enc_clf_opt.step() if self.config.scheduler == "ExponentialLR": self.scheduler.step() self.ema_embedder.apply(self.embedder.named_parameters()) self.ema_encoder.apply(self.encoder.named_parameters()) self.ema_clf.apply(self.clf.named_parameters()) report_func(self.epoch, batch_index, iter_per_epoch, self.time_s, report_stats, self.config.report_every, self.logger) if self.global_steps % self.config.eval_steps == 0: cm_, accuracy, prc_dev = self._run_evaluate( dev_data, addn_dev, ek_t, graph_embs_t) self.logger.info( "- dev accuracy {} | best dev accuracy {} ".format( accuracy, self.best_accuracy)) self.writer.add_scalar("Dev_Accuracy", accuracy, self.global_steps) pred_, lab_ = zip(*prc_dev) pred_ = torch.cat(pred_) lab_ = torch.cat(lab_) self.writer.add_pr_curve("Dev PR-Curve", lab_, pred_, self.global_steps) pprint.pprint(cm_) pprint.pprint(cm_.get_all_metrics()) if accuracy > self.best_accuracy: self.logger.info("- new best score!") self.best_accuracy = accuracy self._save_model() if self.config.scheduler == "ReduceLROnPlateau": self.scheduler.step(accuracy) self.encoder.train() # self.embedder.train() self.clf.train() if self.config.weight_decay > 0: print(">> Square Norm: %1.4f " % self._get_l2_norm_loss()) cm, train_accuracy, _ = self._run_evaluate(train_data, addn_data, ek, graph_embs) self.logger.info("- Train accuracy {}".format(train_accuracy)) pprint.pprint(cm.get_all_metrics()) cm, dev_accuracy, _ = self._run_evaluate(dev_data, addn_dev, ek_t, graph_embs_t) self.logger.info("- Dev accuracy {} | best dev accuracy {}".format( dev_accuracy, self.best_accuracy)) pprint.pprint(cm.get_all_metrics()) self.writer.add_scalars("Overall_Accuracy", { "Train_Accuracy": train_accuracy, "Dev_Accuracy": dev_accuracy }, self.global_steps) return dev_accuracy