def work(self, sentence, oracle=None, initial_state=None): train = oracle is not None if initial_state is None: initial_state = transition_utils.ArcHybridState( sentence, self.options) lstm_outputs = self.getWordEmbeddings(sentence, train) initial_state.set_lstm_map(lstm_outputs) beam = Beam(maxsize=self.options.beam_size) beam.push(initial_state) assert not initial_state.history or initial_state.is_correct gold_state = None while not beam[0].is_finished(): gold_state = None new_beam_candidates = Beam(maxsize=self.options.beam_size) gold_candidate = None action_scores, label_scores = self.evaluate_all_states( i for i in beam) action_scores_np = action_scores.npvalue() label_scores_np = label_scores.npvalue() if len(beam) == 1: action_scores_np = action_scores_np.reshape( action_scores_np.shape + (1, )) label_scores_np = label_scores_np.reshape( label_scores_np.shape + (1, )) for state_idx, state in enumerate(beam): can_do_action = [i.can_do_action(state) for i in self.actions] for joint_idx, (action, relation, action_idx, relation_idx) \ in enumerate(self.actions.decoded_with_relation): if not can_do_action[action_idx]: continue local_score = action_scores_np[ action_idx, state_idx] + label_scores_np[joint_idx, state_idx] if train and state.is_correct and ( action, relation) == oracle[len( state.history)] and self.options.loss_aug: local_score -= self.options.loss_aug candidate = transition_utils.StateCandidate( state_idx, joint_idx, state.score + local_score, local_score) if train and state.is_correct and (action, relation) == oracle[len( state.history)]: gold_candidate = candidate new_beam_candidates.push(candidate) def candidate_to_state(candidate_): state = beam[candidate_.state_idx] new_state = state.copy() action, relation, action_idx, relation_idx = self.actions.decoded_with_relation[ candidate_.joint_idx] correctness = (action, relation) == oracle[len( state.history)] if train else True score_repr = dn.pick_batch_elem(action_scores, candidate_.state_idx)[action_idx] + \ dn.pick_batch_elem(label_scores, candidate_.state_idx)[candidate_.joint_idx] \ if train else None action.do_action( new_state, relation, transition_utils.History( ActionScore(relation, action, candidate_.local_score, score_repr), correctness)) return new_state new_beam = Beam(maxsize=self.options.beam_size, key=attrgetter("score")) for candidate in new_beam_candidates: new_state = candidate_to_state(candidate) new_beam.push(new_state) if train and new_state.is_correct: state = beam[candidate.state_idx] assert not state.history or state.is_correct gold_state = new_state if train and gold_state is None: gold_state = candidate_to_state(gold_candidate) beam = new_beam break # early update beam = new_beam state_to_update = max((i for i in beam), key=attrgetter("score")) return state_to_update, gold_state