Exemplo n.º 1
0
    def train_epoch(self):
        """
        train_epoch
        """
        self.epoch += 1
        train_mm_S, train_mm_G, train_mm_DB, train_mm_DE = MetricsManager(), MetricsManager(), MetricsManager(), \
                                                           MetricsManager()
        num_batches = self.train_iter.n_batch
        self.train_iter.prepare_epoch()
        self.logger.info(self.train_start_message)

        for batch_idx in range(num_batches):
            self.model.train()
            start_time = time.time()

            local_data = self.train_iter.get_batch(batch_idx)
            turn_inputs = create_turn_batch(local_data['inputs'])
            kb_inputs = create_kb_batch(local_data['kbs'])
            assert len(turn_inputs) == local_data['max_turn']

            metrics_list_S, metrics_list_G, metrics_list_DB, metrics_list_DE \
                                                      = self.model.iterate(turn_inputs, kb_inputs,
                                                                           optimizer=self.optimizer,
                                                                           grad_clip=self.grad_clip,
                                                                           is_training=True,
                                                                           method=self.method)

            elapsed = time.time() - start_time
            train_mm_S.update(metrics_list_S)
            train_mm_G.update(metrics_list_G)
            train_mm_DB.update(metrics_list_DB)
            train_mm_DE.update(metrics_list_DE)
            self.batch_num += 1

            if (batch_idx + 1) % self.log_steps == 0:
                self.report_log_steps(train_mm=train_mm_S,
                                      elapsed=elapsed,
                                      batch_idx=batch_idx,
                                      num_batches=num_batches)
                self.report_log_steps(train_mm=train_mm_G,
                                      elapsed=elapsed,
                                      batch_idx=batch_idx,
                                      num_batches=num_batches)
                self.report_log_steps(train_mm=train_mm_DB,
                                      elapsed=elapsed,
                                      batch_idx=batch_idx,
                                      num_batches=num_batches)
                self.report_log_steps(train_mm=train_mm_DE,
                                      elapsed=elapsed,
                                      batch_idx=batch_idx,
                                      num_batches=num_batches)

            # TODO only evaluate model_S here
            if (batch_idx + 1) % self.valid_steps == 0:
                self.report_valid_steps(model=self.model.model_S,
                                        batch_idx=batch_idx,
                                        num_batches=num_batches)

        self.save()
        self.logger.info('')
Exemplo n.º 2
0
    def evaluate(model, data_iter, use_rl=False, method=None):
        """
        evaluate
        note: this function evalute single sub_model instead of the whole model (muti-agent)
        """
        model.eval()
        mm = MetricsManager()
        mm_M, mm_S, mm_TB, mm_TE = MetricsManager(), MetricsManager(
        ), MetricsManager(), MetricsManager()
        num_batches = data_iter.n_batch
        with torch.no_grad():
            for batch_idx in range(num_batches):
                local_data = data_iter.get_batch(batch_idx)
                turn_inputs = create_turn_batch(local_data['inputs'])
                kb_inputs = create_kb_batch(local_data['kbs'])
                assert len(turn_inputs) == local_data['max_turn']

                if model.name == 'muti_agent':
                    if method == '1-3':
                        metrics_list_M, metrics_list_S, metrics_list_TB, metrics_list_TE = \
                                                            model.iterate(turn_inputs, kb_inputs,
                                                                          is_training=False,
                                                                          method=method)
                    else:
                        metrics_list_M, metrics_list_S, metrics_list_TB, metrics_list_TE, _, _ = \
                                                            model.iterate(turn_inputs, kb_inputs,
                                                                          is_training=False,
                                                                          method=method)
                    mm_M.update(metrics_list_M)
                    mm_S.update(metrics_list_S)
                    mm_TB.update(metrics_list_TB)
                    if metrics_list_TE:
                        mm_TE.update(metrics_list_TE)
                    return mm_M, mm_S, mm_TB, mm_TE
                else:
                    metrics_list = model.iterate(turn_inputs,
                                                 kb_inputs,
                                                 use_rl=use_rl,
                                                 is_training=False)
                    mm.update(metrics_list)
                    return mm
