Beispiel #1
0
    def work(self, sentence, oracle=None, initial_state=None):
        train = oracle is not None

        if initial_state is None:
            initial_state = transition_utils.ArcHybridState(
                sentence, self.options)
            lstm_outputs = self.getWordEmbeddings(sentence, train)
            initial_state.set_lstm_map(lstm_outputs)

        beam = Beam(maxsize=self.options.beam_size)
        beam.push(initial_state)
        assert not initial_state.history or initial_state.is_correct
        gold_state = None

        while not beam[0].is_finished():
            gold_state = None
            new_beam_candidates = Beam(maxsize=self.options.beam_size)
            gold_candidate = None
            action_scores, label_scores = self.evaluate_all_states(
                i for i in beam)
            action_scores_np = action_scores.npvalue()
            label_scores_np = label_scores.npvalue()
            if len(beam) == 1:
                action_scores_np = action_scores_np.reshape(
                    action_scores_np.shape + (1, ))
                label_scores_np = label_scores_np.reshape(
                    label_scores_np.shape + (1, ))

            for state_idx, state in enumerate(beam):
                can_do_action = [i.can_do_action(state) for i in self.actions]
                for joint_idx, (action, relation, action_idx, relation_idx) \
                        in enumerate(self.actions.decoded_with_relation):
                    if not can_do_action[action_idx]:
                        continue
                    local_score = action_scores_np[
                        action_idx, state_idx] + label_scores_np[joint_idx,
                                                                 state_idx]
                    if train and state.is_correct and (
                            action, relation) == oracle[len(
                                state.history)] and self.options.loss_aug:
                        local_score -= self.options.loss_aug
                    candidate = transition_utils.StateCandidate(
                        state_idx, joint_idx, state.score + local_score,
                        local_score)
                    if train and state.is_correct and (action,
                                                       relation) == oracle[len(
                                                           state.history)]:
                        gold_candidate = candidate
                    new_beam_candidates.push(candidate)

            def candidate_to_state(candidate_):
                state = beam[candidate_.state_idx]
                new_state = state.copy()
                action, relation, action_idx, relation_idx = self.actions.decoded_with_relation[
                    candidate_.joint_idx]
                correctness = (action, relation) == oracle[len(
                    state.history)] if train else True
                score_repr = dn.pick_batch_elem(action_scores, candidate_.state_idx)[action_idx] + \
                             dn.pick_batch_elem(label_scores, candidate_.state_idx)[candidate_.joint_idx] \
                    if train else None
                action.do_action(
                    new_state, relation,
                    transition_utils.History(
                        ActionScore(relation, action, candidate_.local_score,
                                    score_repr), correctness))
                return new_state

            new_beam = Beam(maxsize=self.options.beam_size,
                            key=attrgetter("score"))
            for candidate in new_beam_candidates:
                new_state = candidate_to_state(candidate)
                new_beam.push(new_state)
                if train and new_state.is_correct:
                    state = beam[candidate.state_idx]
                    assert not state.history or state.is_correct
                    gold_state = new_state

            if train and gold_state is None:
                gold_state = candidate_to_state(gold_candidate)
                beam = new_beam
                break  # early update

            beam = new_beam

        state_to_update = max((i for i in beam), key=attrgetter("score"))
        return state_to_update, gold_state