def generate_all_potential_parses_for_sentence(self,
                                                   tagged_sentence,
                                                   predicted_tags,
                                                   min_probability=0.1):

        pos_ptag_seq, _, tag2span, all_predicted_rtags, _ = self.get_tags_relations_for(
            tagged_sentence, predicted_tags, self.cr_tags)

        if len(all_predicted_rtags) == 0:
            return []

        # tags without positional info
        rtag_seq = [t for t, i in pos_ptag_seq if t[0].isdigit()]
        # if not at least 2 concept codes, then can't parse
        if len(rtag_seq) < 2:
            return []

        words = [wd for wd, tags in tagged_sentence]

        # Initialize stack, basic parser and oracle
        parser = ShiftReduceParser(Stack(verbose=False))
        parser.stack.push((ROOT, 0))
        # needs to be a tuple
        oracle = Oracle([], parser)

        tag2words = defaultdict(list)
        for ix, tag_pair in enumerate(pos_ptag_seq):
            bstart, bstop = tag2span[tag_pair]
            tag2words[tag_pair] = self.ngram_extractor.extract(
                words[bstart:bstop + 1])  # type: List[str]

        all_parses = self.recursively_parse(defaultdict(set), defaultdict(set),
                                            oracle, pos_ptag_seq, tag2span,
                                            tag2words, 0, words,
                                            defaultdict(list), min_probability)
        return all_parses
    def generate_training_data(self,
                               tagged_sentence,
                               predicted_tags,
                               out_parse_examples,
                               out_crel_examples,
                               predict_only=False):

        pos_ptag_seq, pos_ground_truth_crels, tag2span, all_predicted_rtags, all_actual_crels = self.get_tags_relations_for(
            tagged_sentence, predicted_tags, self.cr_tags)
        if predict_only:
            # clear labels
            pos_ground_truth_crels = []
            all_actual_crels = set()

        if len(all_predicted_rtags) == 0:
            return []

        words = [wd for wd, tags in tagged_sentence]

        # Initialize stack, basic parser and oracle
        stack = Stack(verbose=False)
        # needs to be a tuple
        stack.push((ROOT, 0))
        parser = ShiftReduceParser(stack)
        oracle = Oracle(pos_ground_truth_crels, parser)

        predicted_relations = set()  # type: Set[str]

        # instead of head and modifiers, we will map causers to effects, and vice versa
        effect2causers = defaultdict(set)
        # heads can have multiple modifiers
        cause2effects = defaultdict(set)

        # tags without positional info
        rtag_seq = [t for t, i in pos_ptag_seq if t[0].isdigit()]
        # if not at least 2 concept codes, then can't parse
        if len(rtag_seq) < 2:
            return []

        tag2words = defaultdict(list)
        for ix, tag_pair in enumerate(pos_ptag_seq):
            bstart, bstop = tag2span[tag_pair]
            word_seq = words[bstart:bstop + 1]
            tag2words[tag_pair] = self.ngram_extractor.extract(
                word_seq)  # type: List[str]

        # Oracle parsing logic
        # consume the buffer
        for tag_ix, buffer_tag_pair in enumerate(pos_ptag_seq):
            buffer_tag = buffer_tag_pair[0]
            bstart, bstop = tag2span[buffer_tag_pair]

            remaining_buffer_tags = pos_ptag_seq[tag_ix:]
            # Consume the stack
            while True:
                tos_tag_pair = oracle.tos()
                tos_tag = tos_tag_pair[0]

                # Returns -1,-1 if TOS is ROOT
                if tos_tag == ROOT:
                    tstart, tstop = -1, -1
                else:
                    tstart, tstop = tag2span[tos_tag_pair]

                # Note that the end ix in tag2span is always the last index, not the last + 1
                btwn_start, btwn_stop = min(tstop + 1,
                                            len(words)), max(0, bstart)

                btwn_word_seq = words[btwn_start:btwn_stop]
                distance = len(btwn_word_seq)
                btwn_word_ngrams = self.ngram_extractor.extract(
                    btwn_word_seq)  # type: List[str]

                feats = self.feat_extractor.extract(
                    stack_tags=stack.contents(),
                    buffer_tags=remaining_buffer_tags,
                    tag2word_seq=tag2words,
                    between_word_seq=btwn_word_ngrams,
                    distance=distance,
                    cause2effects=cause2effects,
                    effect2causers=effect2causers,
                    positive_val=self.positive_val)

                # Consult Oracle or Model based on coin toss
                if predict_only:
                    action = self.predict_parse_action(
                        feats=feats,
                        tos=tos_tag,
                        models=self.parser_models[-1],
                        vectorizer=self.parser_feature_vectorizers[-1])
                else:  # if training
                    gold_action = oracle.consult(tos_tag_pair, buffer_tag_pair)
                    rand_float = np.random.random_sample(
                    )  # between [0,1) (half-open interval, includes 0 but not 1)
                    # If no trained models, always use Oracle
                    if len(self.parser_models) == 0:
                        action = gold_action
                    elif rand_float <= self.beta:
                        action = self.predict_parse_action(
                            feats=feats,
                            tos=tos_tag,
                            models=self.parser_models[-1],
                            vectorizer=self.parser_feature_vectorizers[-1])
                    else:
                        if len(self.parser_models) < 2:
                            action = gold_action
                        # use previous model if available
                        else:
                            action = self.predict_parse_action(
                                feats=feats,
                                tos=tos_tag,
                                models=self.parser_models[-2],
                                vectorizer=self.parser_feature_vectorizers[-2])

                    # Given the remaining tags, what is the cost of this decision
                    # in terms of the optimal decision(s) that can be made?
                    cost_per_action = self.cost_function(
                        pos_ground_truth_crels, remaining_buffer_tags, oracle)
                    # make a copy as changing later
                    out_parse_examples.add(dict(feats), gold_action,
                                           cost_per_action)

                # Decide the direction of the causal relation
                if action in [LARC, RARC]:

                    c_e_pair = (tos_tag, buffer_tag)
                    # Convert to a string Causer:{l}->Result:{r}
                    cause_effect = denormalize_cr(c_e_pair)

                    e_c_pair = (buffer_tag, tos_tag)
                    # Convert to a string Causer:{l}->Result:{r}
                    effect_cause = denormalize_cr(e_c_pair)

                    if predict_only:
                        gold_lr_action = None
                    else:
                        if cause_effect in all_actual_crels and effect_cause in all_actual_crels:
                            gold_lr_action = CAUSE_AND_EFFECT
                        elif cause_effect in all_actual_crels:
                            gold_lr_action = CAUSE_EFFECT
                        elif effect_cause in all_actual_crels:
                            gold_lr_action = EFFECT_CAUSE
                        else:
                            gold_lr_action = REJECT

                    # Add additional features
                    # needs to be before predict below
                    crel_feats = self.crel_features(action, tos_tag,
                                                    buffer_tag)
                    feats.update(crel_feats)
                    rand_float = np.random.random_sample()

                    if predict_only:
                        lr_action = self.predict_crel_action(
                            feats=feats,
                            model=self.crel_models[-1],
                            vectorizer=self.crel_feat_vectorizers[-1])
                    else:
                        if len(self.crel_models) == 0:
                            lr_action = gold_lr_action
                        elif rand_float <= self.beta:
                            lr_action = self.predict_crel_action(
                                feats=feats,
                                model=self.crel_models[-1],
                                vectorizer=self.crel_feat_vectorizers[-1])
                        else:
                            if len(self.crel_models) < 2:
                                lr_action = gold_lr_action
                            else:
                                lr_action = self.predict_crel_action(
                                    feats=feats,
                                    model=self.crel_models[-2],
                                    vectorizer=self.crel_feat_vectorizers[-2])

                    if lr_action == CAUSE_AND_EFFECT:
                        predicted_relations.add(cause_effect)
                        predicted_relations.add(effect_cause)

                        cause2effects[tos_tag_pair].add(buffer_tag_pair)
                        effect2causers[buffer_tag_pair].add(tos_tag_pair)

                        cause2effects[buffer_tag_pair].add(tos_tag_pair)
                        effect2causers[tos_tag_pair].add(buffer_tag_pair)

                    elif lr_action == CAUSE_EFFECT:
                        predicted_relations.add(cause_effect)

                        cause2effects[tos_tag_pair].add(buffer_tag_pair)
                        effect2causers[buffer_tag_pair].add(tos_tag_pair)

                    elif lr_action == EFFECT_CAUSE:
                        predicted_relations.add(effect_cause)

                        cause2effects[buffer_tag_pair].add(tos_tag_pair)
                        effect2causers[tos_tag_pair].add(buffer_tag_pair)

                    elif lr_action == REJECT:
                        pass
                    else:
                        raise Exception("Invalid CREL type")

                    # cost is always 1 for this action (cost of 1 for getting it wrong)
                    #  because getting the wrong direction won't screw up the parse as it doesn't modify the stack
                    if not predict_only:
                        out_crel_examples.add(dict(feats), gold_lr_action)

                        # Not sure we want to condition on the actions of this crel model
                        # action_history.append(lr_action)
                        # action_tag_pair_history.append((lr_action, tos, buffer))

                # end if action in [LARC,RARC]
                if not oracle.execute(action, tos_tag_pair, buffer_tag_pair):
                    break
                if oracle.is_stack_empty():
                    break

        # Validation logic. Break on pass as relations that should be parsed
        # for pcr in all_actual_crels:
        #     l,r = normalize_cr(pcr)
        #     if l in rtag_seq and r in rtag_seq and pcr not in predicted_relations:
        #         pass

        return predicted_relations
