Exemplo n.º 1
0
def generate_training_instances(parsing_system: ParsingSystem,
                                sentences: List[List[str]],
                                vocabulary: Vocabulary,
                                trees: List[DependencyTree]) -> List[Dict]:
    """
    Generates training instances of configuration and transition labels
    from the sentences and the corresponding dependency trees.
    """
    num_transitions = parsing_system.num_transitions()
    instances: Dict[str, List] = []
    for i in tqdm(range(len(sentences))):
        if trees[i].is_projective():
            c = parsing_system.initial_configuration(sentences[i])
            while not parsing_system.is_terminal(c):
                oracle = parsing_system.get_oracle(c, trees[i])
                feature = get_configuration_features(c, vocabulary)
                label = []
                for j in range(num_transitions):
                    t = parsing_system.transitions[j]
                    if t == oracle:
                        label.append(1.)
                    elif parsing_system.can_apply(c, t):
                        label.append(0.)
                    else:
                        label.append(-1.)
                if 1.0 not in label:
                    print(i, label)
                instances.append({"input": feature, "label": label})
                c = parsing_system.apply(c, oracle)
    return instances
Exemplo n.º 2
0
def predict(model: models.Model, sentences: List[Sentence],
            parsing_system: ParsingSystem,
            vocabulary: Vocabulary) -> List[DependencyTree]:
    """
    Predicts the dependency tree for a given sentence by greedy decoding.
    We generate a initial configuration (features) for ``sentence`` using
    ``parsing_system`` and ``vocabulary``. Then we apply the ``model`` to predict
    what's the best transition for this configuration and apply this transition
    (greedily) with ``parsing_system`` to get the next configuration. We do
    this till the terminal configuration is reached.
    """
    predicted_trees = []
    num_transitions = parsing_system.num_transitions()
    for sentence in tqdm(sentences):
        configuration = parsing_system.initial_configuration(sentence)
        while not parsing_system.is_terminal(configuration):
            features = get_configuration_features(configuration, vocabulary)
            features = np.array(features).reshape((1, -1))
            logits = model(features)["logits"].numpy()
            opt_score = -float('inf')
            opt_trans = ""
            for j in range(num_transitions):
                if (logits[0, j] > opt_score and parsing_system.can_apply(
                        configuration, parsing_system.transitions[j])):
                    opt_score = logits[0, j]
                    opt_trans = parsing_system.transitions[j]
            configuration = parsing_system.apply(configuration, opt_trans)
        predicted_trees.append(configuration.tree)
    return predicted_trees