def build_amr(self, tokens, actions, labels, labelsA, predicates): apply_actions = [] for act, label, labelA, predicate in zip(actions, labels, labelsA, predicates): # print(act, label, labelA, predicate) if act.startswith('PR'): apply_actions.append(act + f'({predicate})') elif act.startswith('RA') or act.startswith( 'LA') and not act.endswith('(root)'): apply_actions.append(act + f'({label})') elif act.startswith('AD'): apply_actions.append(act + f'({labelA})') else: apply_actions.append(act) toks = [tok for tok in tokens if tok != "<eof>"] tr = AMRStateMachine(toks, verbose=False, spacy_lemmatizer=self.lemmatizer) tr.applyActions(apply_actions) return tr.amr
def eval_parser(exp_name, model, args, dev_sentences, h5py_test, epoch_idx, smatch_file, save_model, param_groups): # save also current learning rate learning_rate = ",".join([str(x['lr']) for x in param_groups]) model.eval() # evaluate on dev print_log('eval', f'Evaluating on: {args.amr_dev_data}') if save_model: predicted_amr_file = f'{save_model}/{exp_name}_amrs.epoch{epoch_idx}.dev.txt' else: predicted_amr_file = f'{exp_name}_amrs.epoch{epoch_idx}.dev.txt' with open(predicted_amr_file, 'w+') as f: f.write('') print_log('eval', f'Writing amr graphs to: {predicted_amr_file}') if save_model: actions_file = f'{save_model}/{exp_name}_actions.epoch{epoch_idx}.dev.txt' else: actions_file = f'{exp_name}_actions.epoch{epoch_idx}.dev.txt' if args.write_actions: print_log('eval', f'Writing actions to: {actions_file}') with open(actions_file, 'w+') as f: f.write('') sent_idx = 0 dev_hash = 0 for tokens in tqdm(dev_sentences): sent_rep = utils.vectorize_words(model, tokens, training=False, gpu=args.gpu) dev_b_emb = get_bert_embeddings(h5py_test, sent_idx, tokens) if not args.no_bert else None _, actions, labels, labelsA, predicates = model.forward_single( sent_rep, mode='predict', tokens=tokens, bert_embedding=dev_b_emb) # write amr graphs apply_actions = [] for act, label, labelA, predicate in zip(actions, labels, labelsA, predicates): # print(act, label, labelA, predicate) if act.startswith('PR'): apply_actions.append(act + f'({predicate})') elif act.startswith('RA') or act.startswith( 'LA') and not act.endswith('(root)'): apply_actions.append(act + f'({label})') elif act.startswith('AD'): apply_actions.append(act + f'({labelA})') else: apply_actions.append(act) if args.unit_tests: dev_hash += sum(model.action2idx[a] for a in actions) dev_hash += sum(model.labelsO2idx[l] for l in labels if l) dev_hash += sum(model.labelsA2idx[l] for l in labelsA if l) dev_hash += sum(model.pred2idx[p] if p in model.pred2idx else 0 for p in predicates if p) # print('[eval]',apply_actions) if args.write_actions: with open(actions_file, 'a') as f: f.write('\t'.join(tokens) + '\n') f.write('\t'.join(apply_actions) + '\n\n') tr = AMRStateMachine(tokens, verbose=False) tr.applyActions(apply_actions) with open(predicted_amr_file, 'a') as f: f.write(tr.amr.toJAMRString()) sent_idx += 1 # run smatch print_log('eval', f'Computing SMATCH') smatch_score = smatch_wrapper(args.amr_dev_data, predicted_amr_file, significant=3) print_log('eval', f'SMATCH: {smatch_score}') timestamp = str(datetime.now()).split('.')[0] # store all information in file print_log('eval', f'Writing SMATCH and other info to: {smatch_file}') with open(smatch_file, 'a') as fid: fid.write("\t".join([ f'epoch {epoch_idx}', f'learning_rate {learning_rate}', f'time {timestamp}', f'F-score {smatch_score}\n' ])) if args.unit_tests: test1 = (model.epoch_loss == 3360.1150283813477) test2 = (dev_hash == 6038) print( f'[run tests] epoch_loss==3360.1150283813477 (got {model.epoch_loss}) {"pass" if test1 else "fail"}', file=sys.stderr) print( f'[run tests] dev hash==6038 (got {dev_hash}) {"pass" if test2 else "fail"}', file=sys.stderr) assert (test1) assert (test2) return smatch_score