def dev(cfg, dataset, model):
    logger.info("Validate starting...")
    model.zero_grad()

    all_outputs = []
    all_ent_loss = []
    all_rel_loss = []

    if cfg.embedding_model == 'word_char':
        sort_key = "tokens"
    else:
        sort_key = None

    for _, batch in dataset.get_batch('dev', cfg.test_batch_size, sort_key):
        model.eval()
        with torch.no_grad():
            batch_outpus, ent_loss, rel_loss = step(cfg, model, batch,
                                                    cfg.device)
        all_outputs.extend(batch_outpus)
        all_ent_loss.append(ent_loss.item())
        all_rel_loss.append(rel_loss.item())
    mean_ent_loss = np.mean(all_ent_loss)
    mean_rel_loss = np.mean(all_rel_loss)
    mean_loss = mean_ent_loss + mean_rel_loss

    logger.info("Validate Avgloss: {} (Ent_loss: {} Rel_loss: {})".format(
        mean_loss, mean_ent_loss, mean_rel_loss))

    dev_output_file = os.path.join(cfg.save_dir, "dev.output")
    print_predictions(
        all_outputs, dev_output_file, dataset.vocab, 'entity_labels'
        if cfg.entity_model == 'joint' else 'entity_span_labels')
    token_score, ent_score, rel_score, exact_rel_score = eval_file(
        dev_output_file)
    return ent_score + exact_rel_score
def test(cfg, dataset, model):
    logger.info("Testing starting...")
    model.zero_grad()

    all_outputs = []

    if cfg.embedding_model == 'word_char':
        sort_key = "tokens"
    else:
        sort_key = None

    for _, batch in dataset.get_batch('test', cfg.test_batch_size, sort_key):
        model.eval()
        with torch.no_grad():
            batch_outpus, ent_loss, rel_loss = step(cfg, model, batch,
                                                    cfg.device)
        all_outputs.extend(batch_outpus)
    test_output_file = os.path.join(cfg.save_dir, "test.output")
    print_predictions(
        all_outputs, test_output_file, dataset.vocab, 'entity_labels'
        if cfg.entity_model == 'joint' else 'entity_span_labels')
    eval_file(test_output_file)
示例#3
0
def test(args, dataset, model):
    logger.info("Testing starting...")
    model.zero_grad()

    all_outputs = []

    if args.embedding_model == 'word_char' or args.lstm_layers > 0:
        sort_key = "tokens"
    else:
        sort_key = None

    for _, batch in dataset.get_batch('test', args.test_batch_size, sort_key):
        model.eval()
        with torch.no_grad():
            batch_outpus, ent_loss, rel_loss = step(args, model, batch,
                                                    args.device)
        all_outputs.extend(batch_outpus)
    test_output_file = os.path.join(args.save_dir, "test.output")
    print_predictions(
        all_outputs, test_output_file, dataset.vocab, 'entity_labels'
        if args.entity_model == 'joint' else 'entity_span_labels')
    eval_metrics = ['token', 'span', 'ent', 'rel', 'exact-rel']
    eval_file(test_output_file, eval_metrics)