Ejemplo n.º 3
0
 def create_oracle(self):
     parser = ShiftReduceParser(Stack(verbose=False))
     parser.stack.push((ROOT, 0))
     # needs to be a tuple
     return Oracle([], parser)
    def predict_sentence(self, tagged_sentence, predicted_tags):

        action_history = []
        action_tag_pair_history = []

        pos_ptag_seq, _, tag2span, all_predicted_rtags, _ = self.get_tags_relations_for(tagged_sentence, predicted_tags, self.cr_tags)

        if len(all_predicted_rtags) == 0:
            return []

        words = [wd for wd, tags in tagged_sentence]

        # Initialize stack, basic parser and oracle
        stack = Stack(verbose=False)
        # needs to be a tuple
        stack.push((ROOT,0))
        parser = Parser(stack)
        oracle = Oracle([], parser)

        predicted_relations = set()

        # tags without positional info
        tag_seq = [t for t,i in pos_ptag_seq]
        rtag_seq = [t for t in tag_seq if t[0].isdigit()]
        # if not at least 2 concept codes, then can't parse
        if len(rtag_seq) < 2:
            return []

        # Oracle parsing logic
        for tag_ix, buffer in enumerate(pos_ptag_seq):
            buffer_tag = buffer[0]
            bstart, bstop = tag2span[buffer]
            buffer_word_seq = words[bstart:bstop + 1]
            buffer_feats = self.feat_extractor.extract(buffer_tag, buffer_word_seq, self.positive_val)
            buffer_feats = self.__prefix_feats_("BUFFER", buffer_feats)

            while True:
                tos = oracle.tos()
                tos_tag = tos[0]
                if tos_tag == ROOT:
                    tos_feats = {}
                    tstart, tstop = -1,-1
                else:
                    tstart, tstop = tag2span[tos]
                    tos_word_seq = words[tstart:tstop + 1]

                    tos_feats = self.feat_extractor.extract(tos_tag, tos_word_seq, self.positive_val)
                    tos_feats = self.__prefix_feats_("TOS", tos_feats)

                btwn_start, btwn_stop = min(tstop+1, len(words)-1), max(0, bstart-1)
                btwn_words = words[btwn_start:btwn_stop + 1]
                btwn_feats = self.feat_extractor.extract("BETWEEN", btwn_words, self.positive_val)
                btwn_feats = self.__prefix_feats_("__BTWN__", btwn_feats)

                feats = self.get_conditional_feats(action_history, action_tag_pair_history, tos_tag, buffer_tag,
                                                   tag_seq[:tag_ix], tag_seq [tag_ix + 1:])
                interaction_feats = self.get_interaction_feats(tos_feats, buffer_feats)
                feats.update(buffer_feats)
                feats.update(tos_feats)
                feats.update(btwn_feats)
                feats.update(interaction_feats)

                # Consult Oracle or Model based on coin toss
                action = self.predict_parse_action(feats, tos_tag)

                action_history.append(action)
                action_tag_pair_history.append((action, tos_tag, buffer_tag))

                # Decide the direction of the causal relation
                if action in [LARC, RARC]:

                    cause_effect = denormalize_cr((tos_tag,    buffer_tag))
                    effect_cause = denormalize_cr((buffer_tag, tos_tag))

                    # Add additional features
                    # needs to be before predict below
                    feats.update(self.crel_features(action, tos_tag, buffer_tag))
                    lr_action = self.predict_crel_action(feats)

                    if lr_action == CAUSE_AND_EFFECT:
                        predicted_relations.add(cause_effect)
                        predicted_relations.add(effect_cause)
                    elif lr_action == CAUSE_EFFECT:
                        predicted_relations.add(cause_effect)
                    elif lr_action == EFFECT_CAUSE:
                        predicted_relations.add(effect_cause)
                    elif lr_action == REJECT:
                        pass
                    else:
                        raise Exception("Invalid CREL type")

                # end if action in [LARC,RARC]
                if not oracle.execute(action, tos, buffer):
                    break
                if oracle.is_stack_empty():
                    break
        # Validation logic. Break on pass as relations that should be parsed
        return predicted_relations
    def generate_training_data(self, tagged_sentence, predicted_tags, parse_examples, crel_examples):

        action_history = []
        action_tag_pair_history = []

        pos_ptag_seq, pos_ground_truth, tag2span, all_predicted_rtags, all_actual_crels = self.get_tags_relations_for(tagged_sentence, predicted_tags, self.cr_tags)

        if len(all_predicted_rtags) == 0:
            return []

        words = [wd for wd, tags in tagged_sentence]

        # Initialize stack, basic parser and oracle
        stack = Stack(verbose=False)
        # needs to be a tuple
        stack.push((ROOT,0))
        parser = Parser(stack)
        oracle = Oracle(pos_ground_truth, parser)

        predicted_relations = set()

        # tags without positional info
        tag_seq = [t for t,i in pos_ptag_seq]
        rtag_seq = [t for t in tag_seq if t[0].isdigit()]
        # if not at least 2 concept codes, then can't parse
        if len(rtag_seq) < 2:
            return []

        # Oracle parsing logic
        for tag_ix, buffer in enumerate(pos_ptag_seq):
            buffer_tag = buffer[0]
            bstart, bstop = tag2span[buffer]
            buffer_word_seq = words[bstart:bstop + 1]
            buffer_feats = self.feat_extractor.extract(buffer_tag, buffer_word_seq, self.positive_val)
            buffer_feats = self.__prefix_feats_("BUFFER", buffer_feats)

            while True:
                tos = oracle.tos()
                tos_tag = tos[0]
                if tos_tag == ROOT:
                    tos_feats = {}
                    tstart, tstop = -1,-1
                else:
                    tstart, tstop = tag2span[tos]
                    tos_word_seq = words[tstart:tstop + 1]

                    tos_feats = self.feat_extractor.extract(tos_tag, tos_word_seq, self.positive_val)
                    tos_feats = self.__prefix_feats_("TOS", tos_feats)

                btwn_start, btwn_stop = min(tstop+1, len(words)-1), max(0, bstart-1)
                btwn_words = words[btwn_start:btwn_stop + 1]
                btwn_feats = self.feat_extractor.extract("BETWEEN", btwn_words, self.positive_val)
                btwn_feats = self.__prefix_feats_("__BTWN__", btwn_feats)

                feats = self.get_conditional_feats(action_history, action_tag_pair_history, tos_tag, buffer_tag,
                                                   tag_seq[:tag_ix], tag_seq [tag_ix + 1:])
                interaction_feats = self.get_interaction_feats(tos_feats, buffer_feats)
                feats.update(buffer_feats)
                feats.update(tos_feats)
                feats.update(btwn_feats)
                feats.update(interaction_feats)

                gold_action = oracle.consult(tos, buffer)

                # Consult Oracle or Model based on coin toss
                rand_float = np.random.random_sample()  # between [0,1) (half-open interval, includes 0 but not 1)
                # If no trained models, always use Oracle
                if rand_float >= self.beta and len(self.parser_models) > 0:
                    action = self.predict_parse_action(feats, tos_tag)
                else:
                    action = gold_action

                action_history.append(action)
                action_tag_pair_history.append((action, tos_tag, buffer_tag))

                cost_per_action = self.compute_cost(pos_ground_truth, pos_ptag_seq[tag_ix:], oracle)
                # make a copy as changing later
                parse_examples.add(dict(feats), gold_action, cost_per_action)

                # Decide the direction of the causal relation
                if action in [LARC, RARC]:

                    cause_effect = denormalize_cr((tos_tag,    buffer_tag))
                    effect_cause = denormalize_cr((buffer_tag, tos_tag))

                    if cause_effect in all_actual_crels and effect_cause in all_actual_crels:
                        gold_lr_action = CAUSE_AND_EFFECT
                    elif cause_effect in all_actual_crels:
                        gold_lr_action = CAUSE_EFFECT
                    elif effect_cause in all_actual_crels:
                        gold_lr_action = EFFECT_CAUSE
                    else:
                        gold_lr_action = REJECT

                    # Add additional features
                    # needs to be before predict below
                    feats.update(self.crel_features(action, tos_tag, buffer_tag))
                    rand_float = np.random.random_sample()
                    if rand_float >= self.beta and len(self.crel_models) > 0:
                        lr_action = self.predict_crel_action(feats)
                    else:
                        lr_action = gold_lr_action

                    if lr_action == CAUSE_AND_EFFECT:
                        predicted_relations.add(cause_effect)
                        predicted_relations.add(effect_cause)
                    elif lr_action == CAUSE_EFFECT:
                        predicted_relations.add(cause_effect)
                    elif lr_action == EFFECT_CAUSE:
                        predicted_relations.add(effect_cause)
                    elif lr_action == REJECT:
                        pass
                    else:
                        raise Exception("Invalid CREL type")

                    # cost is always 1 for this action (cost of 1 for getting it wrong)
                    #  because getting the wrong direction won't screw up the parse as it doesn't modify the stack
                    crel_examples.add(dict(feats), gold_lr_action)
                    # Not sure we want to condition on the actions of this crel model
                    # action_history.append(lr_action)
                    # action_tag_pair_history.append((lr_action, tos, buffer))

                # end if action in [LARC,RARC]
                if not oracle.execute(action, tos, buffer):
                    break
                if oracle.is_stack_empty():
                    break
        # Validation logic. Break on pass as relations that should be parsed
        for pcr in all_actual_crels:
            l,r = normalize_cr(pcr)
            if l in rtag_seq and r in rtag_seq and pcr not in predicted_relations:
                pass

        return predicted_relations
    def predict_sentence(self, tagged_sentence, predicted_tags):

        action_history = []
        action_tag_pair_history = []

        pos_ptag_seq, _, tag2span, all_predicted_rtags, _ = self.get_tags_relations_for(tagged_sentence, predicted_tags, self.cr_tags)

        if len(all_predicted_rtags) == 0:
            return []

        words = [wd for wd, tags in tagged_sentence]

        # Initialize stack, basic parser and oracle
        stack = Stack(verbose=False)
        # needs to be a tuple
        stack.push((ROOT,0))
        parser = Parser(stack)
        oracle = Oracle([], parser)

        predicted_relations = set()

        # tags without positional info
        tag_seq = [t for t,i in pos_ptag_seq]
        rtag_seq = [t for t in tag_seq if t[0].isdigit()]
        # if not at least 2 concept codes, then can't parse
        if len(rtag_seq) < 2:
            return []

        # Oracle parsing logic
        for tag_ix, buffer in enumerate(pos_ptag_seq):
            buffer_tag = buffer[0]
            bstart, bstop = tag2span[buffer]
            buffer_word_seq = words[bstart:bstop + 1]
            buffer_feats = self.feat_extractor.extract(buffer_tag, buffer_word_seq, self.positive_val)
            buffer_feats = self.__prefix_feats_("BUFFER", buffer_feats)

            while True:
                tos = oracle.tos()
                tos_tag = tos[0]
                if tos_tag == ROOT:
                    tos_feats = {}
                    tstart, tstop = -1,-1
                else:
                    tstart, tstop = tag2span[tos]
                    tos_word_seq = words[tstart:tstop + 1]

                    tos_feats = self.feat_extractor.extract(tos_tag, tos_word_seq, self.positive_val)
                    tos_feats = self.__prefix_feats_("TOS", tos_feats)

                btwn_start, btwn_stop = min(tstop+1, len(words)-1), max(0, bstart-1)
                btwn_words = words[btwn_start:btwn_stop + 1]
                btwn_feats = self.feat_extractor.extract("BETWEEN", btwn_words, self.positive_val)
                btwn_feats = self.__prefix_feats_("__BTWN__", btwn_feats)

                feats = self.get_conditional_feats(action_history, action_tag_pair_history, tos_tag, buffer_tag,
                                                   tag_seq[:tag_ix], tag_seq [tag_ix + 1:])
                interaction_feats = self.get_interaction_feats(tos_feats, buffer_feats)
                feats.update(buffer_feats)
                feats.update(tos_feats)
                feats.update(btwn_feats)
                feats.update(interaction_feats)

                # Consult Oracle or Model based on coin toss
                action = self.predict_parse_action(feats, tos_tag)

                action_history.append(action)
                action_tag_pair_history.append((action, tos_tag, buffer_tag))

                # Decide the direction of the causal relation
                if action in [LARC, RARC]:

                    cause_effect = denormalize_cr((tos_tag,    buffer_tag))
                    effect_cause = denormalize_cr((buffer_tag, tos_tag))

                    # Add additional features
                    # needs to be before predict below
                    feats.update(self.crel_features(action, tos_tag, buffer_tag))
                    lr_action = self.predict_crel_action(feats)

                    if lr_action == CAUSE_AND_EFFECT:
                        predicted_relations.add(cause_effect)
                        predicted_relations.add(effect_cause)
                    elif lr_action == CAUSE_EFFECT:
                        predicted_relations.add(cause_effect)
                    elif lr_action == EFFECT_CAUSE:
                        predicted_relations.add(effect_cause)
                    elif lr_action == REJECT:
                        pass
                    else:
                        raise Exception("Invalid CREL type")

                # end if action in [LARC,RARC]
                if not oracle.execute(action, tos, buffer):
                    break
                if oracle.is_stack_empty():
                    break
        # Validation logic. Break on pass as relations that should be parsed
        return predicted_relations
    def generate_training_data(self, tagged_sentence, predicted_tags, parse_examples, crel_examples):

        action_history = []
        action_tag_pair_history = []

        pos_ptag_seq, pos_ground_truth, tag2span, all_predicted_rtags, all_actual_crels = self.get_tags_relations_for(tagged_sentence, predicted_tags, self.cr_tags)

        if len(all_predicted_rtags) == 0:
            return []

        words = [wd for wd, tags in tagged_sentence]

        # Initialize stack, basic parser and oracle
        stack = Stack(verbose=False)
        # needs to be a tuple
        stack.push((ROOT,0))
        parser = Parser(stack)
        oracle = Oracle(pos_ground_truth, parser)

        predicted_relations = set()

        # tags without positional info
        tag_seq = [t for t,i in pos_ptag_seq]
        rtag_seq = [t for t in tag_seq if t[0].isdigit()]
        # if not at least 2 concept codes, then can't parse
        if len(rtag_seq) < 2:
            return []

        # Oracle parsing logic
        for tag_ix, buffer in enumerate(pos_ptag_seq):
            buffer_tag = buffer[0]
            bstart, bstop = tag2span[buffer]
            buffer_word_seq = words[bstart:bstop + 1]
            buffer_feats = self.feat_extractor.extract(buffer_tag, buffer_word_seq, self.positive_val)
            buffer_feats = self.__prefix_feats_("BUFFER", buffer_feats)

            while True:
                tos = oracle.tos()
                tos_tag = tos[0]
                if tos_tag == ROOT:
                    tos_feats = {}
                    tstart, tstop = -1,-1
                else:
                    tstart, tstop = tag2span[tos]
                    tos_word_seq = words[tstart:tstop + 1]

                    tos_feats = self.feat_extractor.extract(tos_tag, tos_word_seq, self.positive_val)
                    tos_feats = self.__prefix_feats_("TOS", tos_feats)

                btwn_start, btwn_stop = min(tstop+1, len(words)-1), max(0, bstart-1)
                btwn_words = words[btwn_start:btwn_stop + 1]
                btwn_feats = self.feat_extractor.extract("BETWEEN", btwn_words, self.positive_val)
                btwn_feats = self.__prefix_feats_("__BTWN__", btwn_feats)

                feats = self.get_conditional_feats(action_history, action_tag_pair_history, tos_tag, buffer_tag,
                                                   tag_seq[:tag_ix], tag_seq [tag_ix + 1:])
                interaction_feats = self.get_interaction_feats(tos_feats, buffer_feats)
                feats.update(buffer_feats)
                feats.update(tos_feats)
                feats.update(btwn_feats)
                feats.update(interaction_feats)

                gold_action = oracle.consult(tos, buffer)

                # Consult Oracle or Model based on coin toss
                rand_float = np.random.random_sample()  # between [0,1) (half-open interval, includes 0 but not 1)
                # If no trained models, always use Oracle
                if rand_float >= self.beta and len(self.parser_models) > 0:
                    action = self.predict_parse_action(feats, tos_tag)
                else:
                    action = gold_action

                action_history.append(action)
                action_tag_pair_history.append((action, tos_tag, buffer_tag))

                cost_per_action = self.compute_cost(pos_ground_truth, pos_ptag_seq[tag_ix:], oracle)
                # make a copy as changing later
                parse_examples.add(dict(feats), gold_action, cost_per_action)

                # Decide the direction of the causal relation
                if action in [LARC, RARC]:

                    cause_effect = denormalize_cr((tos_tag,    buffer_tag))
                    effect_cause = denormalize_cr((buffer_tag, tos_tag))

                    if cause_effect in all_actual_crels and effect_cause in all_actual_crels:
                        gold_lr_action = CAUSE_AND_EFFECT
                    elif cause_effect in all_actual_crels:
                        gold_lr_action = CAUSE_EFFECT
                    elif effect_cause in all_actual_crels:
                        gold_lr_action = EFFECT_CAUSE
                    else:
                        gold_lr_action = REJECT

                    # Add additional features
                    # needs to be before predict below
                    feats.update(self.crel_features(action, tos_tag, buffer_tag))
                    rand_float = np.random.random_sample()
                    if rand_float >= self.beta and len(self.crel_models) > 0:
                        lr_action = self.predict_crel_action(feats)
                    else:
                        lr_action = gold_lr_action

                    if lr_action == CAUSE_AND_EFFECT:
                        predicted_relations.add(cause_effect)
                        predicted_relations.add(effect_cause)
                    elif lr_action == CAUSE_EFFECT:
                        predicted_relations.add(cause_effect)
                    elif lr_action == EFFECT_CAUSE:
                        predicted_relations.add(effect_cause)
                    elif lr_action == REJECT:
                        pass
                    else:
                        raise Exception("Invalid CREL type")

                    # cost is always 1 for this action (cost of 1 for getting it wrong)
                    #  because getting the wrong direction won't screw up the parse as it doesn't modify the stack
                    crel_examples.add(dict(feats), gold_lr_action)
                    # Not sure we want to condition on the actions of this crel model
                    # action_history.append(lr_action)
                    # action_tag_pair_history.append((lr_action, tos, buffer))

                # end if action in [LARC,RARC]
                if not oracle.execute(action, tos, buffer):
                    break
                if oracle.is_stack_empty():
                    break
        # Validation logic. Break on pass as relations that should be parsed
        for pcr in all_actual_crels:
            l,r = normalize_cr(pcr)
            if l in rtag_seq and r in rtag_seq and pcr not in predicted_relations:
                pass

        return predicted_relations