def batch_beam_search(self, x, beam_size=5, max_length=255, n_best=1, length_penalty=.2): # |x[0]| = (batch_size, n) batch_size = x[0].size(0) mask = self._generate_mask(x[0], x[1]) # |mask| = (batch_size, n) x = x[0] mask_enc = torch.stack([mask for _ in range(x.size(1))], dim=1) mask_dec = mask.unsqueeze(1) # |mask_enc| = (batch_size, n, n) # |mask_dec| = (batch_size, 1, n) z = self.emb_dropout(self._position_encoding(self.emb_enc(x))) z, _ = self.encoder(z, mask_enc) # |z| = (batch_size, n, hidden_size) spaces = [ SingleBeamSearchSpace( z.device, [('prev_state_%d' % j, None, 0) for j in range(len(self.decoder._modules) + 1)], beam_size=beam_size, max_length=max_length, ) for i in range(batch_size) ] done_cnt = [space.is_done() for space in spaces] length = 0 while sum(done_cnt) < batch_size and length <= max_length: fab_input, fab_z, fab_mask = [], [], [] fab_prevs = [[] for _ in range(len(self.decoder._modules) + 1)] for i, space in enumerate(spaces): if space.is_done() == 0: tmp = space.get_batch() y_hat_ = tmp[0] tmp = tmp[1:] fab_input += [y_hat_] for j, prev_ in enumerate(tmp): if prev_ is not None: fab_prevs[j] += [prev_] else: fab_prevs[j] = None fab_z += [z[i].unsqueeze(0)] * beam_size fab_mask += [mask_dec[i].unsqueeze(0)] * beam_size fab_input = torch.cat(fab_input, dim=0) for i, fab_prev in enumerate(fab_prevs): if fab_prev is not None: fab_prevs[i] = torch.cat(fab_prev, dim=0) fab_z = torch.cat(fab_z, dim=0) fab_mask = torch.cat(fab_mask, dim=0) # |fab_input| = (current_batch_size, 1,) # |fab_prevs[i]| = (current_batch_size, length, hidden_size) # |fab_z| = (current_batch_size, n, hidden_size) # |fab_mask| = (current_batch_size, 1, n) # Unlike training procedure, # take the last time-step's output during the inference. h_t = self.emb_dropout( self._position_encoding(self.emb_dec(fab_input), init_pos=length)) # |h_t| = (current_batch_size, 1, hidden_size) if fab_prevs[0] is None: fab_prevs[0] = h_t else: fab_prevs[0] = torch.cat([fab_prevs[0], h_t], dim=1) for i, block in enumerate(self.decoder._modules.values()): prev = fab_prevs[i] # |prev| = (current_batch_size, m, hidden_size) h_t, _, _, _ = block(h_t, fab_z, fab_mask, prev) # |h_t| = (current_batch_size, 1, hidden_size) if fab_prevs[i + 1] is None: fab_prevs[i + 1] = h_t else: fab_prevs[i + 1] = torch.cat([fab_prevs[i + 1], h_t], dim=1) y_hat_t = self.softmax(self.generator(h_t)) # |y_hat_t| = (batch_size, 1, output_size) cnt = 0 for space in spaces: if space.is_done() == 0: from_index = cnt * beam_size to_index = from_index + beam_size space.collect_result( y_hat_t[from_index:to_index], [( 'prev_state_%d' % i, fab_prevs[i][from_index:to_index], ) for i in range(len(self.decoder._modules) + 1)], ) cnt += 1 done_cnt = [space.is_done() for space in spaces] length += 1 batch_sentences = [] batch_probs = [] for i, space in enumerate(spaces): sentences, probs = space.get_n_best(n_best, length_penalty=length_penalty) batch_sentences += [sentences] batch_probs += [probs] return batch_sentences, batch_probs
def batch_beam_search(self, src, beam_size=5, max_length=255, n_best=1, length_penalty=.2): mask, x_length = None, None if isinstance(src, tuple): x, x_length = src mask = self.generate_mask(x, x_length) # |mask| = (batch_size, length) else: x = src batch_size = x.size(0) emb_src = self.emb_src(x) h_src, h_0_tgt = self.encoder((emb_src, x_length)) # |h_src| = (batch_size, length, hidden_size) h_0_tgt, c_0_tgt = h_0_tgt h_0_tgt = h_0_tgt.transpose(0, 1).contiguous().view( batch_size, -1, self.hidden_size).transpose(0, 1).contiguous() c_0_tgt = c_0_tgt.transpose(0, 1).contiguous().view( batch_size, -1, self.hidden_size).transpose(0, 1).contiguous() # |h_0_tgt| = (n_layers, batch_size, hidden_size) h_0_tgt = (h_0_tgt, c_0_tgt) # initialize 'SingleBeamSearchSpace' as many as batch_size spaces = [ SingleBeamSearchSpace( h_src.device, [ ('hidden_state', h_0_tgt[0][:, i, :].unsqueeze(1), 1), ('cell_state', h_0_tgt[1][:, i, :].unsqueeze(1), 1), ('h_t_1_tilde', None, 0), ], beam_size=beam_size, max_length=max_length, ) for i in range(batch_size) ] done_cnt = [space.is_done() for space in spaces] length = 0 # Run loop while sum of 'done_cnt' is smaller than batch_size, # or length is still smaller than max_length. while sum(done_cnt) < batch_size and length <= max_length: # current_batch_size = sum(done_cnt) * beam_size # Initialize fabricated variables. # As far as batch-beam-search is running, # temporary batch-size for fabricated mini-batch is # 'beam_size'-times bigger than original batch_size. fab_input, fab_hidden, fab_cell, fab_h_t_tilde = [], [], [], [] fab_h_src, fab_mask = [], [] # Build fabricated mini-batch in non-parallel way. # This may cause a bottle-neck. for i, space in enumerate(spaces): # Batchify if the inference for the sample is still not finished. if space.is_done() == 0: y_hat_, hidden_, cell_, h_t_tilde_ = space.get_batch() fab_input += [y_hat_] fab_hidden += [hidden_] fab_cell += [cell_] if h_t_tilde_ is not None: fab_h_t_tilde += [h_t_tilde_] else: fab_h_t_tilde = None fab_h_src += [h_src[i, :, :]] * beam_size fab_mask += [mask[i, :]] * beam_size # Now, concatenate list of tensors. fab_input = torch.cat(fab_input, dim=0) fab_hidden = torch.cat(fab_hidden, dim=1) fab_cell = torch.cat(fab_cell, dim=1) if fab_h_t_tilde is not None: fab_h_t_tilde = torch.cat(fab_h_t_tilde, dim=0) fab_h_src = torch.stack(fab_h_src) fab_mask = torch.stack(fab_mask) # |fab_input| = (current_batch_size, 1) # |fab_hidden| = (n_layers, current_batch_size, hidden_size) # |fab_cell| = (n_layers, current_batch_size, hidden_size) # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size) # |fab_h_src| = (current_batch_size, length, hidden_size) # |fab_mask| = (current_batch_size, length) emb_t = self.emb_dec(fab_input) # |emb_t| = (current_batch_size, 1, word_vec_dim) fab_decoder_output, (fab_hidden, fab_cell) = self.decoder( emb_t, fab_h_t_tilde, (fab_hidden, fab_cell)) # |fab_decoder_output| = (current_batch_size, 1, hidden_size) context_vector = self.attn(fab_h_src, fab_decoder_output, fab_mask) # |context_vector| = (current_batch_size, 1, hidden_size) fab_h_t_tilde = self.tanh( self.concat( torch.cat([fab_decoder_output, context_vector], dim=-1))) # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size) y_hat = self.generator(fab_h_t_tilde) # |y_hat| = (current_batch_size, 1, output_size) # separate the result for each sample. # fab_hidden[:, from_index:to_index, :] = (n_layers, beam_size, hidden_size) # fab_cell[:, from_index:to_index, :] = (n_layers, beam_size, hidden_size) # fab_h_t_tilde[from_index:to_index] = (beam_size, 1, hidden_size) cnt = 0 for space in spaces: if space.is_done() == 0: # Decide a range of each sample. from_index = cnt * beam_size to_index = from_index + beam_size # pick k-best results for each sample. space.collect_result( y_hat[from_index:to_index], [ ('hidden_state', fab_hidden[:, from_index:to_index, :]), ('cell_state', fab_cell[:, from_index:to_index, :]), ('h_t_1_tilde', fab_h_t_tilde[from_index:to_index]), ], ) cnt += 1 done_cnt = [space.is_done() for space in spaces] length += 1 # pick n-best hypothesis. batch_sentences = [] batch_probs = [] # Collect the results. for i, space in enumerate(spaces): sentences, probs = space.get_n_best(n_best, length_penalty=length_penalty) batch_sentences += [sentences] batch_probs += [probs] return batch_sentences, batch_probs
def batch_beam_search(self, src, beam_size=5, max_length=255, n_best=1): mask = None x_length = None if isinstance(src, tuple): x, x_length = src mask = self.generate_mask(x, x_length) # |mask| = (batch_size, length) else: x = src batch_size = x.size(0) emb_src = self.emb_src(x) h_src, h_0_tgt = self.encoder((emb_src, x_length)) # |h_src| = (batch_size, length, hidden_size) h_0_tgt, c_0_tgt = h_0_tgt h_0_tgt = h_0_tgt.transpose(0, 1).contiguous().view( batch_size, -1, self.hidden_size).transpose(0, 1).contiguous() c_0_tgt = c_0_tgt.transpose(0, 1).contiguous().view( batch_size, -1, self.hidden_size).transpose(0, 1).contiguous() # |h_0_tgt| = (n_layers, batch_size, hidden_size) h_0_tgt = (h_0_tgt, c_0_tgt) # initialize beam-search. spaces = [ SingleBeamSearchSpace((h_0_tgt[0][:, i, :].unsqueeze(1), h_0_tgt[1][:, i, :].unsqueeze(1)), None, beam_size, max_length=max_length) for i in range(batch_size) ] done_cnt = [space.is_done() for space in spaces] length = 0 while sum(done_cnt) < batch_size and length <= max_length: # current_batch_size = sum(done_cnt) * beam_size # initialize fabricated variables. fab_input, fab_hidden, fab_cell, fab_h_t_tilde = [], [], [], [] fab_h_src, fab_mask = [], [] # batchify. for i, space in enumerate(spaces): if space.is_done() == 0: y_hat_, (hidden_, cell_), h_t_tilde_ = space.get_batch() fab_input += [y_hat_] fab_hidden += [hidden_] fab_cell += [cell_] if h_t_tilde_ is not None: fab_h_t_tilde += [h_t_tilde_] else: fab_h_t_tilde = None fab_h_src += [h_src[i, :, :]] * beam_size fab_mask += [mask[i, :]] * beam_size fab_input = torch.cat(fab_input, dim=0) fab_hidden = torch.cat(fab_hidden, dim=1) fab_cell = torch.cat(fab_cell, dim=1) if fab_h_t_tilde is not None: fab_h_t_tilde = torch.cat(fab_h_t_tilde, dim=0) fab_h_src = torch.stack(fab_h_src) fab_mask = torch.stack(fab_mask) # |fab_input| = (current_batch_size, 1) # |fab_hidden| = (n_layers, current_batch_size, hidden_size) # |fab_cell| = (n_layers, current_batch_size, hidden_size) # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size) # |fab_h_src| = (current_batch_size, length, hidden_size) # |fab_mask| = (current_batch_size, length) emb_t = self.emb_dec(fab_input) # |emb_t| = (current_batch_size, 1, word_vec_dim) fab_decoder_output, (fab_hidden, fab_cell) = self.decoder( emb_t, fab_h_t_tilde, (fab_hidden, fab_cell)) # |fab_decoder_output| = (current_batch_size, 1, hidden_size) context_vector = self.attn(fab_h_src, fab_decoder_output, fab_mask) # |context_vector| = (current_batch_size, 1, hidden_size) fab_h_t_tilde = self.tanh( self.concat( torch.cat([fab_decoder_output, context_vector], dim=-1))) # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size) y_hat = self.generator(fab_h_t_tilde) # |y_hat| = (current_batch_size, 1, output_size) # separate the result for each sample. cnt = 0 for space in spaces: if space.is_done() == 0: from_index = cnt * beam_size to_index = (cnt + 1) * beam_size # pick k-best results for each sample. space.collect_result( y_hat[from_index:to_index], (fab_hidden[:, from_index:to_index, :], fab_cell[:, from_index:to_index, :]), fab_h_t_tilde[from_index:to_index]) cnt += 1 done_cnt = [space.is_done() for space in spaces] length += 1 # pick n-best hypothesis. batch_sentences = [] batch_probs = [] for i, space in enumerate(spaces): sentences, probs = space.get_n_best(n_best) batch_sentences += [sentences] batch_probs += [probs] return batch_sentences, batch_probs