Ejemplo n.º 1
0
def compute_f1_from_list(l1, l2):
    t1_brackets, len1 = get_brackets(l1)
    t2_brackets, len2 = get_brackets(l2)
    assert len1 == len2
    t1_brackets.add((0,len1))
    t2_brackets.add((0,len2))
    return compute_f1(t1_brackets & t2_brackets, t1_brackets, t2_brackets) 
Ejemplo n.º 2
0
def verify_f1(path):
    f1_list = []
    with codecs.open(path, encoding='utf-8') as f:
        for line in f:
            try:
                line = line.encode('UTF-8')
            except UnicodeError as e:
                print "ENCODING ERROR:", line, e
                line = "{}"
            loaded_example = json.loads(line)
            t1 = Tree.fromstring(loaded_example['sentence1_parse'])
            l1 = len(t1.leaves())
            t1 = tree2list(t1)
            t2 = Tree.fromstring(loaded_example['sentence2_parse'])
            l2 = len(t2.leaves())
            t2 = tree2list(t2)
            # print t1
            # print l1
            # print t2
            # print l2

            bt1 = get_balanced_tree(l1)
            bt2 = get_balanced_tree(l2)
            # print bt1
            # print bt2

            print t1
            t1 = get_brackets(t1)[0]
            print t1
            sys.exit(0)

            t2 = get_brackets(t2)[0]
            bt1 = get_brackets(bt1)[0]
            bt2 = get_brackets(bt2)[0]

            # t1.add((0,l1))
            # bt1.add((0,l1))
            # t2.add((0,l2))
            # bt2.add((0,l2))

            # print t1
            # print t2
            # print bt1
            # print bt2

            f1 = compute_f1(t1 & bt1, t1, bt1)
            f1_list.append(f1)
            f1 = compute_f1(t2 & bt2, t2, bt2)
            f1_list.append(f1)

    return sum(f1_list) / len(f1_list), len(f1_list)
Ejemplo n.º 3
0
                    j += 1
                    skip = False
                    continue

                # read prediction example
                example_pred = json.loads(lines_pred[j])
                key_pred = list2words(example_pred['sent1_tree']) + ' || ' + list2words(example_pred['sent2_tree'])
                key_pred = key_pred.lower()

                if check_key(key_dev, key_pred):
                    tree1_dev = nltk.Tree.fromstring(example_dev['sentence1_parse'])
                    tree2_dev = nltk.Tree.fromstring(example_dev['sentence2_parse'])
                    tag_brackets1_dev = get_tag_brackets(tree1_dev)
                    tag_brackets2_dev = get_tag_brackets(tree2_dev)

                    brackets1_pred, len1 = get_brackets(example_pred['sent1_tree'])
                    brackets2_pred, len2 = get_brackets(example_pred['sent2_tree'])
                    brackets1_pred.add((0,len1))
                    brackets2_pred.add((0,len2))

                    for tag, bracket in tag_brackets1_dev:
                        if tag not in tag_dev_freq:
                            tag_dev_freq[tag] = 1
                        else:
                            tag_dev_freq[tag] += 1                           
                        if bracket in brackets1_pred:
                            if tag not in tag_pred_freq:
                                tag_pred_freq[tag] = 1
                            else:
                                tag_pred_freq[tag] += 1     
