def train_epoch(self): """ train_epoch """ self.epoch += 1 train_mm_S, train_mm_G, train_mm_DB, train_mm_DE = MetricsManager(), MetricsManager(), MetricsManager(), \ MetricsManager() num_batches = self.train_iter.n_batch self.train_iter.prepare_epoch() self.logger.info(self.train_start_message) for batch_idx in range(num_batches): self.model.train() start_time = time.time() local_data = self.train_iter.get_batch(batch_idx) turn_inputs = create_turn_batch(local_data['inputs']) kb_inputs = create_kb_batch(local_data['kbs']) assert len(turn_inputs) == local_data['max_turn'] metrics_list_S, metrics_list_G, metrics_list_DB, metrics_list_DE \ = self.model.iterate(turn_inputs, kb_inputs, optimizer=self.optimizer, grad_clip=self.grad_clip, is_training=True, method=self.method) elapsed = time.time() - start_time train_mm_S.update(metrics_list_S) train_mm_G.update(metrics_list_G) train_mm_DB.update(metrics_list_DB) train_mm_DE.update(metrics_list_DE) self.batch_num += 1 if (batch_idx + 1) % self.log_steps == 0: self.report_log_steps(train_mm=train_mm_S, elapsed=elapsed, batch_idx=batch_idx, num_batches=num_batches) self.report_log_steps(train_mm=train_mm_G, elapsed=elapsed, batch_idx=batch_idx, num_batches=num_batches) self.report_log_steps(train_mm=train_mm_DB, elapsed=elapsed, batch_idx=batch_idx, num_batches=num_batches) self.report_log_steps(train_mm=train_mm_DE, elapsed=elapsed, batch_idx=batch_idx, num_batches=num_batches) # TODO only evaluate model_S here if (batch_idx + 1) % self.valid_steps == 0: self.report_valid_steps(model=self.model.model_S, batch_idx=batch_idx, num_batches=num_batches) self.save() self.logger.info('')
def evaluate(model, data_iter, use_rl=False, method=None): """ evaluate note: this function evalute single sub_model instead of the whole model (muti-agent) """ model.eval() mm = MetricsManager() mm_M, mm_S, mm_TB, mm_TE = MetricsManager(), MetricsManager( ), MetricsManager(), MetricsManager() num_batches = data_iter.n_batch with torch.no_grad(): for batch_idx in range(num_batches): local_data = data_iter.get_batch(batch_idx) turn_inputs = create_turn_batch(local_data['inputs']) kb_inputs = create_kb_batch(local_data['kbs']) assert len(turn_inputs) == local_data['max_turn'] if model.name == 'muti_agent': if method == '1-3': metrics_list_M, metrics_list_S, metrics_list_TB, metrics_list_TE = \ model.iterate(turn_inputs, kb_inputs, is_training=False, method=method) else: metrics_list_M, metrics_list_S, metrics_list_TB, metrics_list_TE, _, _ = \ model.iterate(turn_inputs, kb_inputs, is_training=False, method=method) mm_M.update(metrics_list_M) mm_S.update(metrics_list_S) mm_TB.update(metrics_list_TB) if metrics_list_TE: mm_TE.update(metrics_list_TE) return mm_M, mm_S, mm_TB, mm_TE else: metrics_list = model.iterate(turn_inputs, kb_inputs, use_rl=use_rl, is_training=False) mm.update(metrics_list) return mm
def evaluate(model, data_iter, use_rl=False, entity_dir=None): """ evaluate """ model.eval() mm = MetricsManager() num_batches = data_iter.n_batch with torch.no_grad(): for batch_idx in range(num_batches): local_data = data_iter.get_batch(batch_idx) turn_inputs = create_turn_batch(local_data['inputs']) kb_inputs = create_kb_batch(local_data['kbs']) assert len(turn_inputs) == local_data['max_turn'] metrics_list = model.iterate(turn_inputs, kb_inputs, use_rl=use_rl, entity_dir=entity_dir, is_training=False) mm.update(metrics_list) return mm
def generate(self, data_iter, save_file=None, verbos=False): """ generate """ results = [] num_batches = data_iter.n_batch for batch_idx in range(num_batches): local_data = data_iter.get_batch(batch_idx) turn_inputs = create_turn_batch(local_data['inputs']) kb_inputs = create_kb_batch(local_data['kbs']) assert len(turn_inputs) == local_data['max_turn'] result_batch = self.forward(turn_inputs=turn_inputs, kb_inputs=kb_inputs, enc_hidden=None) results.append(result_batch) # post processing the results hyps = [] refs = [] turn_labels = [] tasks = [] gold_entity = [] kb_word = [] for result_batch in results: for result_turn in result_batch: hyps.append(result_turn.preds[0].split(" ")) refs.append(result_turn.tgt.split(" ")) turn_labels.append(result_turn.turn_label) tasks.append(result_turn.task) gold_entity.append(result_turn.gold_entity) kb_word.append(result_turn.kb_word) outputs = [] for tl, hyp, ref, tk, ent, kb in zip(turn_labels, hyps, refs, tasks, gold_entity, kb_word): hyp_str = " ".join(hyp) ref_str = " ".join(ref) sample = { "turn_label": str(tl), "result": hyp_str, "target": ref_str, "task": tk, "gold_entity": ent, "kb": kb } outputs.append(sample) avg_len = np.average([len(s) for s in hyps]) intra_dist1, intra_dist2, inter_dist1, inter_dist2 = distinct(hyps) report_message = "Result: Avg_Len={:.3f} ".format(avg_len) + \ "Inter_Dist={:.4f}/{:.4f}".format(inter_dist1, inter_dist2) avg_len = np.average([len(s) for s in refs]) intra_dist1, intra_dist2, inter_dist1, inter_dist2 = distinct(refs) target_message = "Target: Avg_Len={:.3f} ".format(avg_len) + \ "Inter_Dist={:.4f}/{:.4f}".format(inter_dist1, inter_dist2) message = target_message + "\n" + report_message if verbos: print(message) if save_file is not None: with open(save_file, 'w', encoding='utf-8') as fw: for sample in outputs: line = json.dumps(sample) fw.write(line) fw.write('\n') print("Saved generation results to '{}.'".format(save_file))
def train_epoch(self): """ train_epoch """ self.epoch += 1 train_mm = MetricsManager() num_batches = self.train_iter.n_batch self.train_iter.prepare_epoch() self.logger.info(self.train_start_message) for batch_idx in range(num_batches): self.model.train() start_time = time.time() local_data = self.train_iter.get_batch(batch_idx) turn_inputs = create_turn_batch(local_data['inputs']) kb_inputs = create_kb_batch(local_data['kbs']) situation_inputs = create_situation_batch(local_data['situation']) user_profile_inputs = create_user_profile_batch(local_data['user_profile']) assert len(turn_inputs) == local_data['max_turn'] metrics_list = self.model.iterate(turn_inputs, kb_inputs, situation_inputs, user_profile_inputs, optimizer=self.optimizer, grad_clip=self.grad_clip, use_rl=self.use_rl, is_training=True) elapsed = time.time() - start_time train_mm.update(metrics_list) self.batch_num += 1 if (batch_idx + 1) % self.log_steps == 0: message_prefix = "[Train][{:2d}][{}/{}]".format(self.epoch, batch_idx + 1, num_batches) metrics_message = train_mm.report_val() message_posfix = "TIME={:.2f}s".format(elapsed) self.logger.info(" ".join( [message_prefix, metrics_message, message_posfix])) if (batch_idx + 1) % self.valid_steps == 0: self.logger.info(self.valid_start_message) valid_mm = self.evaluate(self.model, self.valid_iter, use_rl=self.use_rl) message_prefix = "[Valid][{:2d}][{}/{}]".format(self.epoch, batch_idx + 1, num_batches) metrics_message = valid_mm.report_cum() self.logger.info(" ".join([message_prefix, metrics_message])) cur_valid_metric = valid_mm.get(self.valid_metric_name) if self.is_decreased_valid_metric: is_best = cur_valid_metric < self.best_valid_metric else: is_best = cur_valid_metric > self.best_valid_metric if is_best: self.best_valid_metric = cur_valid_metric self.save(is_best, is_rl=self.use_rl) if self.lr_scheduler is not None: self.lr_scheduler.step(cur_valid_metric) self.logger.info("-" * 85 + "\n") self.save() self.logger.info('')
def generate(self, data_iter, save_file=None, verbos=False): """ generate """ results = [] num_batches = data_iter.n_batch for batch_idx in tqdm(range(num_batches)): local_data = data_iter.get_batch(batch_idx) turn_inputs = create_turn_batch(local_data['inputs']) kb_inputs = create_kb_batch(local_data['kbs']) situation_inputs = create_situation_batch(local_data['situation']) user_profile_inputs = create_user_profile_batch( local_data['user_profile']) assert len(turn_inputs) == local_data['max_turn'] result_batch = self.forward( turn_inputs=turn_inputs, kb_inputs=kb_inputs, situation_inputs=situation_inputs, user_profile_inputs=user_profile_inputs, enc_hidden=None) results.append(result_batch) # post processing the results srcs = [] hyps = [] refs = [] # turn_labels = [] tasks = [] gold_entity = [] # kb_word = [] for result_batch in results: for result_turn in result_batch: srcs.append(result_turn.src.split(" ")) hyps.append(result_turn.preds[0].split(" ")) refs.append(result_turn.tgt.split(" ")) # turn_labels.append(result_turn.turn_label) tasks.append(result_turn.task) gold_entity.append(result_turn.gold_entity) # kb_word.append(result_turn.kb_word) outputs = [] for src, hyp, ref, tk, ent in zip(srcs, hyps, refs, tasks, gold_entity): src_str = " ".join(src) hyp_str = " ".join(hyp) ref_str = " ".join(ref) if self.mode == 'test': sample = { "source": src_str, "result": hyp_str, "target": ref_str } else: sample = { "result": hyp_str, "target": ref_str, "task": tk, "gold_entity": ent } outputs.append(sample) avg_len = np.average([len(s) for s in hyps]) intra_dist1, intra_dist2, inter_dist1, inter_dist2 = distinct(hyps) report_message = "Result: Avg_Len={:.3f} ".format(avg_len) + \ "Inter_Dist={:.4f}/{:.4f}".format(inter_dist1, inter_dist2) if self.mode != 'test': avg_len = np.average([len(s) for s in refs]) intra_dist1, intra_dist2, inter_dist1, inter_dist2 = distinct(refs) target_message = "Target: Avg_Len={:.3f} ".format(avg_len) + \ "Inter_Dist={:.4f}/{:.4f}".format(inter_dist1, inter_dist2) else: target_message = "Target: Avg_Len={:.3f} ".format(1) + \ "Inter_Dist={:.4f}/{:.4f}".format(0.0, 0.0) message = target_message + "\n" + report_message if verbos: print(message) if save_file is not None: with open(save_file, 'w', encoding='utf-8') as fw: for sample in outputs: line = json.dumps(sample) fw.write(line) fw.write('\n') print("Saved generation results to '{}.'".format(save_file))