def generate_training_instances(parsing_system: ParsingSystem, sentences: List[List[str]], vocabulary: Vocabulary, trees: List[DependencyTree]) -> List[Dict]: """ Generates training instances of configuration and transition labels from the sentences and the corresponding dependency trees. """ num_transitions = parsing_system.num_transitions() instances: Dict[str, List] = [] for i in tqdm(range(len(sentences))): if trees[i].is_projective(): c = parsing_system.initial_configuration(sentences[i]) while not parsing_system.is_terminal(c): oracle = parsing_system.get_oracle(c, trees[i]) feature = get_configuration_features(c, vocabulary) label = [] for j in range(num_transitions): t = parsing_system.transitions[j] if t == oracle: label.append(1.) elif parsing_system.can_apply(c, t): label.append(0.) else: label.append(-1.) if 1.0 not in label: print(i, label) instances.append({"input": feature, "label": label}) c = parsing_system.apply(c, oracle) return instances
def predict(model: models.Model, sentences: List[Sentence], parsing_system: ParsingSystem, vocabulary: Vocabulary) -> List[DependencyTree]: """ Predicts the dependency tree for a given sentence by greedy decoding. We generate a initial configuration (features) for ``sentence`` using ``parsing_system`` and ``vocabulary``. Then we apply the ``model`` to predict what's the best transition for this configuration and apply this transition (greedily) with ``parsing_system`` to get the next configuration. We do this till the terminal configuration is reached. """ predicted_trees = [] num_transitions = parsing_system.num_transitions() for sentence in tqdm(sentences): configuration = parsing_system.initial_configuration(sentence) while not parsing_system.is_terminal(configuration): features = get_configuration_features(configuration, vocabulary) features = np.array(features).reshape((1, -1)) logits = model(features)["logits"].numpy() opt_score = -float('inf') opt_trans = "" for j in range(num_transitions): if (logits[0, j] > opt_score and parsing_system.can_apply( configuration, parsing_system.transitions[j])): opt_score = logits[0, j] opt_trans = parsing_system.transitions[j] configuration = parsing_system.apply(configuration, opt_trans) predicted_trees.append(configuration.tree) return predicted_trees