예제 #1
0
    def train_classifier(self, epochs):
        """
        Classifier for calculating the classification accuracy metric of category text generation.

        Note: the train and test data for the classifier is opposite to the generator.
        Because the classifier is to calculate the classification accuracy of the generated samples
        where are trained on self.train_samples_list.

        Since there's no test data in synthetic data (oracle data), the synthetic data experiments
        doesn't need a classifier.
        """
        import copy

        # Prepare data for Classifier
        clas_data = CatClasDataIter(self.clas_samples_list)
        eval_clas_data = CatClasDataIter(self.train_samples_list)

        max_acc = 0
        best_clas = None
        for epoch in range(epochs):
            c_loss, c_acc = self.train_dis_epoch(self.clas, clas_data.loader,
                                                 self.clas_criterion,
                                                 self.clas_opt)
            _, eval_acc = self.eval_dis(self.clas, eval_clas_data.loader,
                                        self.clas_criterion)
            if eval_acc > max_acc:
                best_clas = copy.deepcopy(
                    self.clas.state_dict())  # save the best classifier
                max_acc = eval_acc
            self.log.info(
                '[PRE-CLAS] epoch %d: c_loss = %.4f, c_acc = %.4f, eval_acc = %.4f, max_eval_acc = %.4f',
                epoch, c_loss, c_acc, eval_acc, max_acc)
        self.clas.load_state_dict(
            copy.deepcopy(best_clas))  # Reload the best classifier
예제 #2
0
    def train_discriminator(self, d_step, d_epoch, phase='MLE'):
        """
        Training the discriminator on real_data_samples (positive) and generated samples from gen (negative).
        Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch.
        """
        # prepare loader for validate
        global d_loss, train_acc

        for step in range(d_step):
            # prepare loader for training
            real_samples = []
            fake_samples = []
            for i in range(cfg.k_label):
                real_samples.append(self.oracle_samples_list[i])
                fake_samples.append(self.gen_list[i].sample(
                    cfg.samples_num // cfg.k_label, 8 * cfg.batch_size))

            dis_samples_list = [torch.cat(fake_samples, dim=0)] + real_samples
            dis_data = CatClasDataIter(dis_samples_list)

            for epoch in range(d_epoch):
                # ===Train===
                d_loss, train_acc = self.train_dis_epoch(
                    self.dis, dis_data.loader, self.dis_criterion,
                    self.dis_opt)

            # ===Test===
            self.log.info(
                '[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f' %
                (phase, step, d_loss, train_acc))

            if cfg.if_save and not cfg.if_test and phase == 'MLE':
                torch.save(self.dis.state_dict(), cfg.pretrained_dis_path)
예제 #3
0
    def cal_metrics_with_label(self, label_i):
        assert type(label_i) == int, 'missing label'

        with torch.no_grad():
            # Prepare data for evaluation
            eval_samples = self.gen.sample(cfg.samples_num,
                                           8 * cfg.batch_size,
                                           label_i=label_i)
            gen_data = GenDataIter(eval_samples)
            gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict)
            gen_tokens_s = tensor_to_tokens(
                self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict)
            clas_data = CatClasDataIter([eval_samples], label_i)

            # Reset metrics
            self.bleu.reset(test_text=gen_tokens,
                            real_text=self.test_data_list[label_i].tokens)
            self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader,
                               label_i)
            self.nll_div.reset(self.gen, gen_data.loader, label_i)
            self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens)
            self.clas_acc.reset(self.clas, clas_data.loader)
            self.ppl.reset(gen_tokens)

        return [metric.get_score() for metric in self.all_metrics]