def generate(self, src: Union[batchers.Batch, sent.Sentence], forced_trg_ids: Optional[Sequence[numbers.Integral]] = None, normalize_scores: bool = False): if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) if forced_trg_ids: forced_trg_ids = batchers.mark_as_batch([forced_trg_ids]) h = self._encode_src(src) scores = self.scorer.calc_log_probs( h) if normalize_scores else self.scorer.calc_scores(h) np_scores = scores.npvalue() if forced_trg_ids: output_action = forced_trg_ids else: output_action = np.argmax(np_scores, axis=0) outputs = [] for batch_i in range(src.batch_size()): if src.batch_size() > 1: my_action = output_action[batch_i] score = np_scores[:, batch_i][my_action] else: my_action = output_action score = np_scores[my_action] outputs.append(sent.ScalarSentence(value=my_action, score=score)) return outputs
def read_sent(self, line: str, idx: numbers.Integral) -> sent.Sentence: if self.vocab: convert_fct = self.vocab.convert else: convert_fct = convert_int if self.read_sent_len: return sent.ScalarSentence(idx=idx, value=len(line.strip().split())) else: return sent.SimpleSentence(idx=idx, words=[convert_fct(word) for word in line.strip().split()] + [vocabs.Vocab.ES], vocab=self.vocab, output_procs=self.output_procs)
def generate(self, src: Union[batchers.Batch, sent.Sentence], normalize_scores: bool = False): if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) h = self._encode_src(src) best_words, best_scores = self.scorer.best_k(h, k=1, normalize_scores=normalize_scores) assert best_words.shape == (1, src.batch_size()) assert best_scores.shape == (1, src.batch_size()) outputs = [] for batch_i in range(src.batch_size()): if src.batch_size() > 1: word = best_words[0, batch_i] score = best_scores[0, batch_i] else: word = best_words[0] score = best_scores[0] outputs.append(sent.ScalarSentence(value=word, score=score)) return outputs
def read_sent(self, line: str, idx: numbers.Integral) -> sent.ScalarSentence: return sent.ScalarSentence(idx=idx, value=int(line.strip()))