def beam_search_forward(self, col, score, embedding, pre_hid_info, s_embedding): n = col.shape[0] input_info = T.concatenate([embedding, col, pre_hid_info], axis=-1) u1 = get_output(self.gru_update_3, input_info) r1 = get_output(self.gru_reset_3, input_info) reset_h1 = pre_hid_info * r1 c_in = T.concatenate([embedding, col, reset_h1], axis=1) c1 = get_output(self.gru_candidate_3, c_in) h1 = (1.0 - u1) * pre_hid_info + u1 * c1 s = get_output(self.score, h1) sample_score = T.dot(s, s_embedding.T) k = sample_score.shape[-1] sample_score = sample_score.reshape((n, 1, k)) sample_score += score sample_score = sample_score.reshape((n, 10 * k)) sort_index = T.argsort(-sample_score, axis=-1) sample_score = T.sort(-sample_score, axis=-1) tops = sort_index[:, :10] sample_score = -sample_score[:, :10] tops = T.cast(T.divmod(tops, self.target_vocab_size), "int8") embedding = get_output(self.target_input_embedding, tops) d = embedding.shape[-1] embedding = embedding.reshape((n * 10, d)) return sample_score, embedding, h1, tops
def beam_search_backward(self, top, idx): max_idx = top[idx] idx = T.true_div(max_idx, self.target_vocab_size) max_idx = T.divmod(max_idx, self.target_vocab_size) return idx, max_idx
def divmod(self, l, r): return T.divmod(l, r)