def generate_trivial_tree_dataset_debug(read_file_path,
                                        write_file_path,
                                        trivial_tree='balanced'):
    if trivial_tree == 'balanced':
        get_trivial_tree = get_balanced_tree
    elif trivial_tree == 'left_branching':
        get_trivial_tree = get_left_branching_tree
    elif trivial_tree == 'right_branching':
        get_trivial_tree = get_right_branching_tree
    else:
        raise ValueError('invalid trivial tree form!')
    print '****** generating {} tree ******'.format(trivial_tree)

    f1_list = []

    with codecs.open(read_file_path, encoding='utf-8') as f:
        for line in f:
            try:
                line = line.encode('UTF-8')
            except UnicodeError as e:
                print "ENCODING ERROR:", line, e
                line = "{}"
            loaded_example = json.loads(line)
            write_example = {}

            write_example['gold_label'] = loaded_example['gold_label']
            if 'genre' in loaded_example:
                write_example['genre'] = loaded_example['genre']
            if 'promptID' in loaded_example:
                write_example['promptID'] = loaded_example['promptID']

            t1 = Tree.fromstring(loaded_example['sentence1_parse'])
            t2 = Tree.fromstring(loaded_example['sentence2_parse'])
            words1 = filter_words(t1)
            words2 = filter_words(t2)

            if len(words1) < 1 or len(words2) < 1:
                continue

            trivial_t1 = get_trivial_tree(words1)
            trivial_t2 = get_trivial_tree(words2)

            write_example['sentence1_prpn_binary_parse'] = trivial_t1
            write_example['sentence2_prpn_binary_parse'] = trivial_t2
            write_example['sentence1_binary_parse'] = tree2list(t1)
            write_example['sentence2_binary_parse'] = tree2list(t2)
            write_example['sentence1'] = words1
            write_example['sentence2'] = words2

            t1_brackets, l1 = get_brackets(tree2list(t1))
            t2_brackets, l2 = get_brackets(tree2list(t2))
            trivial_t1_brackets, trivial_l1 = get_brackets(trivial_t1)
            trivial_t2_brackets, trivial_l2 = get_brackets(trivial_t2)
            assert l1 == len(words1)
            assert l2 == len(words2)
            assert l1 == trivial_l1
            assert l2 == trivial_l2

            t1_brackets.add((0, l1))
            trivial_t1_brackets.add((0, l1))
            t2_brackets.add((0, l2))
            trivial_t2_brackets.add((0, l2))

            f1 = compute_f1(t1_brackets & trivial_t1_brackets, t1_brackets,
                            trivial_t1_brackets)
            f1_list.append(f1)
            f1 = compute_f1(t2_brackets & trivial_t2_brackets, t2_brackets,
                            trivial_t2_brackets)
            f1_list.append(f1)

    return sum(f1_list) / len(f1_list), len(f1_list)
