def _run_evaluate(self, test_data):
     pr_curve_data = []
     cm = ConfusionMatrix(self.classes)
     accuracy_list = []
     # test_iter = self._create_iter(test_data, self.config.wbatchsize,
     #                               random_shuffler=utils.identity_fun)
     test_iter = self.chunks(test_data)
     for test_batch in test_iter:
         test_batch = batch_utils.seq_pad_concat(test_batch, -1)
         pred, acc = self._predict_batch(cm, test_batch)
         accuracy_list.append(acc)
         pr_curve_data.append(
             (F.softmax(pred, -1)[:, 1].data, test_batch.labels.data))
     accuracy = 100 * (sum(accuracy_list) / len(test_data))
     return cm, accuracy, pr_curve_data
예제 #2
0
    def _run_evaluate(self, test_data, addn_test):
        pr_curve_data = []
        cm = ConfusionMatrix(self.classes)
        accuracy_list = []
        # test_iter = self._create_iter(test_data, self.config.wbatchsize,
        #                               random_shuffler=utils.identity_fun)
        test_iter = self.chunks(test_data)

        for batch_index, test_batch in enumerate(test_iter):
            addn_batch = addn_test[batch_index * 15:(batch_index + 1) * 15]
            test_batch = batch_utils.seq_pad_concat(test_batch, -1)
            #             print(addn_batch.shape)
            try:
                pred, acc = self._predict_batch(cm, test_batch, addn_batch)
            except:
                continue
            accuracy_list.append(acc)
            pr_curve_data.append(
                (F.softmax(pred, -1)[:, 1].data, test_batch.labels.data))
        accuracy = 100 * (sum(accuracy_list) / len(test_data))
        return cm, accuracy, pr_curve_data
