def beam_search(self, hidden): sequence_len = self.max_len beams = [ Beam(hidden=self.decoder_state(hidden), inputs=self.decoder_initial_inputs(1).squeeze(0)), ] for di in range(sequence_len): beams_current = [] decoder_hidden = torch.stack([b.hidden for b in beams]) decoder_inputs = torch.stack([b.inputs for b in beams]) beam_scores = torch.stack([b.score for b in beams]) decoder_inputs = self.embedding(decoder_inputs) if self.dropout is not None: decoder_inputs = self.dropout(decoder_inputs) decoder_hidden = self.decoder(decoder_inputs, decoder_hidden) decoder_outputs = decoder_hidden if self.dropout is not None: decoder_outputs = self.dropout(decoder_outputs) out = self.out(decoder_outputs) out_probs = F.softmax(out, dim=-1) out_scores = F.log_softmax(out, dim=-1) candidates = torch.multinomial(out_probs, self.beam_sample) # candidates_scores = torch.index_select(out_scores, 0, candidates) candidates_scores = torch.gather(out_scores, 1, candidates) beams_current_scores = beam_scores.unsqueeze(1) + candidates_scores beams_current_scores = beams_current_scores.view(-1) beams_current_hidden_indices = variable(torch.arange(len(beams))) beams_current_hidden_indices = beams_current_hidden_indices.unsqueeze(1).expand(-1, self.beam_sample) beams_current_hidden_indices = beams_current_hidden_indices.contiguous().view(-1) beams_current_candidates = candidates.view(-1) top_scores, top_indices = torch.topk(beams_current_scores.view(-1), self.beam_width) top_hidden_indices = beams_current_hidden_indices[top_indices] top_candidates = beams_current_candidates[top_indices] for hidden_idx, candidate, score in zip(top_hidden_indices, top_candidates, top_scores): hidden_idx = int(hidden_idx) beam = beams[hidden_idx] beam_hidden = decoder_hidden[hidden_idx] beams_current.append( beam.step(score, candidate, beam_hidden) ) beams = beams_current outputs = variable(np.array(beams[0].sequence)) return outputs
def zero_state(self, batch_size): state_shape = (batch_size, self.hidden_size) h0 = [ variable(torch.zeros(*state_shape)) for _ in range(self.nb_layers) ] return h0
def _forward_pass(self, batch): batch = variable(batch) inputs, targets = batch inputs = tuplify(inputs) targets = tuplify(targets) outputs = self.model(*inputs) outputs = tuplify(outputs) return outputs, targets
def __init__(self, hidden, inputs): self.sequence = [] self.score = variable(torch.ones(1)).squeeze() self.hidden = hidden self.inputs = inputs
def decoder_initial_inputs(self, batch_size): inputs = variable(torch.from_numpy(np.full((1,), self.init_token, dtype=np.long)).expand((batch_size,))) return inputs