Ejemplo n.º 5
0
def evaluate(FLAGS,
             model,
             data_manager,
             eval_set,
             log_entry,
             logger,
             step,
             vocabulary=None,
             show_sample=False,
             eval_index=0):
    filename, dataset = eval_set

    A = Accumulator()
    index = len(log_entry.evaluation)
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]:
        pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps**(
            step / 10000.0)
        if FLAGS.pyramid_temperature_cycle_length > 0.0:
            min_temp = 1e-5
            pyramid_temperature_multiplier *= (math.cos(
                (step) / FLAGS.pyramid_temperature_cycle_length) + 1 +
                                               min_temp) / 2
    else:
        pyramid_temperature_multiplier = None

    model.eval()

    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids, _, silver_tree = batch
        # eval_X_batch: <batch x maxlen x 2>
        # eval_y_batch: <batch >
        # silver_tree:
        # the dist is invalid for val

        # Run model.
        output = model(
            eval_X_batch,
            eval_transitions_batch,
            eval_y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions,
            pyramid_temperature_multiplier=pyramid_temperature_multiplier,
            store_parse_masks=True,
            example_lengths=eval_num_transitions_batch)

        # TODO: Restore support in Pyramid if using.
        can_sample = FLAGS.model_type in [
            "ChoiPyramid"
        ] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser)
        if show_sample and can_sample:
            tmp_samples = model.get_samples(
                eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            # tree_strs = prettyprint_trees(tmp_samples)
            tree_strs = [tree for tree in tmp_samples]

        tmp_samples = model.get_samples(eval_X_batch,
                                        vocabulary,
                                        only_one=False)

        # def get_max(s):
        #     # test f1
        #     max = 0
        #     for x in s:
        #         _, idx = x.split(',')
        #         if int(idx) > max:
        #             max = int(idx)
        #     return max

        for s in (range(int(model.use_sentence_pair) + 1)):
            for b in (range(silver_tree.shape[0])):
                model_out = tmp_samples[s * silver_tree.shape[0] + b]
                std_out = silver_tree[b, :, s]
                std_out = set([x for x in std_out if x != '-1,-1'])
                model_out_brackets, model_out_max_l = get_brackets(model_out)
                model_out = set(convert_brackets_to_string(model_out_brackets))

                outmost_bracket = '{:d},{:d}'.format(0, model_out_max_l)
                std_out.add(outmost_bracket)
                model_out.add(outmost_bracket)

                # print get_max(model_out), get_max(std_out)
                # print model_out
                # print std_out
                # print '=' * 30
                # assert get_max(model_out) == get_max(std_out)

                overlap = model_out & std_out
                prec = float(len(overlap)) / (len(model_out) + 1e-8)
                reca = float(len(overlap)) / (len(std_out) + 1e-8)
                if len(std_out) == 0:
                    reca = 1.
                    if len(model_out) == 0:
                        prec = 1.
                f1 = 2 * prec * reca / (prec + reca + 1e-8)
                A.add('f1', f1)

        if not FLAGS.write_eval_report:
            # Only show one sample, regardless of the number of batches.
            show_sample = False

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()

        # get the index of the max log-probability
        pred = logits.data.max(1, keepdim=False)[1].cpu()

        eval_accumulate(model, data_manager, A, batch)
        A.add('class_correct', pred.eq(target).sum())
        A.add('class_total', target.size(0))

        # Optionally calculate transition loss/acc.
        #model.transition_loss if hasattr(model, 'transition_loss') else None
        # TODO: review this. the original line seems to have no effect

        # Update Aggregate Accuracies
        total_tokens += sum([(nt + 1) / 2
                             for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.spinn.get_transitions_per_example(
                style="preds" if FLAGS.eval_report_use_preds else "given") if (
                    FLAGS.model_type == "SPINN"
                    and FLAGS.use_internal_parser) else (None, None)

            if model.use_sentence_pair:
                batch_size = pred.size(0)
                sent1_transitions = transitions_per_example[:
                                                            batch_size] if transitions_per_example is not None else None
                sent2_transitions = transitions_per_example[
                    batch_size:] if transitions_per_example is not None else None

                sent1_trees = tree_strs[:
                                        batch_size] if tree_strs is not None else None
                sent2_trees = tree_strs[
                    batch_size:] if tree_strs is not None else None
            else:
                sent1_transitions = transitions_per_example if transitions_per_example is not None else None
                sent2_transitions = None

                sent1_trees = tree_strs if tree_strs is not None else None
                sent2_trees = None

            reporter.save_batch(pred, target, eval_ids,
                                output.data.cpu().numpy(), sent1_transitions,
                                sent2_transitions, sent1_trees, sent2_trees)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()
    if tree_strs is not None:
        logger.Log('Sample: ' + str(tree_strs[0]))

    end = time.time()
    total_time = end - start

    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)

    eval_stats(model, A, eval_log)  # get the eval statistics (e.g. average F1)
    eval_log.filename = filename

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(
            FLAGS.log_path,
            FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report")
        reporter.write_report(eval_report_path)

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy
    eval_f1 = eval_log.f1

    return eval_class_acc, eval_trans_acc, eval_f1
Ejemplo n.º 6
0
def get_tag_acc_from_fpath(fpath_dev,
                           fpath_pred,
                           tag_dev_freq,
                           tag_pred_freq,
                           prpn=True):
    with open(fpath_pred, 'r') as fr_pred:
        with open(fpath_dev, 'r') as fr_dev:
            lines_pred = fr_pred.readlines()
            lines_dev = fr_dev.readlines()
            i = 0
            j = 0
            s = 0
            x = 0
            while i < len(lines_dev) and j < len(lines_pred):
                # read dev example
                example_dev = json.loads(lines_dev[i])
                key_dev = example_dev['pairID']

                # read prediction example
                example_pred = json.loads(lines_pred[j])
                if not prpn:
                    key_pred = example_pred['example_id']
                else:
                    key_pred = example_pred['pairID']

                if key_dev == key_pred:
                    tree1_dev = nltk.Tree.fromstring(
                        example_dev['sentence1_parse'])
                    tree2_dev = nltk.Tree.fromstring(
                        example_dev['sentence2_parse'])
                    tag_brackets1_dev = get_tag_brackets(tree1_dev)
                    tag_brackets2_dev = get_tag_brackets(tree2_dev)

                    if not prpn:
                        brackets1_pred, len1 = get_brackets(
                            example_pred['sent1_tree'])
                        brackets2_pred, len2 = get_brackets(
                            example_pred['sent2_tree'])
                    else:
                        brackets1_pred, len1 = get_brackets(
                            example_pred['sentence1_prpn_binary_parse'])
                        brackets2_pred, len2 = get_brackets(
                            example_pred['sentence2_prpn_binary_parse'])

                    brackets1_pred.add((0, len1))
                    brackets2_pred.add((0, len2))

                    for tag, bracket in tag_brackets1_dev:
                        if tag not in tag_dev_freq:
                            tag_dev_freq[tag] = 1
                        else:
                            tag_dev_freq[tag] += 1
                        if bracket in brackets1_pred:
                            if tag not in tag_pred_freq:
                                tag_pred_freq[tag] = 1
                            else:
                                tag_pred_freq[tag] += 1

                    for tag, bracket in tag_brackets2_dev:
                        if tag not in tag_dev_freq:
                            tag_dev_freq[tag] = 1
                        else:
                            tag_dev_freq[tag] += 1
                        if bracket in brackets2_pred:
                            if tag not in tag_pred_freq:
                                tag_pred_freq[tag] = 1
                            else:
                                tag_pred_freq[tag] += 1
                    i += 1
                    j += 1
                    s += 1
                else:
                    i += 1
            print '\t{}: {}/{}|{}'.format(
                fpath_pred.split('/')[-1], len(lines_dev), len(lines_pred), s)
Ejemplo n.º 7
0
def get_f1_against_right_branching_from_list(tree_list):
    t_brackets, length = get_brackets(tree_list)
    t_brackets.add((0, length))
    rb_brackets = {(x, length) for x in range(length - 1)}
    return compute_f1(t_brackets & rb_brackets, t_brackets, rb_brackets)
Ejemplo n.º 8
0
def compute_f1_baseline(path):
    '''
    for RL
    '''

    rb_f1_list = []
    lb_f1_list = []
    prpn_f1_list = []
    prpn_f1_df_list = []
    with codecs.open(path, encoding='utf-8') as f:
        for line in f:
            try:
                line = line.encode('UTF-8')
            except UnicodeError as e:
                print "ENCODING ERROR:", line, e
                line = "{}"
            loaded_example = json.loads(line)
            if loaded_example["gold_label"] not in LABEL_MAP:
                # 158 here
                continue

            prpn_gates1 = loaded_example['sentence1_prpn_gates']
            prpn_gates2 = loaded_example['sentence2_prpn_gates']
            prpn_df_tree1 = get_brackets(
                build_tree_by_definition(prpn_gates1[1:],
                                         loaded_example['sentence1']))[0]
            prpn_df_tree2 = get_brackets(
                build_tree_by_definition(prpn_gates2[1:],
                                         loaded_example['sentence2']))[0]

            std_tree1 = get_brackets(
                loaded_example['sentence1_binary_parse'])[0]
            prpn_tree1 = get_brackets(
                loaded_example['sentence1_prpn_binary_parse'])[0]
            std_tree2 = get_brackets(
                loaded_example['sentence2_binary_parse'])[0]
            prpn_tree2 = get_brackets(
                loaded_example['sentence2_prpn_binary_parse'])[0]

            len1 = len(loaded_example['sentence1'])
            if len1 < 3:
                lb_tree1 = set()
                rb_tree1 = set()
            else:
                lb_tree1 = {(0, i) for i in range(2, len1 - 1)}
                rb_tree1 = {(i, len1) for i in range(1, len1 - 2)}
            len2 = len(loaded_example['sentence2'])
            if len2 < 3:
                lb_tree2 = set()
                rb_tree2 = set()
            else:
                lb_tree2 = {(0, i) for i in range(2, len2 - 1)}
                rb_tree2 = {(i, len2) for i in range(1, len2 - 2)}

            rb_f1_list.append(
                compute_f1(rb_tree1 & std_tree1, std_tree1, rb_tree1))
            rb_f1_list.append(
                compute_f1(rb_tree2 & std_tree2, std_tree2, rb_tree2))
            lb_f1_list.append(
                compute_f1(lb_tree1 & std_tree1, std_tree1, lb_tree1))
            lb_f1_list.append(
                compute_f1(lb_tree2 & std_tree2, std_tree2, lb_tree2))
            prpn_f1_list.append(
                compute_f1(prpn_tree1 & std_tree1, std_tree1, prpn_tree1))
            prpn_f1_list.append(
                compute_f1(prpn_tree2 & std_tree2, std_tree2, prpn_tree2))
            prpn_f1_df_list.append(
                compute_f1(prpn_df_tree1 & std_tree1, std_tree1,
                           prpn_df_tree1))
            prpn_f1_df_list.append(
                compute_f1(prpn_df_tree2 & std_tree2, std_tree2,
                           prpn_df_tree2))

    rb_f1 = sum(rb_f1_list) / len(rb_f1_list)
    lb_f1 = sum(lb_f1_list) / len(lb_f1_list)
    prpn_f1 = sum(prpn_f1_list) / len(prpn_f1_list)
    prpn_f1_df = sum(prpn_f1_df_list) / len(prpn_f1_df_list)

    return rb_f1, lb_f1, prpn_f1, prpn_f1_df