def from_args(cls, args): c = cls() c.name = args.model_path if args.model_path else c.name c.param_path = args.param_path if args.param_path else c.param_path c.mode = args.mode c.train_set_path = args.train_path c.dev_set_path = args.dev_path c.test_set_path = args.test_path c.test_set_path2 = args.test_path2 c.test_conllu_gold_path = args.test_conllu_gold_path c.wembpath = args.word_emb_path c.conll_format = const.CONLL06 if args.data_format is 'conll06' else \ const.CONLL09 if args.data_format is 'conll09' else const.CONLLU c.tune_las = args.tune == 'las' c.use_bilstm_input = args.use_bilstm_input c.in_lstm_size = args.input_bilstm_size c.in_lstm_count = args.input_bilstm_layers c.hid_mlp_size = args.hidden_layer_size c.hid_mlp_count = args.hidden_layers c.hid_lstm_size = args.encoding_lstm_size c.hid_lstm_count = args.encoding_lstm_layers c.batch_size = args.batch_size c.stack_feats = args.stack_feats c.buffer_feats = args.buffer_feats c.max_epochs = args.epochs c.iter_per_test = args.iter_per_test c.min_iter_for_test = args.min_iter_for_test from networks import RRTNetwork c.nn = RRTNetwork.from_args(args) if c.conll_format is const.CONLLU: c.dev_ud = load_conllu_file(c.dev_set_path) c.gold_ud = load_conllu_file(c.test_conllu_gold_path) return c
# Load the data root_factors = [ud_dataset.UDDataset.FORMS] train = ud_dataset.UDDataset("{}-ud-train.conllu".format(args.basename), args.lr_allow_copy, root_factors) dev = ud_dataset.UDDataset("{}-ud-dev.conllu".format(args.basename), args.lr_allow_copy, root_factors, train=train, shuffle_batches=False) dev_udpipe = ud_dataset.UDDataset("{}-ud-dev-udpipe.conllu".format( args.basename), args.lr_allow_copy, root_factors, train=train, shuffle_batches=False) dev_conllu = conll18_ud_eval.load_conllu_file("{}-ud-dev.conllu".format( args.basename)) # Construct the network network = Network(threads=args.threads) network.construct( args, len(train.factors[train.FORMS].words), len(train.factors[train.FORMS].alphabet), dict((tag, len(train.factors[train.FACTORS_MAP[tag]].words)) for tag in args.tags)) if args.checkpoint: network.saver_train.restore(network.session, args.checkpoint) with open("{}/cmd".format(args.logdir), "w") as cmd_file: cmd_file.write(command_line) log_file = open("{}/log".format(args.logdir), "w")
def main(): # Parse arguments parser = argparse.ArgumentParser() parser.add_argument("truth", type=str, help="Directory name of the truth dataset.") parser.add_argument("system", type=str, help="Directory name of system output.") parser.add_argument("output", type=str, help="Directory name of the output directory.") args = parser.parse_args() # Load input dataset metadata.json with open(args.truth + "/metadata.json", "r") as metadata_file: metadata = json.load(metadata_file) # Evaluate and compute sum of all treebanks metrics = [ "Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX" ] treebanks = 0 summation = {} results = [] results_las, results_mlas, results_blex = {}, {}, {} for entry in metadata: treebanks += 1 ltcode, goldfile, outfile = "_".join( (entry['lcode'], entry['tcode'])), entry['goldfile'], entry['outfile'] # Load gold data try: gold = load_conllu_file(args.truth + "/" + goldfile) except: results.append( (ltcode + "-Status", "Error: Cannot load gold file")) continue # Load system data try: system = load_conllu_file(args.system + "/" + outfile) except UDError as e: if e.args[0].startswith("There is a cycle"): results.append( (ltcode + "-Status", "Error: There is a cycle in generated CoNLL-U file")) continue if e.args[0].startswith("There are multiple roots"): results.append(( ltcode + "-Status", "Error: There are multiple roots in a sentence in generated CoNLL-U file" )) continue results.append(( ltcode + "-Status", "Error: There is a format error (tabs, ID values, etc) in generated CoNLL-U file" )) continue except: results.append((ltcode + "-Status", "Error: Cannot open generated CoNLL-U file")) continue # Check for correctness if not system.characters: results.append( (ltcode + "-Status", "Error: The system file is empty")) continue if system.characters != gold.characters: results.append(( ltcode + "-Status", "Error: The concatenation of tokens in gold file and in system file differ, system file has {} nonspace characters, which is approximately {}% of the gold file" .format( len(system.characters), int(100 * len(system.characters) / len(gold.characters))))) continue # Evaluate try: evaluation = evaluate(gold, system) except: # Should not happen results.append(( ltcode + "-Status", "Error: Cannot evaluate generated CoNLL-U file, internal error" )) continue # Generate output metrics and compute sum results.append(( ltcode + "-Status", "OK: Result F1 scores rounded to 5% are LAS={:.0f}% MLAS={:.0f}% BLEX={:.0f}%" .format(100 * round_score(evaluation["LAS"].f1), 100 * round_score(evaluation["MLAS"].f1), 100 * round_score(evaluation["BLEX"].f1)))) for metric in metrics: results.append((ltcode + "-" + metric + "-F1", "{:.9f}".format(100 * evaluation[metric].f1))) summation[metric] = summation.get(metric, 0) + evaluation[metric].f1 results_las[ltcode] = evaluation["LAS"].f1 results_mlas[ltcode] = evaluation["MLAS"].f1 results_blex[ltcode] = evaluation["BLEX"].f1 # Compute averages for metric in reversed(metrics): results.insert(0, ("total-" + metric + "-F1", "{:.9f}".format( 100 * summation.get(metric, 0) / treebanks))) # Generate evaluation.prototext with open(args.output + "/evaluation.prototext", "w") as evaluation: for key, value in results: print('measure{{\n key: "{}"\n value: "{}"\n}}'.format( key, value), file=evaluation) # Generate LAS-F1, MLAS-F1, BLEX-F1 + Status on stdout, Status on stderr for key, value in results: if not key.endswith("-Status"): continue ltcode = key[:-len("-Status")] print("{:13} LAS={:10.6f}% MLAS={:10.6f}% BLEX={:10.6f}% ({})".format( ltcode, 100 * results_las.get(ltcode, 0.), 100 * results_mlas.get(ltcode, 0.), 100 * results_blex.get(ltcode, 0.), value), file=sys.stdout) print("{:13} {}".format(ltcode, value), file=sys.stderr)
def conll_eval(system_file, gold_file): gold_ud = load_conllu_file(gold_file) system_ud = load_conllu_file(system_file) return evaluate(gold_ud, system_ud)