예제 #3
0
    def _run_epoch(self, train_data, dev_data, unlabel_data, addn_data,
                   addn_data_unlab, addn_dev, ek, ek_t, ek_u, graph_embs,
                   graph_embs_t, graph_embs_u):
        addn_dev.cuda()
        ek_t.cuda()
        graph_embs_t.cuda()
        report_stats = utils.Statistics()
        cm = ConfusionMatrix(self.classes)
        _, seq_data = list(zip(*train_data))
        total_seq_words = len(list(itertools.chain.from_iterable(seq_data)))
        iter_per_epoch = (1.5 * total_seq_words) // self.config.wbatchsize

        self.encoder.train()
        self.clf.train()

        train_iter = self._create_iter(train_data, self.config.wbatchsize)

        unlabel_iter = self._create_iter(unlabel_data,
                                         self.config.wbatchsize_unlabel)

        sofar = 0
        sofar_1 = 0
        for batch_index, train_batch_raw in enumerate(train_iter):
            seq_iter = list(zip(*train_batch_raw))[1]
            seq_words = len(list(itertools.chain.from_iterable(seq_iter)))
            report_stats.n_words += seq_words
            self.global_steps += 1

            # self.enc_clf_opt.zero_grad()
            if self.config.add_noise:
                train_batch_raw = add_noise(train_batch_raw,
                                            self.config.noise_dropout,
                                            self.config.random_permutation)
            train_batch = batch_utils.seq_pad_concat(train_batch_raw, -1)

            train_embedded = self.embedder(train_batch)

            memory_bank_train, enc_final_train = self.encoder(
                train_embedded, train_batch)

            if self.config.lambda_vat > 0 or self.config.lambda_ae > 0 or self.config.lambda_entropy:
                try:
                    unlabel_batch_raw = next(unlabel_iter)
                except StopIteration:
                    unlabel_iter = self._create_iter(
                        unlabel_data, self.config.wbatchsize_unlabel)
                    unlabel_batch_raw = next(unlabel_iter)

                if self.config.add_noise:
                    unlabel_batch_raw = add_noise(
                        unlabel_batch_raw, self.config.noise_dropout,
                        self.config.random_permutation)
                unlabel_batch = batch_utils.seq_pad_concat(
                    unlabel_batch_raw, -1)
                unlabel_embedded = self.embedder(unlabel_batch)
                memory_bank_unlabel, enc_final_unlabel = self.encoder(
                    unlabel_embedded, unlabel_batch)
                addn_batch_unlab = retAddnBatch(addn_data_unlab,
                                                memory_bank_unlabel.shape[0],
                                                sofar_1).cuda()
                ek_batch_unlab = retAddnBatch(ek_u,
                                              memory_bank_unlabel.shape[0],
                                              sofar_1).cuda()
                graph_embs_unlab = retAddnBatch(graph_embs_u,
                                                memory_bank_unlabel.shape[0],
                                                sofar_1).cuda()
                sofar_1 += addn_batch_unlab.shape[0]
                if sofar_1 >= ek_u.shape[0]:
                    sofar_1 = 0
            addn_batch = retAddnBatch(addn_data, memory_bank_train.shape[0],
                                      sofar).cuda()
            ek_batch = retAddnBatch(ek, memory_bank_train.shape[0],
                                    sofar).cuda()
            graph_embs_batch = retAddnBatch(graph_embs,
                                            memory_bank_train.shape[0],
                                            sofar).cuda()
            sofar += addn_batch.shape[0]
            if sofar >= ek.shape[0]:
                sofar = 0
            pred = self.clf(memory_bank_train, addn_batch, ek_batch,
                            enc_final_train, graph_embs_batch)
            accuracy = self.get_accuracy(cm, pred.data,
                                         train_batch.labels.data)
            lclf = self.clf_loss(pred, train_batch.labels)

            lat = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            lvat = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            if self.config.lambda_at > 0:
                lat = at_loss(
                    self.embedder,
                    self.encoder,
                    self.clf,
                    train_batch,
                    addn_batch,
                    ek_batch,
                    graph_embs_batch,
                    perturb_norm_length=self.config.perturb_norm_length)

            if self.config.lambda_vat > 0:
                lvat_train = vat_loss(
                    self.embedder,
                    self.encoder,
                    self.clf,
                    train_batch,
                    addn_batch,
                    ek_batch,
                    graph_embs_batch,
                    p_logit=pred,
                    perturb_norm_length=self.config.perturb_norm_length)
                if self.config.inc_unlabeled_loss:
                    if memory_bank_unlabel.shape[0] != ek_batch_unlab.shape[0]:
                        print(
                            f'Skipping; Unequal Shapes: {memory_bank_unlabel.shape} and {ek_batch_unlab.shape}'
                        )
                        continue
                    else:
                        lvat_unlabel = vat_loss(
                            self.embedder,
                            self.encoder,
                            self.clf,
                            unlabel_batch,
                            addn_batch_unlab,
                            ek_batch_unlab,
                            graph_embs_unlab,
                            p_logit=self.clf(memory_bank_unlabel,
                                             addn_batch_unlab, ek_batch_unlab,
                                             enc_final_unlabel,
                                             graph_embs_unlab),
                            perturb_norm_length=self.config.perturb_norm_length
                        )
                    if self.config.unlabeled_loss_type == "AvgTrainUnlabel":
                        lvat = 0.5 * (lvat_train + lvat_unlabel)
                    elif self.config.unlabeled_loss_type == "Unlabel":
                        lvat = lvat_unlabel
                else:
                    lvat = lvat_train

            lentropy = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            if self.config.lambda_entropy > 0:
                lentropy_train = entropy_loss(pred)
                if self.config.inc_unlabeled_loss:
                    lentropy_unlabel = entropy_loss(
                        self.clf(memory_bank_unlabel, addn_batch_unlab,
                                 ek_batch_unlab, enc_final_unlabel,
                                 graph_embs_unlab))
                    if self.config.unlabeled_loss_type == "AvgTrainUnlabel":
                        lentropy = 0.5 * (lentropy_train + lentropy_unlabel)
                    elif self.config.unlabeled_loss_type == "Unlabel":
                        lentropy = lentropy_unlabel
                else:
                    lentropy = lentropy_train

            lae = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            if self.config.lambda_ae > 0:
                lae = self.ae(memory_bank_unlabel, enc_final_unlabel,
                              unlabel_batch.sent_len, unlabel_batch_raw)

            ltotal = (self.config.lambda_clf * lclf) + \
                     (self.config.lambda_ae * lae) + \
                     (self.config.lambda_at * lat) + \
                     (self.config.lambda_vat * lvat) + \
                     (self.config.lambda_entropy * lentropy)

            report_stats.clf_loss += lclf.data.cpu().numpy()
            report_stats.at_loss += lat.data.cpu().numpy()
            report_stats.vat_loss += lvat.data.cpu().numpy()
            report_stats.ae_loss += lae.data.cpu().numpy()
            report_stats.entropy_loss += lentropy.data.cpu().numpy()
            report_stats.n_sent += len(pred)
            report_stats.n_correct += accuracy
            self.enc_clf_opt.zero_grad()
            ltotal.backward()

            params_list = self._get_trainabe_modules()
            # Excluding embedder form norm constraint when AT or VAT
            if not self.config.normalize_embedding:
                params_list += list(self.embedder.parameters())

            norm = torch.nn.utils.clip_grad_norm(params_list,
                                                 self.config.max_norm)
            report_stats.grad_norm += norm
            self.enc_clf_opt.step()
            if self.config.scheduler == "ExponentialLR":
                self.scheduler.step()
            self.ema_embedder.apply(self.embedder.named_parameters())
            self.ema_encoder.apply(self.encoder.named_parameters())
            self.ema_clf.apply(self.clf.named_parameters())

            report_func(self.epoch, batch_index, iter_per_epoch, self.time_s,
                        report_stats, self.config.report_every, self.logger)

            if self.global_steps % self.config.eval_steps == 0:
                cm_, accuracy, prc_dev = self._run_evaluate(
                    dev_data, addn_dev, ek_t, graph_embs_t)
                self.logger.info(
                    "- dev accuracy {} | best dev accuracy {} ".format(
                        accuracy, self.best_accuracy))
                self.writer.add_scalar("Dev_Accuracy", accuracy,
                                       self.global_steps)
                pred_, lab_ = zip(*prc_dev)
                pred_ = torch.cat(pred_)
                lab_ = torch.cat(lab_)
                self.writer.add_pr_curve("Dev PR-Curve", lab_, pred_,
                                         self.global_steps)
                pprint.pprint(cm_)
                pprint.pprint(cm_.get_all_metrics())
                if accuracy > self.best_accuracy:
                    self.logger.info("- new best score!")
                    self.best_accuracy = accuracy
                    self._save_model()
                if self.config.scheduler == "ReduceLROnPlateau":
                    self.scheduler.step(accuracy)
                self.encoder.train()
                #                 self.embedder.train()
                self.clf.train()

                if self.config.weight_decay > 0:
                    print(">> Square Norm: %1.4f " % self._get_l2_norm_loss())

        cm, train_accuracy, _ = self._run_evaluate(train_data, addn_data, ek,
                                                   graph_embs)
        self.logger.info("- Train accuracy  {}".format(train_accuracy))
        pprint.pprint(cm.get_all_metrics())

        cm, dev_accuracy, _ = self._run_evaluate(dev_data, addn_dev, ek_t,
                                                 graph_embs_t)
        self.logger.info("- Dev accuracy  {} | best dev accuracy {}".format(
            dev_accuracy, self.best_accuracy))
        pprint.pprint(cm.get_all_metrics())
        self.writer.add_scalars("Overall_Accuracy", {
            "Train_Accuracy": train_accuracy,
            "Dev_Accuracy": dev_accuracy
        }, self.global_steps)
        return dev_accuracy