Exemplo n.º 3
0
    def evaluate(model, data_iter, use_rl=False, entity_dir=None):
        """
        evaluate
        """
        model.eval()
        mm = MetricsManager()
        num_batches = data_iter.n_batch
        with torch.no_grad():
            for batch_idx in range(num_batches):
                local_data = data_iter.get_batch(batch_idx)
                turn_inputs = create_turn_batch(local_data['inputs'])
                kb_inputs = create_kb_batch(local_data['kbs'])
                assert len(turn_inputs) == local_data['max_turn']

                metrics_list = model.iterate(turn_inputs,
                                             kb_inputs,
                                             use_rl=use_rl,
                                             entity_dir=entity_dir,
                                             is_training=False)
                mm.update(metrics_list)
        return mm
Exemplo n.º 4
0
    def generate(self, data_iter, save_file=None, verbos=False):
        """
        generate
        """
        results = []
        num_batches = data_iter.n_batch
        for batch_idx in range(num_batches):
            local_data = data_iter.get_batch(batch_idx)
            turn_inputs = create_turn_batch(local_data['inputs'])
            kb_inputs = create_kb_batch(local_data['kbs'])
            assert len(turn_inputs) == local_data['max_turn']
            result_batch = self.forward(turn_inputs=turn_inputs,
                                        kb_inputs=kb_inputs,
                                        enc_hidden=None)
            results.append(result_batch)

        # post processing the results
        hyps = []
        refs = []
        turn_labels = []
        tasks = []
        gold_entity = []
        kb_word = []
        for result_batch in results:
            for result_turn in result_batch:
                hyps.append(result_turn.preds[0].split(" "))
                refs.append(result_turn.tgt.split(" "))
                turn_labels.append(result_turn.turn_label)
                tasks.append(result_turn.task)
                gold_entity.append(result_turn.gold_entity)
                kb_word.append(result_turn.kb_word)

        outputs = []

        for tl, hyp, ref, tk, ent, kb in zip(turn_labels, hyps, refs, tasks,
                                             gold_entity, kb_word):
            hyp_str = " ".join(hyp)
            ref_str = " ".join(ref)
            sample = {
                "turn_label": str(tl),
                "result": hyp_str,
                "target": ref_str,
                "task": tk,
                "gold_entity": ent,
                "kb": kb
            }
            outputs.append(sample)

        avg_len = np.average([len(s) for s in hyps])
        intra_dist1, intra_dist2, inter_dist1, inter_dist2 = distinct(hyps)
        report_message = "Result:  Avg_Len={:.3f}   ".format(avg_len) + \
                         "Inter_Dist={:.4f}/{:.4f}".format(inter_dist1, inter_dist2)

        avg_len = np.average([len(s) for s in refs])
        intra_dist1, intra_dist2, inter_dist1, inter_dist2 = distinct(refs)
        target_message = "Target:  Avg_Len={:.3f}   ".format(avg_len) + \
                         "Inter_Dist={:.4f}/{:.4f}".format(inter_dist1, inter_dist2)

        message = target_message + "\n" + report_message
        if verbos:
            print(message)

        if save_file is not None:
            with open(save_file, 'w', encoding='utf-8') as fw:
                for sample in outputs:
                    line = json.dumps(sample)
                    fw.write(line)
                    fw.write('\n')
            print("Saved generation results to '{}.'".format(save_file))
Exemplo n.º 5
0
    def train_epoch(self):
        """
        train_epoch
        """
        self.epoch += 1
        train_mm = MetricsManager()
        num_batches = self.train_iter.n_batch
        self.train_iter.prepare_epoch()
        self.logger.info(self.train_start_message)

        for batch_idx in range(num_batches):
            self.model.train()
            start_time = time.time()

            local_data = self.train_iter.get_batch(batch_idx)

            turn_inputs = create_turn_batch(local_data['inputs'])
            kb_inputs = create_kb_batch(local_data['kbs'])
            situation_inputs = create_situation_batch(local_data['situation'])
            user_profile_inputs = create_user_profile_batch(local_data['user_profile'])
            assert len(turn_inputs) == local_data['max_turn']

            metrics_list = self.model.iterate(turn_inputs, kb_inputs,
                                              situation_inputs, user_profile_inputs,
                                              optimizer=self.optimizer,
                                              grad_clip=self.grad_clip,
                                              use_rl=self.use_rl,
                                              is_training=True)

            elapsed = time.time() - start_time
            train_mm.update(metrics_list)
            self.batch_num += 1

            if (batch_idx + 1) % self.log_steps == 0:
                message_prefix = "[Train][{:2d}][{}/{}]".format(self.epoch, batch_idx + 1, num_batches)
                metrics_message = train_mm.report_val()
                message_posfix = "TIME={:.2f}s".format(elapsed)
                self.logger.info("   ".join(
                    [message_prefix, metrics_message, message_posfix]))

            if (batch_idx + 1) % self.valid_steps == 0:
                self.logger.info(self.valid_start_message)
                valid_mm = self.evaluate(self.model, self.valid_iter, use_rl=self.use_rl)

                message_prefix = "[Valid][{:2d}][{}/{}]".format(self.epoch, batch_idx + 1, num_batches)
                metrics_message = valid_mm.report_cum()
                self.logger.info("   ".join([message_prefix, metrics_message]))

                cur_valid_metric = valid_mm.get(self.valid_metric_name)
                if self.is_decreased_valid_metric:
                    is_best = cur_valid_metric < self.best_valid_metric
                else:
                    is_best = cur_valid_metric > self.best_valid_metric
                if is_best:
                    self.best_valid_metric = cur_valid_metric
                self.save(is_best, is_rl=self.use_rl)
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step(cur_valid_metric)
                self.logger.info("-" * 85 + "\n")

        self.save()
        self.logger.info('')
