def create_sent_report(self, src: sent.Sentence, output: sent.ReadableSentence, attentions: np.ndarray, ref_file: Optional[str], **kwargs) -> None: """ Create report. Args: src: source-side input output: generated output attentions: attention matrices ref_file: path to reference file **kwargs: arguments to be ignored """ self.cur_sent_no += 1 if self.max_num_sents and self.cur_sent_no > self.max_num_sents: return reference = utils.cached_file_lines(ref_file)[output.idx] idx = src.idx self.add_sent_heading(idx) src_tokens = src.str_tokens() if isinstance( src, sent.ReadableSentence) else [] trg_tokens = output.str_tokens() src_str = src.sent_str() if isinstance(src, sent.ReadableSentence) else "" trg_str = output.sent_str() self.add_charcut_diff(trg_str, reference) self.add_fields_if_set({"Src": src_str}) self.add_atts( attentions, src.get_array() if isinstance(src, sent.ArraySentence) else src_tokens, trg_tokens, idx) self.finish_sent()
def embed_sent(self, x: sent.Sentence) -> expression_seqs.ExpressionSequence: # TODO refactor: seems a bit too many special cases that need to be distinguished batched = batchers.is_batched(x) first_sent = x[0] if batched else x if hasattr(first_sent, "get_array"): if not batched: return expression_seqs.LazyNumpyExpressionSequence( lazy_data=x.get_array()) else: return expression_seqs.LazyNumpyExpressionSequence( lazy_data=batchers.mark_as_batch([s for s in x]), mask=x.mask) else: if not batched: embeddings = [self.embed(word) for word in x] else: embeddings = [] for word_i in range(x.sent_len()): embeddings.append( self.embed( batchers.mark_as_batch( [single_sent[word_i] for single_sent in x]))) return expression_seqs.ExpressionSequence(expr_list=embeddings, mask=x.mask)
def create_sent_report(self, segment_actions, src: sent.Sentence, **kwargs): if self.report_fp is None: utils.make_parent_dir(self.report_path) self.report_fp = open(self.report_path, "w") actions = segment_actions[0] src = src.str_tokens() words = [] start = 0 for end in actions: if start < end + 1: words.append("".join(map(str, src[start:end + 1]))) start = end + 1 print(" ".join(words), file=self.report_fp)
def create_sent_report(self, src: sent.Sentence, output: sent.ReadableSentence, ref_file: Optional[str] = None, **kwargs) -> None: """ Create report. Args: src: source-side input output: generated output ref_file: path to reference file **kwargs: arguments to be ignored """ reference = utils.cached_file_lines(ref_file)[output.idx] trg_str = output.sent_str() if isinstance(src, sent.ReadableSentence): src_str = src.sent_str() self.src_sents.append(src_str) self.hyp_sents.append(trg_str) self.ref_sents.append(reference)
def create_trajectory(self, src: sent.Sentence, ref: sent.Sentence = None, current_state: Optional[SimultaneousState] = None, from_oracle: bool = True, force_decoding: bool = True, max_generation: int = -1): assert not from_oracle or type( src) == sent.CompoundSentence or self._is_action_forced() force_action = None if type(src) == sent.CompoundSentence: src, force_action = src.sents[0], src.sents[1].words force_action = force_action if from_oracle else None current_state = current_state or self._initial_state(src) src_len = src.len_unpadded() actions = [] decoder_states = [] outputs = [] model_states = [current_state] def stoping_criterions_met(state, trg, now_action): look_oracle = now_action is not None and from_oracle if look_oracle: return state.has_been_read + state.has_been_written >= len( force_action) elif self.policy_network is None or self._is_action_forced(): return state.has_been_written >= trg.sent_len() else: return (max_generation != -1 and state.has_been_written >= max_generation) or \ state.written_word == vocabs.Vocab.ES # Simultaneous greedy search while not stoping_criterions_met(current_state, ref, force_action): actions_taken = current_state.has_been_read + current_state.has_been_written if force_action is not None and actions_taken < len(force_action): defined_action = force_action[actions_taken] else: defined_action = None # Define action based on state policy_action = self._next_action(current_state, src_len, defined_action) action = policy_action.content if action == self.Action.READ.value: # Reading + Encoding current_state = current_state.read( self.src_encoding[current_state.has_been_read], policy_action) elif action == self.Action.WRITE.value: # Calculating losses if force_decoding: if ref.len_unpadded() <= current_state.has_been_written: prev_word = vocabs.Vocab.ES elif current_state.has_been_written == 0: prev_word = None else: prev_word = ref[current_state.has_been_written - 1] # Write current_state = current_state.write( self.src_encoding, prev_word, policy_action) else: # TODO implement if ref is None! pass # The produced words outputs.append(prev_word) decoder_states.append(current_state) else: raise ValueError(action) model_states.append(current_state) actions.append(action) return actions, outputs, decoder_states, model_states