def decode_step(self, preds, memory, memory_mask, cache, scores, flag): """ decode an utterance in a stepwise way""" batch_size = int(scores.size(0) / self.beam_width) batch_log_probs, dec_cache, dec_attn_weights = self.decode( preds, memory, memory_mask, cache["decoder"]) if self.lm is not None: batch_lm_log_probs, lm_hidden = self.lm_decode(preds, cache["lm"]) batch_lm_log_probs = batch_lm_log_probs.squeeze(1) batch_log_probs = batch_log_probs + self.lm_weight * batch_lm_log_probs else: lm_hidden = None if batch_log_probs.dim() == 3: batch_log_probs = batch_log_probs.squeeze(1) last_k_scores, last_k_preds = batch_log_probs.topk(self.beam_width) last_k_scores = mask_finished_scores(last_k_scores, flag) last_k_preds = mask_finished_preds(last_k_preds, flag) # update scores scores = scores + last_k_scores scores = scores.view(batch_size, self.beam_width * self.beam_width) # pruning scores, offset_k_indices = flow.topk(scores, k=self.beam_width) scores = scores.view(-1, 1) device = scores.device base_k_indices = (flow.arange(batch_size, device=device).view( -1, 1).repeat([1, self.beam_width])) base_k_indices *= self.beam_width**2 best_k_indices = base_k_indices.view(-1) + offset_k_indices.view(-1) # update predictions best_k_preds = flow.index_select(last_k_preds.view(-1), dim=0, index=best_k_indices).to(flow.int64) preds_index = best_k_indices.floor_divide(self.beam_width) preds_symbol = flow.index_select(preds, dim=0, index=preds_index) preds_symbol = flow.cat( [preds_symbol, best_k_preds.view(-1, 1)], dim=1) # finished or not end_flag = flow.eq(preds_symbol[:, -1], EOS).view(-1, 1).to(flow.uint8) return preds_symbol, cache, scores, end_flag
def reselect_hidden(tensor, beam_width, indices): n_layers, batch_size, hidden_size = tensor.size() tensor = tensor.transpose(0, 1).unsqueeze(1).repeat([1, beam_width, 1, 1]) tensor = tensor.reshape(batch_size * beam_width, n_layers, hidden_size) new_tensor = flow.index_select(tensor, dim=0, index=indices) new_tensor = new_tensor.transpose(0, 1).contiguous() return new_tensor
def test_index_select_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[1, 2, 3], [4, 5, 6]], dtype=flow.float32, requires_grad=True) index = flow.tensor([0], dtype=flow.int32) y = flow.index_select(x, 4, index) test_case.assertTrue( "Dimension out of range" in str(context.exception))
def test_index_select_index_num_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[1, 2, 3], [4, 5, 6]], dtype=flow.float32, requires_grad=True) index = flow.tensor([[0]], dtype=flow.int32) y = flow.index_select(x, 1, index) test_case.assertTrue( "Index is supposed to be a vector" in str(context.exception))
def test_index_select_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.tensor([[1, 2, 3], [4, 5, 6]], dtype=flow.float32, requires_grad=True) index = flow.tensor([0, 1], dtype=flow.float32) y = flow.index_select(x, 1, index) test_case.assertTrue("Expected dtype int32 or int64 for index" in str( context.exception))
def select_hidden(tensor, indices, beam_width): n_layers, _, hidden_size = tensor.size() tensor = tensor.transpose(0, 1) tensor = (tensor.unsqueeze(1).repeat([1, beam_width, 1, 1]).reshape(-1, n_layers, hidden_size)) new_tensor = flow.index_select(tensor, dim=0, index=indices) new_tensor = new_tensor.transpose(0, 1).contiguous() return new_tensor
def select_chunk_states_and_mask_based_index(tensor, tensor_mask, index): # tensor: [b, c, t, v] # index: [b] # return [b, t, v] assert tensor.dim() == 4 assert tensor_mask.dim() == 3 assert index.dim() == 1 b, c, t, v = tensor.size() base_index = flow.arange(b, device=tensor.device) * c indices = base_index + index select_tensor = flow.index_select(tensor.reshape(b * c, t, v), 0, indices.long()) select_tensor_mask = flow.index_select(tensor_mask.reshape(b * c, 1, t), 0, indices.long()) return select_tensor, select_tensor_mask
def select_tensor_based_index(tensor, index): # tensor: [b, c, t, v] # index: [b] # return [b, t, v] assert tensor.dim() >= 2 assert index.dim() == 1 batch_size = tensor.size(0) tensor_len = tensor.size(1) base_index = flow.arange(batch_size, device=tensor.device) * tensor_len indices = base_index + index if tensor.dim() == 2: select_tensor = flow.index_select( tensor.reshape(batch_size * tensor_len), 0, indices.long()) else: assert tensor.dim() == 3 select_tensor = flow.index_select( tensor.reshape(batch_size * tensor_len, tensor.size(-1)), 0, indices.long()) return select_tensor
def lm_rescoring(self, preds, pred_lens): # preds [beam_size, lens] # preds_len [beam_size] if self.lm.model_type == "transformer_lm": log_probs = self.lm.predict(preds, last_frame=False) else: log_probs = [] hidden = None for t in range(preds.size(1)): log_prob, hidden = self.lm.predict(preds[:, t].unsqueeze(-1), hidden) log_probs.append(log_prob) log_probs = flow.cat(log_probs, dim=1) rescores = [] max_length = log_probs.size(1) vocab_size = log_probs.size(-1) for b in range(preds.size(0)): base_index = flow.arange(max_length, device=preds.device) bias_index = preds[b].reshape(-1) index = base_index * vocab_size + bias_index score = flow.index_select(log_probs[b].reshape(-1), dim=-1, index=index) label_len = min(int(pred_lens[b]), score.size(0)) score[label_len - 1:] = 0 rescores.append(flow.sum(score) / label_len) rescores = flow.tensor(rescores, dtype=flow.float32) _, indices = flow.sort(rescores, dim=-1, descending=True) sorted_preds = preds[indices] sorted_length = pred_lens[indices] return sorted_preds, sorted_length
def _index_select(self, dim, index): return flow.index_select(self, dim, index)