def beam_decode(model, input, hidden, max_len_decode, beam_size, pad_id, sos_id, eos_id, tup_idx=4, batch_size=1, use_constraints=True, init_beam=False, roles=None): # hidden [1, 1, hidden_size] assert beam_size > 0 and batch_size == 1, "Beam decoding batch size must be 1 and Beam size greater than 0." # Helper functions for working with beams and batches def var(a): return Variable(a, volatile=True) def bottle(m): return m.view(batch_size * beam_size, -1) def unbottle(m): return m.view(beam_size, batch_size, -1) def beam_update(e, idx, positions, beam_size): sizes = e.size() # [1, beam_size, hidden_size] br = sizes[1] if len(sizes) == 3: sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx] else: sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2], sizes[3])[:, :, idx] # [1, beam_size, hidden_size] indexed_before = sent_states.data.index_select(1, positions) sent_states.data.copy_(sent_states.data.index_select(1, positions)) indexed_after = sent_states.data.index_select(1, positions) # 1 beam object as we have batch_size 1 during decoding beam = [ Beam(beam_size, n_best=args.n_best, cuda=use_cuda, pad=pad_id, eos=eos_id, bos=sos_id, min_length=10) ] if init_beam: # id of last element in seq to init the beam for b in beam: b.next_ys[0][0] = np.asscalar(input.data.numpy()[0]) # [1, beam_size, hidden_size] hidden = hidden.repeat(1, beam_size, 1) # this comes from the known role id of the last seqence object. #if args.emb_type: #inp2 = role.repeat(1, beam_size) verb_list = [[]] * beam_size #for constraints # run the decoder to generate the sequence for i in range(max_len_decode): # one all beams have EOS break if all((b.done() for b in beam)): break # No need to explicitly set the input to previous output - beam advance does it. Make sure. inp = var( torch.stack([b.get_current_state() for b in beam ]).t().contiguous().view(-1)) #[beam_size] # Tested that the last output is the input in the next time step. # Run one step of the decoder # dec_out: beam x rnn_size inp = inp.unsqueeze(1) if args.emb_type: curr_idx = i % 5 # this gives the index of the role type: [tup, v, s, o, prep] curr_role = roles[curr_idx] # wrap into a tensor and make a var. repeat beam times inp2 = var(torch.LongTensor([curr_role])).repeat(beam_size, 1) logit, hidden = model(inp, hidden, inp2) else: logit, hidden = model(inp, hidden) # [1, beam_size, hidden_size] logit = torch.unsqueeze(logit, 0) probs = F.log_softmax(logit, dim=2).data out = unbottle(probs) # [beam_size, 1, vocab_size] out.log() # Advance each beam. We have 1 beam object. for j, b in enumerate(beam): #print("OUT: {}".format(out[:, j])) # [beam_size, vocab_size] if use_constraints: b.advance( ge.schema_constraint(out[:, j], b.next_ys[-1], verb_list)) else: b.advance(out[:, j]) beam_update(hidden, j, b.get_current_origin(), beam_size) if use_constraints: verb_list = ge.update_verb_list(verb_list, b, tup_idx) # extract sentences from beam and return ret = _from_beam(beam, args.n_best) return ret
def beam_decode(self, input, dhidden, latent_values, beam_size, max_len_decode, n_best=1, use_constraints=True, min_len_decode=0, init_beam=False): # dhidden [1, 1, hidden_size] # latent_values [1,3, hidden_size] # Fixed these values here for now batch_size = 1 assert beam_size >= n_best, "n_best cannot be greater than beam_size." # Helper functions for working with beams and batches def var(a): return Variable(a, volatile=True) def bottle(m): return m.view(batch_size * beam_size, -1) def unbottle(m): return m.view(beam_size, batch_size, -1) def beam_update(e, idx, positions, beam_size): sizes = e.size() # [1, beam_size, hidden_size] br = sizes[1] if len(sizes) == 3: sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx] else: sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2], sizes[3])[:, :, idx] # [1, beam_size, hidden_size] #print("POS: {}".format(positions)) indexed_before = sent_states.data.index_select(1, positions) #print("INDEXED BEFORE: {}".format(indexed_before)) sent_states.data.copy_(sent_states.data.index_select(1, positions)) indexed_after = sent_states.data.index_select(1, positions) #print("INDEXED BEFORE: {}".format(indexed_before)) # TESTED that this does change so not always True depending on the positions above #print("STATE SAME? {}".format(indexed_after.equal(indexed_before))) # 1. everything needs to be repeated for the beams # dec_input [beam_size] # prev_output [beam_size] # [1, beam_size, hidden_size] dhidden = dhidden.repeat(1, beam_size, 1) # [beam_size, hidden_size, 3] latent_values = latent_values.repeat(beam_size, 1, 1) # init after repeating everything if init_beam: #Init with the current input feed self.decoder.input_feed = self.decoder.input_feed.repeat( beam_size, 1) else: # init with beam_size zero if self.use_cuda: self.decoder.init_feed_( var(torch.zeros(beam_size, self.decoder.attn_dim).cuda()) ) #initialize the input feed 0 else: self.decoder.init_feed_( var(torch.zeros(beam_size, self.decoder.attn_dim))) # 2. beam object as we have batch_size 1 during decoding beam = [ Beam(beam_size, n_best=n_best, cuda=self.use_cuda, pad=1, eos=self.eos_idx, bos=self.sos_idx, min_length=10) ] if init_beam: #if init_beam is true, then input will be the initial input to init beam with for b in beam: b.next_ys[0][0] = np.asscalar(input.data.numpy()[0]) verb_list = [[]] * beam_size #for constraints # run the decoder to generate the sequence for i in range(max_len_decode): # one all beams have EOS break if all((b.done() for b in beam)): break # No need to explicitly set the input to previous output - beam advance does it. Make sure. inp = var( torch.stack([b.get_current_state() for b in beam ]).t().contiguous().view(-1)) #[beam_size] # Tested that the last output is the input in the next time step. #print("STEP {}".format(i)) #print("INPUT: {}".format(inp.data)) # Run one step of the decoder # dec_out: beam x rnn_size dec_output, dhidden = self.decoder(inp, dhidden, latent_values) # [1, beam_size, hidden_size] dec_output = torch.unsqueeze(dec_output, 0) logits = self.logits_out(dec_output) probs = F.log_softmax(logits, dim=2).data out = unbottle(probs) # [beam_size, 1, vocab_size] out.log() # Advance each beam. for j, b in enumerate(beam): if use_constraints: b.advance( ge.schema_constraint(out[:, j], b.next_ys[-1], verb_list, min_len_decode=min_len_decode, step=i, eos_idx=self.eos_idx)) else: b.advance(out[:, j]) # advance hidden state and input feed accordingly beam_update(dhidden, j, b.get_current_origin(), beam_size) beam_update(dec_output, j, b.get_current_origin(), beam_size) if use_constraints: verb_list = ge.update_verb_list( verb_list, b, self.tup_idx) #update list of currently used verbs self.decoder.input_feed = dec_output.squeeze( dim=0) # update input feed for the next step. # extract sentences (token ids) from beam and return ret = self._from_beam(beam, n_best=n_best)[0][0] # best hyp self.decoder.reset_feed_( ) #reset input feed so pytorch correctly cleans graph return ret