def predict(model: models.Model, sentences: List[Sentence], parsing_system: ParsingSystem, vocabulary: Vocabulary) -> List[DependencyTree]: """ Predicts the dependency tree for a given sentence by greedy decoding. We generate a initial configuration (features) for ``sentence`` using ``parsing_system`` and ``vocabulary``. Then we apply the ``model`` to predict what's the best transition for this configuration and apply this transition (greedily) with ``parsing_system`` to get the next configuration. We do this till the terminal configuration is reached. """ predicted_trees = [] num_transitions = parsing_system.num_transitions() for sentence in tqdm(sentences): configuration = parsing_system.initial_configuration(sentence) while not parsing_system.is_terminal(configuration): features = get_configuration_features(configuration, vocabulary) features = np.array(features).reshape((1, -1)) logits = model(features)["logits"].numpy() opt_score = -float('inf') opt_trans = "" for j in range(num_transitions): if (logits[0, j] > opt_score and parsing_system.can_apply( configuration, parsing_system.transitions[j])): opt_score = logits[0, j] opt_trans = parsing_system.transitions[j] configuration = parsing_system.apply(configuration, opt_trans) predicted_trees.append(configuration.tree) return predicted_trees
def evaluate(sentences: List[Sentence], parsing_system: ParsingSystem, predicted_trees: List[DependencyTree], label_trees: List[DependencyTree]) -> str: """ Predict the dependency trees and evaluate them comparing with gold trees. """ return parsing_system.evaluate(sentences, predicted_trees, label_trees)
def generate_training_instances(parsing_system: ParsingSystem, sentences: List[List[str]], vocabulary: Vocabulary, trees: List[DependencyTree]) -> List[Dict]: """ Generates training instances of configuration and transition labels from the sentences and the corresponding dependency trees. """ num_transitions = parsing_system.num_transitions() instances: Dict[str, List] = [] for i in tqdm(range(len(sentences))): if trees[i].is_projective(): c = parsing_system.initial_configuration(sentences[i]) while not parsing_system.is_terminal(c): oracle = parsing_system.get_oracle(c, trees[i]) feature = get_configuration_features(c, vocabulary) label = [] for j in range(num_transitions): t = parsing_system.transitions[j] if t == oracle: label.append(1.) elif parsing_system.can_apply(c, t): label.append(0.) else: label.append(-1.) if 1.0 not in label: print(i, label) instances.append({"input": feature, "label": label}) c = parsing_system.apply(c, oracle) return instances
print("Reading training data") train_sentences, train_trees = read_conll_data(args.train_data_file_path) print("Reading validation data") validation_sentences, validation_trees = read_conll_data( args.validation_data_file_path) vocabulary = Vocabulary(train_sentences, train_trees) sorted_labels = [ item[0] for item in sorted(vocabulary.label_token_to_id.items(), key=lambda e: e[1]) ] non_null_sorted_labels = sorted_labels[1:] parsing_system = ParsingSystem(non_null_sorted_labels) # Generating training instances takes ~20 minutes everytime. So once you finalize the # feature generation and want to try different configs for experiments, you can use caching. if args.use_cached_data: print("Loading cached training instances") cache_processed_data_path = args.train_data_file_path.replace( "conll", "jsonl") if not os.path.exists(cache_processed_data_path): raise Exception( f"You asked to use cached data but {cache_processed_data_path} " f"is not available.") with open(cache_processed_data_path, "r") as file: train_instances = [ json.loads(line) for line in tqdm(file.readlines()) if line.strip()
type=str, help= 'serialization directory of the trained model. Used only for vocab.') parser.add_argument('gold_data_path', type=str, help='gold data file path.') parser.add_argument('prediction_data_path', type=str, help='predictions data file path.') args = parser.parse_args() print("Reading data") sentences, label_trees = read_conll_data(args.gold_data_path) _, predicted_trees = read_conll_data(args.prediction_data_path) print("Reading vocabulary") vocabulary_path = os.path.join(args.load_serialization_dir, "vocab.pickle") vocabulary = Vocabulary.load(vocabulary_path) sorted_labels = [ item[0] for item in sorted(vocabulary.label_token_to_id.items(), key=lambda e: e[1]) ] non_null_sorted_labels = sorted_labels[1:] parsing_system = ParsingSystem(non_null_sorted_labels) print("Evaluating") report = evaluate(sentences, parsing_system, predicted_trees, label_trees) print(report)