Esempio n. 1
0
def main():
    args_parser = argparse.ArgumentParser()
    args_parser.add_argument('--config_path', required=True)
    args = args_parser.parse_args()
    config = Config(None)
    config.load_config(args.config_path)

    logger = get_logger("RSTParser (Top-Down) RUN", config.use_dynamic_oracle,
                        config.model_path)
    word_alpha, tag_alpha, gold_action_alpha, action_label_alpha, relation_alpha, nuclear_alpha, nuclear_relation_alpha, etype_alpha = create_alphabet(
        None, config.alphabet_path, logger)
    vocab = Vocab(word_alpha, tag_alpha, etype_alpha, gold_action_alpha,
                  action_label_alpha, relation_alpha, nuclear_alpha,
                  nuclear_relation_alpha)

    network = MainArchitecture(vocab, config)

    if config.use_gpu:
        network.load_state_dict(torch.load(config.model_name))
        network = network.cuda()
    else:
        network.load_state_dict(
            torch.load(config.model_name, map_location=torch.device('cpu')))

    network.eval()

    logger.info('Reading dev instance, and predict...')
    reader = Reader(config.dev_path, config.dev_syn_feat_path)
    dev_instances = reader.read_data()
    predict(network, dev_instances, vocab, config, logger)

    logger.info('Reading test instance, and predict...')
    reader = Reader(config.test_path, config.test_syn_feat_path)
    test_instances = reader.read_data()
    predict(network, test_instances, vocab, config, logger)
Esempio n. 2
0
def main():
    args_parser = argparse.ArgumentParser()
    args_parser.add_argument('--config_path', required=True)
    args = args_parser.parse_args()
    config = Config(None)
    config.load_config(args.config_path)

    logger = get_logger("RSTParser RUN", config.use_dynamic_oracle,
                        config.model_path)
    word_alpha, tag_alpha, gold_action_alpha, action_label_alpha, etype_alpha = create_alphabet(
        None, config.alphabet_path, logger)
    vocab = Vocab(word_alpha, tag_alpha, etype_alpha, gold_action_alpha,
                  action_label_alpha)

    network = MainArchitecture(vocab, config)
    network.load_state_dict(torch.load(config.model_name))

    if config.use_gpu:
        network = network.cuda()
    network.eval()

    logger.info('Reading test instance')
    reader = Reader(config.test_path, config.test_syn_feat_path)
    test_instances = reader.read_data()
    time_start = datetime.now()
    batch_size = config.batch_size
    span = Metric()
    nuclear = Metric()
    relation = Metric()
    full = Metric()
    predictions = []
    total_data_test = len(test_instances)
    for i in range(0, total_data_test, batch_size):
        end_index = i + batch_size
        if end_index > total_data_test:
            end_index = total_data_test
        indices = np.array(range(i, end_index))
        subset_data_test = batch_data_variable(test_instances, indices, vocab,
                                               config)
        prediction_of_subtrees = network.loss(subset_data_test, None)
        predictions += prediction_of_subtrees
    for i in range(total_data_test):
        span, nuclear, relation, full = test_instances[i].evaluate(
            predictions[i], span, nuclear, relation, full)
    time_elapsed = datetime.now() - time_start
    m, s = divmod(time_elapsed.seconds, 60)
    logger.info('TEST is finished in {} mins {} secs'.format(m, s))
    logger.info("S: " + span.print_metric())
    logger.info("N: " + nuclear.print_metric())
    logger.info("R: " + relation.print_metric())
    logger.info("F: " + full.print_metric())

    import ipdb
    ipdb.set_trace()