Exemplo n.º 6
0
    def generate(self, data_iter, save_file=None, verbos=False):
        """
        generate
        """
        results = []
        num_batches = data_iter.n_batch
        for batch_idx in tqdm(range(num_batches)):
            local_data = data_iter.get_batch(batch_idx)
            turn_inputs = create_turn_batch(local_data['inputs'])
            kb_inputs = create_kb_batch(local_data['kbs'])
            situation_inputs = create_situation_batch(local_data['situation'])
            user_profile_inputs = create_user_profile_batch(
                local_data['user_profile'])
            assert len(turn_inputs) == local_data['max_turn']
            result_batch = self.forward(
                turn_inputs=turn_inputs,
                kb_inputs=kb_inputs,
                situation_inputs=situation_inputs,
                user_profile_inputs=user_profile_inputs,
                enc_hidden=None)
            results.append(result_batch)

        # post processing the results
        srcs = []
        hyps = []
        refs = []
        # turn_labels = []
        tasks = []
        gold_entity = []
        # kb_word = []
        for result_batch in results:
            for result_turn in result_batch:
                srcs.append(result_turn.src.split(" "))
                hyps.append(result_turn.preds[0].split(" "))
                refs.append(result_turn.tgt.split(" "))
                # turn_labels.append(result_turn.turn_label)
                tasks.append(result_turn.task)
                gold_entity.append(result_turn.gold_entity)
                # kb_word.append(result_turn.kb_word)

        outputs = []

        for src, hyp, ref, tk, ent in zip(srcs, hyps, refs, tasks,
                                          gold_entity):
            src_str = " ".join(src)
            hyp_str = " ".join(hyp)
            ref_str = " ".join(ref)
            if self.mode == 'test':
                sample = {
                    "source": src_str,
                    "result": hyp_str,
                    "target": ref_str
                }
            else:
                sample = {
                    "result": hyp_str,
                    "target": ref_str,
                    "task": tk,
                    "gold_entity": ent
                }
            outputs.append(sample)

        avg_len = np.average([len(s) for s in hyps])
        intra_dist1, intra_dist2, inter_dist1, inter_dist2 = distinct(hyps)
        report_message = "Result:  Avg_Len={:.3f}   ".format(avg_len) + \
                         "Inter_Dist={:.4f}/{:.4f}".format(inter_dist1, inter_dist2)

        if self.mode != 'test':
            avg_len = np.average([len(s) for s in refs])
            intra_dist1, intra_dist2, inter_dist1, inter_dist2 = distinct(refs)
            target_message = "Target:  Avg_Len={:.3f}   ".format(avg_len) + \
                             "Inter_Dist={:.4f}/{:.4f}".format(inter_dist1, inter_dist2)
        else:
            target_message = "Target:  Avg_Len={:.3f}   ".format(1) + \
                             "Inter_Dist={:.4f}/{:.4f}".format(0.0, 0.0)

        message = target_message + "\n" + report_message

        if verbos:
            print(message)

        if save_file is not None:
            with open(save_file, 'w', encoding='utf-8') as fw:
                for sample in outputs:
                    line = json.dumps(sample)
                    fw.write(line)
                    fw.write('\n')
            print("Saved generation results to '{}.'".format(save_file))