def compute_f1_from_list(l1, l2): t1_brackets, len1 = get_brackets(l1) t2_brackets, len2 = get_brackets(l2) assert len1 == len2 t1_brackets.add((0,len1)) t2_brackets.add((0,len2)) return compute_f1(t1_brackets & t2_brackets, t1_brackets, t2_brackets)
def verify_f1(path): f1_list = [] with codecs.open(path, encoding='utf-8') as f: for line in f: try: line = line.encode('UTF-8') except UnicodeError as e: print "ENCODING ERROR:", line, e line = "{}" loaded_example = json.loads(line) t1 = Tree.fromstring(loaded_example['sentence1_parse']) l1 = len(t1.leaves()) t1 = tree2list(t1) t2 = Tree.fromstring(loaded_example['sentence2_parse']) l2 = len(t2.leaves()) t2 = tree2list(t2) # print t1 # print l1 # print t2 # print l2 bt1 = get_balanced_tree(l1) bt2 = get_balanced_tree(l2) # print bt1 # print bt2 print t1 t1 = get_brackets(t1)[0] print t1 sys.exit(0) t2 = get_brackets(t2)[0] bt1 = get_brackets(bt1)[0] bt2 = get_brackets(bt2)[0] # t1.add((0,l1)) # bt1.add((0,l1)) # t2.add((0,l2)) # bt2.add((0,l2)) # print t1 # print t2 # print bt1 # print bt2 f1 = compute_f1(t1 & bt1, t1, bt1) f1_list.append(f1) f1 = compute_f1(t2 & bt2, t2, bt2) f1_list.append(f1) return sum(f1_list) / len(f1_list), len(f1_list)
j += 1 skip = False continue # read prediction example example_pred = json.loads(lines_pred[j]) key_pred = list2words(example_pred['sent1_tree']) + ' || ' + list2words(example_pred['sent2_tree']) key_pred = key_pred.lower() if check_key(key_dev, key_pred): tree1_dev = nltk.Tree.fromstring(example_dev['sentence1_parse']) tree2_dev = nltk.Tree.fromstring(example_dev['sentence2_parse']) tag_brackets1_dev = get_tag_brackets(tree1_dev) tag_brackets2_dev = get_tag_brackets(tree2_dev) brackets1_pred, len1 = get_brackets(example_pred['sent1_tree']) brackets2_pred, len2 = get_brackets(example_pred['sent2_tree']) brackets1_pred.add((0,len1)) brackets2_pred.add((0,len2)) for tag, bracket in tag_brackets1_dev: if tag not in tag_dev_freq: tag_dev_freq[tag] = 1 else: tag_dev_freq[tag] += 1 if bracket in brackets1_pred: if tag not in tag_pred_freq: tag_pred_freq[tag] = 1 else: tag_pred_freq[tag] += 1
def generate_trivial_tree_dataset_debug(read_file_path, write_file_path, trivial_tree='balanced'): if trivial_tree == 'balanced': get_trivial_tree = get_balanced_tree elif trivial_tree == 'left_branching': get_trivial_tree = get_left_branching_tree elif trivial_tree == 'right_branching': get_trivial_tree = get_right_branching_tree else: raise ValueError('invalid trivial tree form!') print '****** generating {} tree ******'.format(trivial_tree) f1_list = [] with codecs.open(read_file_path, encoding='utf-8') as f: for line in f: try: line = line.encode('UTF-8') except UnicodeError as e: print "ENCODING ERROR:", line, e line = "{}" loaded_example = json.loads(line) write_example = {} write_example['gold_label'] = loaded_example['gold_label'] if 'genre' in loaded_example: write_example['genre'] = loaded_example['genre'] if 'promptID' in loaded_example: write_example['promptID'] = loaded_example['promptID'] t1 = Tree.fromstring(loaded_example['sentence1_parse']) t2 = Tree.fromstring(loaded_example['sentence2_parse']) words1 = filter_words(t1) words2 = filter_words(t2) if len(words1) < 1 or len(words2) < 1: continue trivial_t1 = get_trivial_tree(words1) trivial_t2 = get_trivial_tree(words2) write_example['sentence1_prpn_binary_parse'] = trivial_t1 write_example['sentence2_prpn_binary_parse'] = trivial_t2 write_example['sentence1_binary_parse'] = tree2list(t1) write_example['sentence2_binary_parse'] = tree2list(t2) write_example['sentence1'] = words1 write_example['sentence2'] = words2 t1_brackets, l1 = get_brackets(tree2list(t1)) t2_brackets, l2 = get_brackets(tree2list(t2)) trivial_t1_brackets, trivial_l1 = get_brackets(trivial_t1) trivial_t2_brackets, trivial_l2 = get_brackets(trivial_t2) assert l1 == len(words1) assert l2 == len(words2) assert l1 == trivial_l1 assert l2 == trivial_l2 t1_brackets.add((0, l1)) trivial_t1_brackets.add((0, l1)) t2_brackets.add((0, l2)) trivial_t2_brackets.add((0, l2)) f1 = compute_f1(t1_brackets & trivial_t1_brackets, t1_brackets, trivial_t1_brackets) f1_list.append(f1) f1 = compute_f1(t2_brackets & trivial_t2_brackets, t2_brackets, trivial_t2_brackets) f1_list.append(f1) return sum(f1_list) / len(f1_list), len(f1_list)
def evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step, vocabulary=None, show_sample=False, eval_index=0): filename, dataset = eval_set A = Accumulator() index = len(log_entry.evaluation) eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]: pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps**( step / 10000.0) if FLAGS.pyramid_temperature_cycle_length > 0.0: min_temp = 1e-5 pyramid_temperature_multiplier *= (math.cos( (step) / FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2 else: pyramid_temperature_multiplier = None model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids, _, silver_tree = batch # eval_X_batch: <batch x maxlen x 2> # eval_y_batch: <batch > # silver_tree: # the dist is invalid for val # Run model. output = model( eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, store_parse_masks=True, example_lengths=eval_num_transitions_batch) # TODO: Restore support in Pyramid if using. can_sample = FLAGS.model_type in [ "ChoiPyramid" ] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) if show_sample and can_sample: tmp_samples = model.get_samples( eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) # tree_strs = prettyprint_trees(tmp_samples) tree_strs = [tree for tree in tmp_samples] tmp_samples = model.get_samples(eval_X_batch, vocabulary, only_one=False) # def get_max(s): # # test f1 # max = 0 # for x in s: # _, idx = x.split(',') # if int(idx) > max: # max = int(idx) # return max for s in (range(int(model.use_sentence_pair) + 1)): for b in (range(silver_tree.shape[0])): model_out = tmp_samples[s * silver_tree.shape[0] + b] std_out = silver_tree[b, :, s] std_out = set([x for x in std_out if x != '-1,-1']) model_out_brackets, model_out_max_l = get_brackets(model_out) model_out = set(convert_brackets_to_string(model_out_brackets)) outmost_bracket = '{:d},{:d}'.format(0, model_out_max_l) std_out.add(outmost_bracket) model_out.add(outmost_bracket) # print get_max(model_out), get_max(std_out) # print model_out # print std_out # print '=' * 30 # assert get_max(model_out) == get_max(std_out) overlap = model_out & std_out prec = float(len(overlap)) / (len(model_out) + 1e-8) reca = float(len(overlap)) / (len(std_out) + 1e-8) if len(std_out) == 0: reca = 1. if len(model_out) == 0: prec = 1. f1 = 2 * prec * reca / (prec + reca + 1e-8) A.add('f1', f1) if not FLAGS.write_eval_report: # Only show one sample, regardless of the number of batches. show_sample = False # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() # get the index of the max log-probability pred = logits.data.max(1, keepdim=False)[1].cpu() eval_accumulate(model, data_manager, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Optionally calculate transition loss/acc. #model.transition_loss if hasattr(model, 'transition_loss') else None # TODO: review this. the original line seems to have no effect # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if ( FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[: batch_size] if transitions_per_example is not None else None sent2_transitions = transitions_per_example[ batch_size:] if transitions_per_example is not None else None sent1_trees = tree_strs[: batch_size] if tree_strs is not None else None sent2_trees = tree_strs[ batch_size:] if tree_strs is not None else None else: sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() if tree_strs is not None: logger.Log('Sample: ' + str(tree_strs[0])) end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) # get the eval statistics (e.g. average F1) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join( FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy eval_f1 = eval_log.f1 return eval_class_acc, eval_trans_acc, eval_f1
def get_tag_acc_from_fpath(fpath_dev, fpath_pred, tag_dev_freq, tag_pred_freq, prpn=True): with open(fpath_pred, 'r') as fr_pred: with open(fpath_dev, 'r') as fr_dev: lines_pred = fr_pred.readlines() lines_dev = fr_dev.readlines() i = 0 j = 0 s = 0 x = 0 while i < len(lines_dev) and j < len(lines_pred): # read dev example example_dev = json.loads(lines_dev[i]) key_dev = example_dev['pairID'] # read prediction example example_pred = json.loads(lines_pred[j]) if not prpn: key_pred = example_pred['example_id'] else: key_pred = example_pred['pairID'] if key_dev == key_pred: tree1_dev = nltk.Tree.fromstring( example_dev['sentence1_parse']) tree2_dev = nltk.Tree.fromstring( example_dev['sentence2_parse']) tag_brackets1_dev = get_tag_brackets(tree1_dev) tag_brackets2_dev = get_tag_brackets(tree2_dev) if not prpn: brackets1_pred, len1 = get_brackets( example_pred['sent1_tree']) brackets2_pred, len2 = get_brackets( example_pred['sent2_tree']) else: brackets1_pred, len1 = get_brackets( example_pred['sentence1_prpn_binary_parse']) brackets2_pred, len2 = get_brackets( example_pred['sentence2_prpn_binary_parse']) brackets1_pred.add((0, len1)) brackets2_pred.add((0, len2)) for tag, bracket in tag_brackets1_dev: if tag not in tag_dev_freq: tag_dev_freq[tag] = 1 else: tag_dev_freq[tag] += 1 if bracket in brackets1_pred: if tag not in tag_pred_freq: tag_pred_freq[tag] = 1 else: tag_pred_freq[tag] += 1 for tag, bracket in tag_brackets2_dev: if tag not in tag_dev_freq: tag_dev_freq[tag] = 1 else: tag_dev_freq[tag] += 1 if bracket in brackets2_pred: if tag not in tag_pred_freq: tag_pred_freq[tag] = 1 else: tag_pred_freq[tag] += 1 i += 1 j += 1 s += 1 else: i += 1 print '\t{}: {}/{}|{}'.format( fpath_pred.split('/')[-1], len(lines_dev), len(lines_pred), s)
def get_f1_against_right_branching_from_list(tree_list): t_brackets, length = get_brackets(tree_list) t_brackets.add((0, length)) rb_brackets = {(x, length) for x in range(length - 1)} return compute_f1(t_brackets & rb_brackets, t_brackets, rb_brackets)
def compute_f1_baseline(path): ''' for RL ''' rb_f1_list = [] lb_f1_list = [] prpn_f1_list = [] prpn_f1_df_list = [] with codecs.open(path, encoding='utf-8') as f: for line in f: try: line = line.encode('UTF-8') except UnicodeError as e: print "ENCODING ERROR:", line, e line = "{}" loaded_example = json.loads(line) if loaded_example["gold_label"] not in LABEL_MAP: # 158 here continue prpn_gates1 = loaded_example['sentence1_prpn_gates'] prpn_gates2 = loaded_example['sentence2_prpn_gates'] prpn_df_tree1 = get_brackets( build_tree_by_definition(prpn_gates1[1:], loaded_example['sentence1']))[0] prpn_df_tree2 = get_brackets( build_tree_by_definition(prpn_gates2[1:], loaded_example['sentence2']))[0] std_tree1 = get_brackets( loaded_example['sentence1_binary_parse'])[0] prpn_tree1 = get_brackets( loaded_example['sentence1_prpn_binary_parse'])[0] std_tree2 = get_brackets( loaded_example['sentence2_binary_parse'])[0] prpn_tree2 = get_brackets( loaded_example['sentence2_prpn_binary_parse'])[0] len1 = len(loaded_example['sentence1']) if len1 < 3: lb_tree1 = set() rb_tree1 = set() else: lb_tree1 = {(0, i) for i in range(2, len1 - 1)} rb_tree1 = {(i, len1) for i in range(1, len1 - 2)} len2 = len(loaded_example['sentence2']) if len2 < 3: lb_tree2 = set() rb_tree2 = set() else: lb_tree2 = {(0, i) for i in range(2, len2 - 1)} rb_tree2 = {(i, len2) for i in range(1, len2 - 2)} rb_f1_list.append( compute_f1(rb_tree1 & std_tree1, std_tree1, rb_tree1)) rb_f1_list.append( compute_f1(rb_tree2 & std_tree2, std_tree2, rb_tree2)) lb_f1_list.append( compute_f1(lb_tree1 & std_tree1, std_tree1, lb_tree1)) lb_f1_list.append( compute_f1(lb_tree2 & std_tree2, std_tree2, lb_tree2)) prpn_f1_list.append( compute_f1(prpn_tree1 & std_tree1, std_tree1, prpn_tree1)) prpn_f1_list.append( compute_f1(prpn_tree2 & std_tree2, std_tree2, prpn_tree2)) prpn_f1_df_list.append( compute_f1(prpn_df_tree1 & std_tree1, std_tree1, prpn_df_tree1)) prpn_f1_df_list.append( compute_f1(prpn_df_tree2 & std_tree2, std_tree2, prpn_df_tree2)) rb_f1 = sum(rb_f1_list) / len(rb_f1_list) lb_f1 = sum(lb_f1_list) / len(lb_f1_list) prpn_f1 = sum(prpn_f1_list) / len(prpn_f1_list) prpn_f1_df = sum(prpn_f1_df_list) / len(prpn_f1_df_list) return rb_f1, lb_f1, prpn_f1, prpn_f1_df