def validate(validation_loader, model, loss_fn, device, print_frequency = 2,curr_epoch=1,column_split_order=[]):
    history = {
        'loss': [],
        'accuracy':[],
        'batch_time':[],
        'classification_metrics':None,
        'confusion_matrix':None
    }
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(validation_loader),
        [batch_time, data_time, losses, top1],
        prefix="Epoch: [{}]".format(curr_epoch))

    # switch to train mode
        # switch to evaluate mode
    model.eval()
    conf_matrix = None
    if len(column_split_order) > 0:
        conf_matrix = ConfusionMatrix(column_split_order)

    with torch.no_grad():
        # https://github.com/pytorch/pytorch/issues/16417#issuecomment-566654504
        end = time.time()
        for i, (input_ids,attention_mask, labels) in enumerate(validation_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            input_ids = input_ids.to(device, non_blocking=True)
            attention_mask = attention_mask.to(device, non_blocking=True)
            labels = torch.argmax(labels,dim=1).to(device,non_blocking=True)
            # compute output
            output = model(input_ids,attention_mask=attention_mask)

            loss = loss_fn(output, labels)

            # measure accuracy and record loss
            acc1 = accuracy(output, labels,conf_matrix=conf_matrix)

            losses.update(loss.item(), input_ids.size(0))
            top1.update(acc1[0].tolist()[0], input_ids.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if i % print_frequency == 0:
                progress.display(i)
        
        history['accuracy'].append(float(top1.avg))
        history['loss'].append(float(losses.avg))
        history['batch_time'].append(float(batch_time.avg))
        if conf_matrix is not None:
            history['classification_metrics'] = conf_matrix.get_all_metrics()
            history['confusion_matrix'] = str(conf_matrix)

    return history
Example #2
0
    def _run_epoch(self, train_data, dev_data, unlabel_data, addn_data,
                   addn_data_unlab, addn_dev, ek, ek_t, ek_u, graph_embs,
                   graph_embs_t, graph_embs_u):
        addn_dev.cuda()
        ek_t.cuda()
        graph_embs_t.cuda()
        report_stats = utils.Statistics()
        cm = ConfusionMatrix(self.classes)
        _, seq_data = list(zip(*train_data))
        total_seq_words = len(list(itertools.chain.from_iterable(seq_data)))
        iter_per_epoch = (1.5 * total_seq_words) // self.config.wbatchsize

        self.encoder.train()
        self.clf.train()

        train_iter = self._create_iter(train_data, self.config.wbatchsize)

        unlabel_iter = self._create_iter(unlabel_data,
                                         self.config.wbatchsize_unlabel)

        sofar = 0
        sofar_1 = 0
        for batch_index, train_batch_raw in enumerate(train_iter):
            seq_iter = list(zip(*train_batch_raw))[1]
            seq_words = len(list(itertools.chain.from_iterable(seq_iter)))
            report_stats.n_words += seq_words
            self.global_steps += 1

            # self.enc_clf_opt.zero_grad()
            if self.config.add_noise:
                train_batch_raw = add_noise(train_batch_raw,
                                            self.config.noise_dropout,
                                            self.config.random_permutation)
            train_batch = batch_utils.seq_pad_concat(train_batch_raw, -1)

            train_embedded = self.embedder(train_batch)

            memory_bank_train, enc_final_train = self.encoder(
                train_embedded, train_batch)

            if self.config.lambda_vat > 0 or self.config.lambda_ae > 0 or self.config.lambda_entropy:
                try:
                    unlabel_batch_raw = next(unlabel_iter)
                except StopIteration:
                    unlabel_iter = self._create_iter(
                        unlabel_data, self.config.wbatchsize_unlabel)
                    unlabel_batch_raw = next(unlabel_iter)

                if self.config.add_noise:
                    unlabel_batch_raw = add_noise(
                        unlabel_batch_raw, self.config.noise_dropout,
                        self.config.random_permutation)
                unlabel_batch = batch_utils.seq_pad_concat(
                    unlabel_batch_raw, -1)
                unlabel_embedded = self.embedder(unlabel_batch)
                memory_bank_unlabel, enc_final_unlabel = self.encoder(
                    unlabel_embedded, unlabel_batch)
                addn_batch_unlab = retAddnBatch(addn_data_unlab,
                                                memory_bank_unlabel.shape[0],
                                                sofar_1).cuda()
                ek_batch_unlab = retAddnBatch(ek_u,
                                              memory_bank_unlabel.shape[0],
                                              sofar_1).cuda()
                graph_embs_unlab = retAddnBatch(graph_embs_u,
                                                memory_bank_unlabel.shape[0],
                                                sofar_1).cuda()
                sofar_1 += addn_batch_unlab.shape[0]
                if sofar_1 >= ek_u.shape[0]:
                    sofar_1 = 0
            addn_batch = retAddnBatch(addn_data, memory_bank_train.shape[0],
                                      sofar).cuda()
            ek_batch = retAddnBatch(ek, memory_bank_train.shape[0],
                                    sofar).cuda()
            graph_embs_batch = retAddnBatch(graph_embs,
                                            memory_bank_train.shape[0],
                                            sofar).cuda()
            sofar += addn_batch.shape[0]
            if sofar >= ek.shape[0]:
                sofar = 0
            pred = self.clf(memory_bank_train, addn_batch, ek_batch,
                            enc_final_train, graph_embs_batch)
            accuracy = self.get_accuracy(cm, pred.data,
                                         train_batch.labels.data)
            lclf = self.clf_loss(pred, train_batch.labels)

            lat = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            lvat = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            if self.config.lambda_at > 0:
                lat = at_loss(
                    self.embedder,
                    self.encoder,
                    self.clf,
                    train_batch,
                    addn_batch,
                    ek_batch,
                    graph_embs_batch,
                    perturb_norm_length=self.config.perturb_norm_length)

            if self.config.lambda_vat > 0:
                lvat_train = vat_loss(
                    self.embedder,
                    self.encoder,
                    self.clf,
                    train_batch,
                    addn_batch,
                    ek_batch,
                    graph_embs_batch,
                    p_logit=pred,
                    perturb_norm_length=self.config.perturb_norm_length)
                if self.config.inc_unlabeled_loss:
                    if memory_bank_unlabel.shape[0] != ek_batch_unlab.shape[0]:
                        print(
                            f'Skipping; Unequal Shapes: {memory_bank_unlabel.shape} and {ek_batch_unlab.shape}'
                        )
                        continue
                    else:
                        lvat_unlabel = vat_loss(
                            self.embedder,
                            self.encoder,
                            self.clf,
                            unlabel_batch,
                            addn_batch_unlab,
                            ek_batch_unlab,
                            graph_embs_unlab,
                            p_logit=self.clf(memory_bank_unlabel,
                                             addn_batch_unlab, ek_batch_unlab,
                                             enc_final_unlabel,
                                             graph_embs_unlab),
                            perturb_norm_length=self.config.perturb_norm_length
                        )
                    if self.config.unlabeled_loss_type == "AvgTrainUnlabel":
                        lvat = 0.5 * (lvat_train + lvat_unlabel)
                    elif self.config.unlabeled_loss_type == "Unlabel":
                        lvat = lvat_unlabel
                else:
                    lvat = lvat_train

            lentropy = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            if self.config.lambda_entropy > 0:
                lentropy_train = entropy_loss(pred)
                if self.config.inc_unlabeled_loss:
                    lentropy_unlabel = entropy_loss(
                        self.clf(memory_bank_unlabel, addn_batch_unlab,
                                 ek_batch_unlab, enc_final_unlabel,
                                 graph_embs_unlab))
                    if self.config.unlabeled_loss_type == "AvgTrainUnlabel":
                        lentropy = 0.5 * (lentropy_train + lentropy_unlabel)
                    elif self.config.unlabeled_loss_type == "Unlabel":
                        lentropy = lentropy_unlabel
                else:
                    lentropy = lentropy_train

            lae = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            if self.config.lambda_ae > 0:
                lae = self.ae(memory_bank_unlabel, enc_final_unlabel,
                              unlabel_batch.sent_len, unlabel_batch_raw)

            ltotal = (self.config.lambda_clf * lclf) + \
                     (self.config.lambda_ae * lae) + \
                     (self.config.lambda_at * lat) + \
                     (self.config.lambda_vat * lvat) + \
                     (self.config.lambda_entropy * lentropy)

            report_stats.clf_loss += lclf.data.cpu().numpy()
            report_stats.at_loss += lat.data.cpu().numpy()
            report_stats.vat_loss += lvat.data.cpu().numpy()
            report_stats.ae_loss += lae.data.cpu().numpy()
            report_stats.entropy_loss += lentropy.data.cpu().numpy()
            report_stats.n_sent += len(pred)
            report_stats.n_correct += accuracy
            self.enc_clf_opt.zero_grad()
            ltotal.backward()

            params_list = self._get_trainabe_modules()
            # Excluding embedder form norm constraint when AT or VAT
            if not self.config.normalize_embedding:
                params_list += list(self.embedder.parameters())

            norm = torch.nn.utils.clip_grad_norm(params_list,
                                                 self.config.max_norm)
            report_stats.grad_norm += norm
            self.enc_clf_opt.step()
            if self.config.scheduler == "ExponentialLR":
                self.scheduler.step()
            self.ema_embedder.apply(self.embedder.named_parameters())
            self.ema_encoder.apply(self.encoder.named_parameters())
            self.ema_clf.apply(self.clf.named_parameters())

            report_func(self.epoch, batch_index, iter_per_epoch, self.time_s,
                        report_stats, self.config.report_every, self.logger)

            if self.global_steps % self.config.eval_steps == 0:
                cm_, accuracy, prc_dev = self._run_evaluate(
                    dev_data, addn_dev, ek_t, graph_embs_t)
                self.logger.info(
                    "- dev accuracy {} | best dev accuracy {} ".format(
                        accuracy, self.best_accuracy))
                self.writer.add_scalar("Dev_Accuracy", accuracy,
                                       self.global_steps)
                pred_, lab_ = zip(*prc_dev)
                pred_ = torch.cat(pred_)
                lab_ = torch.cat(lab_)
                self.writer.add_pr_curve("Dev PR-Curve", lab_, pred_,
                                         self.global_steps)
                pprint.pprint(cm_)
                pprint.pprint(cm_.get_all_metrics())
                if accuracy > self.best_accuracy:
                    self.logger.info("- new best score!")
                    self.best_accuracy = accuracy
                    self._save_model()
                if self.config.scheduler == "ReduceLROnPlateau":
                    self.scheduler.step(accuracy)
                self.encoder.train()
                #                 self.embedder.train()
                self.clf.train()

                if self.config.weight_decay > 0:
                    print(">> Square Norm: %1.4f " % self._get_l2_norm_loss())

        cm, train_accuracy, _ = self._run_evaluate(train_data, addn_data, ek,
                                                   graph_embs)
        self.logger.info("- Train accuracy  {}".format(train_accuracy))
        pprint.pprint(cm.get_all_metrics())

        cm, dev_accuracy, _ = self._run_evaluate(dev_data, addn_dev, ek_t,
                                                 graph_embs_t)
        self.logger.info("- Dev accuracy  {} | best dev accuracy {}".format(
            dev_accuracy, self.best_accuracy))
        pprint.pprint(cm.get_all_metrics())
        self.writer.add_scalars("Overall_Accuracy", {
            "Train_Accuracy": train_accuracy,
            "Dev_Accuracy": dev_accuracy
        }, self.global_steps)
        return dev_accuracy