def generate_training_instances(parsing_system: ParsingSystem, sentences: List[List[str]], vocabulary: Vocabulary, trees: List[DependencyTree]) -> List[Dict]: """ Generates training instances of configuration and transition labels from the sentences and the corresponding dependency trees. """ num_transitions = parsing_system.num_transitions() instances: Dict[str, List] = [] for i in tqdm(range(len(sentences))): if trees[i].is_projective(): c = parsing_system.initial_configuration(sentences[i]) while not parsing_system.is_terminal(c): oracle = parsing_system.get_oracle(c, trees[i]) feature = get_configuration_features(c, vocabulary) label = [] for j in range(num_transitions): t = parsing_system.transitions[j] if t == oracle: label.append(1.) elif parsing_system.can_apply(c, t): label.append(0.) else: label.append(-1.) if 1.0 not in label: print(i, label) instances.append({"input": feature, "label": label}) c = parsing_system.apply(c, oracle) return instances
def predict(model: models.Model, sentences: List[Sentence], parsing_system: ParsingSystem, vocabulary: Vocabulary) -> List[DependencyTree]: """ Predicts the dependency tree for a given sentence by greedy decoding. We generate a initial configuration (features) for ``sentence`` using ``parsing_system`` and ``vocabulary``. Then we apply the ``model`` to predict what's the best transition for this configuration and apply this transition (greedily) with ``parsing_system`` to get the next configuration. We do this till the terminal configuration is reached. """ predicted_trees = [] num_transitions = parsing_system.num_transitions() for sentence in tqdm(sentences): configuration = parsing_system.initial_configuration(sentence) while not parsing_system.is_terminal(configuration): features = get_configuration_features(configuration, vocabulary) features = np.array(features).reshape((1, -1)) logits = model(features)["logits"].numpy() opt_score = -float('inf') opt_trans = "" for j in range(num_transitions): if (logits[0, j] > opt_score and parsing_system.can_apply( configuration, parsing_system.transitions[j])): opt_score = logits[0, j] opt_trans = parsing_system.transitions[j] configuration = parsing_system.apply(configuration, opt_trans) predicted_trees.append(configuration.tree) return predicted_trees
# If cached training data is asked for. if args.cache_processed_data: print("Caching training instances for later use") cache_processed_data_path = args.train_data_file_path.replace( "conll", "jsonl") with open(cache_processed_data_path, "w") as file: for instance in tqdm(train_instances): file.write(json.dumps(instance) + "\n") # Setup Model config_dict = { "vocab_size": len(vocabulary.id_to_token), "embedding_dim": args.embedding_dim, "num_tokens": args.num_tokens, "hidden_dim": args.hidden_dim, "num_transitions": parsing_system.num_transitions(), "regularization_lambda": args.regularization_lambda, "trainable_embeddings": args.trainable_embeddings, "activation_name": args.activation_name } model = DependencyParser(**config_dict) if args.pretrained_embedding_file: embedding_matrix = load_embeddings(args.pretrained_embedding_file, vocabulary, args.embedding_dim) model.embeddings.assign(embedding_matrix) # Setup Optimizer optimizer = optimizers.Adam() # Train