def dev(cfg, dataset, model): logger.info("Validate starting...") model.zero_grad() all_outputs = [] all_ent_loss = [] all_rel_loss = [] if cfg.embedding_model == 'word_char': sort_key = "tokens" else: sort_key = None for _, batch in dataset.get_batch('dev', cfg.test_batch_size, sort_key): model.eval() with torch.no_grad(): batch_outpus, ent_loss, rel_loss = step(cfg, model, batch, cfg.device) all_outputs.extend(batch_outpus) all_ent_loss.append(ent_loss.item()) all_rel_loss.append(rel_loss.item()) mean_ent_loss = np.mean(all_ent_loss) mean_rel_loss = np.mean(all_rel_loss) mean_loss = mean_ent_loss + mean_rel_loss logger.info("Validate Avgloss: {} (Ent_loss: {} Rel_loss: {})".format( mean_loss, mean_ent_loss, mean_rel_loss)) dev_output_file = os.path.join(cfg.save_dir, "dev.output") print_predictions( all_outputs, dev_output_file, dataset.vocab, 'entity_labels' if cfg.entity_model == 'joint' else 'entity_span_labels') token_score, ent_score, rel_score, exact_rel_score = eval_file( dev_output_file) return ent_score + exact_rel_score
def test(cfg, dataset, model): logger.info("Testing starting...") model.zero_grad() all_outputs = [] if cfg.embedding_model == 'word_char': sort_key = "tokens" else: sort_key = None for _, batch in dataset.get_batch('test', cfg.test_batch_size, sort_key): model.eval() with torch.no_grad(): batch_outpus, ent_loss, rel_loss = step(cfg, model, batch, cfg.device) all_outputs.extend(batch_outpus) test_output_file = os.path.join(cfg.save_dir, "test.output") print_predictions( all_outputs, test_output_file, dataset.vocab, 'entity_labels' if cfg.entity_model == 'joint' else 'entity_span_labels') eval_file(test_output_file)
def test(args, dataset, model): logger.info("Testing starting...") model.zero_grad() all_outputs = [] if args.embedding_model == 'word_char' or args.lstm_layers > 0: sort_key = "tokens" else: sort_key = None for _, batch in dataset.get_batch('test', args.test_batch_size, sort_key): model.eval() with torch.no_grad(): batch_outpus, ent_loss, rel_loss = step(args, model, batch, args.device) all_outputs.extend(batch_outpus) test_output_file = os.path.join(args.save_dir, "test.output") print_predictions( all_outputs, test_output_file, dataset.vocab, 'entity_labels' if args.entity_model == 'joint' else 'entity_span_labels') eval_metrics = ['token', 'span', 'ent', 'rel', 'exact-rel'] eval_file(test_output_file, eval_metrics)