def make_batch(self, seq, gpu): tokens = [self.vocab.start_token] + seq + [self.vocab.stop_token] idxs = torch.LongTensor([[self.vocab.index(t)] for t in tokens]) lengths = torch.LongTensor([len(tokens)]) encoder_inputs = Variable(idxs, lengths, batch_dim=1, length_dim=0, pad_value=0) return {'encoder_input': encoder_inputs.cuda(gpu)}
def init_state(self, batch_size, encoder_state): output = Variable( torch.LongTensor([self.vocab.start_index] * batch_size)\ .view(1, -1), lengths=torch.LongTensor([1] * batch_size), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) if str(encoder_state.device) != "cpu": output = output.cuda(encoder_state.device) return {"output": output, "decoder_state": encoder_state}
def make_generator_inputs(self, data): tokens = [self.source_vocab.start_token] + data["source"] \ + [self.source_vocab.stop_token] inputs = Variable(torch.LongTensor( [[self.source_vocab[t] for t in tokens]]).t(), lengths=torch.LongTensor([len(tokens)]), length_dim=0, batch_dim=1, pad_value=self.source_vocab.pad_index) if self._gpu > -1: inputs = inputs.cuda(self._gpu) return {"source_inputs": inputs}
def _batch_labels(self, labels_batch): input_tokens = torch.LongTensor( [[self.source_vocab[tok] for tok in self._labels2input(labels)] for labels in labels_batch]) length = input_tokens.size(1) inputs = Variable(input_tokens.t(), lengths=torch.LongTensor([length] * len(labels_batch)), length_dim=0, batch_dim=1, pad_value=-1) if self._gpu > -1: inputs = inputs.cuda(self._gpu) return {"source_inputs": inputs}
def next_state(self, decoder, prev_state, context, active_items, controls): # Get next state from the decoder. next_state = decoder.next_state(prev_state, context, controls=controls) log_probs = next_state["log_probs"].tensor samples = torch.distributions.Categorical( logits=next_state["target_logits"].tensor).sample() sample_output = Variable(samples, lengths=samples.new(samples.size(1)).fill_(1), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) output_log_probs = log_probs.gather(2, samples.unsqueeze(-1))\ .squeeze(2) next_state["output"] = sample_output next_state["output_log_probs"] = output_log_probs # Mask outputs if we have already completed that batch item. next_state["output"].data.view(-1).masked_fill_( ~active_items, self.vocab.pad_index) return next_state
def _variable_forward(self, query, key, value): key = self.key_net(key) query = self.query_net(query) key = key.permute_as_sequence_batch_features() query = query.permute_as_sequence_batch_features() assert key.dim() == query.dim() == 3 # TODO use named tensors to allow aribitrary seq dims with torch.no_grad(): mask = ~torch.einsum("qbh,kbh->qkb", [(~query.mask).float(), (~key.mask).float()]).byte() query_uns = query.data.unsqueeze(query.length_dim + 1) key_uns = key.data.unsqueeze(query.length_dim) hidden = torch.tanh(key_uns + query_uns) scores = hidden.matmul(self.weight) scores = scores.masked_fill(mask, float("-inf")) attention = torch.softmax(scores.transpose(1, 2), dim=2) attention = attention.masked_fill(attention != attention, 0.) comp = torch.einsum("ijk,kjh->ijh", [attention, self.value_net(value).data]) comp = Variable(comp, lengths=query.lengths, length_dim=0, batch_dim=1) return {"attention": attention, "output": comp}
def init_state(self, batch_size, encoder_state): beam_state = {} n, bs, ds = encoder_state.size() assert batch_size == bs beam_state["decoder_state"] = encoder_state.unsqueeze(2)\ .repeat(1, 1, self.beam_size, 1)\ .view(n, batch_size * self.beam_size, ds) beam_state["output"] = Variable( torch.LongTensor( [self.vocab.start_index] * batch_size * self.beam_size)\ .view(1, -1), lengths=torch.LongTensor([1] * batch_size * self.beam_size), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) if str(encoder_state.device) != "cpu": beam_state["output"] = beam_state["output"].cuda( encoder_state.device) # Start the first beam of each batch with 0 log prob, and all others # with -inf. beam_state["accum_log_prob"] = (self._init_accum_log_probs( batch_size, encoder_state.device)) # At the first time step no sequences have been terminated so this mask # is all 0s. beam_state["terminal_mask"] = (encoder_state.new().byte().new( 1, batch_size * self.beam_size, 1).fill_(0)) return beam_state
def output(self, as_indices=False, n_best=-1): if n_best < 1: o = self._output[:,0] if as_indices: olen = (self._lengths[:,0]) return Variable(o, lengths=olen, batch_dim=0, length_dim=1, pad_value=int(self.vocab.pad_index)) tokens = [] for row in o: tokens.append([self.vocab.token(t) for t in row if t != self.vocab.pad_index]) return tokens elif n_best < self.beam_size: o = self._output[:,:n_best] else: o = self._output if as_indices: print(o) print(self._lengths) return o beams = [] for beam in o: tokens = [] for row in beam: tokens.append([self.vocab.token(t) for t in row if t != self.vocab.pad_index]) beams.append(tokens) return beams
def init_state(self, batch_size, device, source_indices, encoder_state=None): output = Variable( torch.LongTensor([self.vocab.start_index] * batch_size)\ .view(1, -1), lengths=torch.LongTensor([1] * batch_size), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) if str(device) != "cpu": output = output.cuda(device) mask = source_indices.mask.t() return {"target": output, "decoder_state": encoder_state, "inputs": output, 'pointer_mask': mask}
def seq2tsr(seq, vocab): length = torch.LongTensor([len(seq)]) tensor = torch.LongTensor([[vocab[x] for x in seq]]).t() return Variable(tensor, lengths=length, length_dim=0, batch_dim=1, pad_value=vocab.pad_index)
def search2inputs(search_outputs, vocab, gpu): input_lens = [len(x) + 1 for x in search_outputs] max_len = max(input_lens) inputs = [] for out in search_outputs: out = [vocab.start_token] + out if len(out) < max_len: out = out + [vocab.pad_token] * (max_len - len(out)) inputs.append([vocab[t] for t in out]) inputs = torch.LongTensor(inputs) input_lens = torch.LongTensor(input_lens) inputs = Variable( inputs.t(), lengths=input_lens, batch_dim=1, length_dim=0, pad_value=vocab.pad_index) if gpu > -1: inputs = inputs.cuda(gpu) return {"inputs": inputs}
def make_batch(self, input_tokens, gpu): input_indices = [self.src_vocab.index(t) for t in input_tokens] input_indices = [self.src_vocab.start_index] \ + input_indices + [self.src_vocab.stop_index] input_indices = torch.LongTensor([input_indices]) lengths = torch.LongTensor([len(input_tokens) + 2]) variable = Variable(input_indices.t(), lengths=lengths, length_dim=0, batch_dim=1, pad_value=self.src_vocab.pad_index).cuda(gpu) return {'encoder_input': variable}
def _next_candidates(self, batch_size, log_probs, candidates): # TODO seq_lps should really be called cumulative log probs. # flat_beam_lps (batch size x (beam size ** 2)) flat_beam_lps = log_probs.view(batch_size, -1) flat_beam_scores = flat_beam_lps / (self.steps + 1) # beam_seq_scores (batch size x beam size) # relative_indices (batch_size x beam size) beam_seq_scores, relative_indices = torch.topk( flat_beam_scores, k=self.beam_size, dim=1) relative_indices_mask = ( (relative_indices >= flat_beam_scores.size(1)) | (relative_indices < 0) ) relative_indices = relative_indices.masked_fill( relative_indices_mask, 0) # beam_seq_lps (batch size x beam size) beam_seq_lps = flat_beam_lps.gather(1, relative_indices) # TODO make these ahead of time. offset1 = ( torch.arange(batch_size, device=beam_seq_lps.device) * self.beam_size ).view(batch_size, 1) offset2 = offset1 * self.beam_size beam_indexing = ((relative_indices // self.beam_size) + offset1) \ .view(-1) # beam_seq_lps (1 x (batch_size * beam_size) x 1) beam_seq_lps = beam_seq_lps \ .view(1, batch_size * self.beam_size, 1) # beam_seq_scores (1 x (batch_size * beam_size) x 1) beam_seq_scores = beam_seq_scores \ .view(1, batch_size * self.beam_size, 1) # next_output (1 x (batch size * beam size)) next_candidate_indices = (relative_indices + offset2).view(-1) next_output = Variable( candidates.view(-1)[next_candidate_indices].view(1, -1), lengths=candidates.new().long().new(batch_size * self.beam_size)\ .fill_(1), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) return beam_seq_lps, beam_seq_scores, next_output, beam_indexing
def init_state_context(self, encoder_state): batch_size = encoder_state["state"].size(1) * self.samples output = Variable(torch.LongTensor([self.vocab.start_index] * batch_size).view(1, -1), lengths=torch.LongTensor([1] * batch_size), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) if str(encoder_state["state"].device) != "cpu": output = output.cuda(encoder_state["state"].device) layers = encoder_state["state"].size(0) decoder_state = encoder_state["state"].unsqueeze(2)\ .repeat(1, 1, self.samples, 1).view(layers, batch_size, -1) search_state = {"output": output, "decoder_state": decoder_state} encoder_output = encoder_state["output"].repeat_batch_dim(self.samples) context = {"encoder_output": encoder_output} return search_state, context
def init_state(self, batch_size, encoder_state, encoder_input, device): beam_state = {} if encoder_state: n, bs, ds = encoder_state.size() assert batch_size == bs beam_state["decoder_state"] = encoder_state.unsqueeze(2)\ .repeat(1, 1, self.beam_size, 1)\ .view(n, batch_size * self.beam_size, ds) ei = encoder_input.permute_as_batch_sequence_features().data cons = ei.new().bool().new(batch_size, len(self.vocab)).fill_(1) cons = cons.scatter_(1, ei, 0) cons = cons.unsqueeze(1).repeat(1, self.beam_size, 1).view(batch_size * self.beam_size, -1) cons[:, self.vocab.stop_index] = 1 for v in [ "inform", "give_opinion", "verify_attribute", "request", "recommend", "request_explanation", "confirm", "suggest", "request_attribute", "<pad>", "<sos>", 'rating=N/A', 'rating=average', 'rating=excellent', 'rating=good', 'rating=poor' ]: if v in self.vocab: cons[:, self.vocab.index(v)] = 1 beam_state['cons'] = cons beam_state["target"] = Variable( torch.LongTensor( [self.vocab.start_index] * batch_size * self.beam_size)\ .view(1, -1), lengths=torch.LongTensor([1] * batch_size * self.beam_size), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) beam_state["target"] = beam_state["target"].to(device) beam_state['inputs'] = beam_state['target'] # Start the first beam of each batch with 0 log prob, and all others # with -inf. beam_state["accum_log_prob"] = (self._init_accum_log_probs( batch_size, device)) # At the first time step no sequences have been terminated so this mask # is all 0s. beam_state["terminal_mask"] = (torch.BoolTensor( 1, batch_size * self.beam_size, 1).fill_(0).to(device)) return beam_state
def _make_classifier_inputs(self, batch): lens = [] clf_inputs = [] for out in batch: clf_inputs.append([self.input_vocab.start_index] + [self.input_vocab[t] for t in out[:-1]] + [self.input_vocab.stop_index]) lens.append(len(out) + 1) lens = torch.LongTensor(lens) max_len = lens.max().item() clf_inputs = torch.LongTensor([ inp + [self.input_vocab.pad_index] * (max_len - len(inp)) for inp in clf_inputs ]).t() clf_inputs = Variable(clf_inputs, lengths=lens, batch_dim=1, length_dim=0, pad_value=self.input_vocab.pad_index) if self.gpu > -1: return clf_inputs.cuda(self.gpu) else: return clf_inputs
def output(self, as_indices=False): if as_indices: lengths = (~self._mask_T).sum(0) return Variable(self._outputs, lengths=lengths, length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index)\ .apply_sequence_mask() tokens = [] for output in self._outputs.t(): tokens_i = [] for index in output: if index in [self.vocab.pad_index, self.vocab.stop_index]: break tokens_i.append(self.vocab.token(index)) tokens.append(tokens_i) return tokens
def init_state(self, batch_size, device, source_indices, encoder_state): beam_state = {} if encoder_state: n, bs, ds = encoder_state.size() assert batch_size == bs beam_state["decoder_state"] = encoder_state.unsqueeze(2)\ .repeat(1, 1, self.beam_size, 1)\ .view(n, batch_size * self.beam_size, ds) beam_state["target"] = Variable( torch.LongTensor( [self.vocab.start_index] * batch_size * self.beam_size)\ .view(1, -1), lengths=torch.LongTensor([1] * batch_size * self.beam_size), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) beam_state["target"] = beam_state["target"].to(device) beam_state['inputs'] = beam_state['target'] source_indices = source_indices.permute_as_batch_sequence_features() mask = source_indices.mask.unsqueeze(1).repeat(1, self.beam_size, 1)\ .contiguous().view(-1, source_indices.size(1)) beam_state['pointer_mask'] = mask # Start the first beam of each batch with 0 log prob, and all others # with -inf. beam_state["accum_log_prob"] = ( self._init_accum_log_probs(batch_size, device) ) # At the first time step no sequences have been terminated so this mask # is all 0s. beam_state["terminal_mask"] = ( torch.BoolTensor( 1, batch_size * self.beam_size, 1).fill_(0).to(device) ) return beam_state
def next_state(self, decoder, prev_state, context, active_items, controls): # Get next state from the decoder. next_state = decoder.next_state(prev_state, context, controls=controls) next_pointer = next_state['log_probs']\ .permute_as_sequence_batch_features().max(2)[1]\ target = context['source_indices'].data.t().gather( 1, next_pointer.data.t()) target = prev_state['target'].new_with_meta(target.t()) next_state['target'] = target pm = prev_state['pointer_mask'] pm.scatter_(1, next_pointer.data.t(), 1) next_state['pointer_mask'] = pm # Mask outputs if we have already completed that batch item. next_state["target"].data.view(-1).masked_fill_( ~active_items, int(self.vocab.pad_index)) next_inputs = torch.cat( [ prev_state['inputs'].data, next_state['target'].data, ], 0) next_inputs = Variable( next_inputs, prev_state['inputs'].lengths + 1, length_dim=0, batch_dim=1, pad_value=prev_state['inputs'].pad_value) next_state['inputs'] = next_inputs # return next_state
def __call__(self, decoder, encoder_state, controls=None): self.reset() batch_size = encoder_state["batch_size"] device = encoder_state["device"] search_state = self.init_state(batch_size, device, encoder_state['source_indices'], encoder_state.get("state", None)) context = { "encoder_output": encoder_state["output"], "source_indices": encoder_state['source_indices'], } active_items = torch.BoolTensor(batch_size).fill_(1) output_length = torch.LongTensor(batch_size).fill_(0) if str(encoder_state["device"]) != "cpu": active_items = active_items.cuda(encoder_state["device"]) output_length = output_length.cuda(encoder_state["device"]) prev_decoder_inputs = search_state["target"] step_masks = [] # Perform search until we either trigger a termination condition for # each batch item or we reach the maximum number of search steps. while self.steps < self.max_steps and not self.is_finished: inactive_items = ~active_items # Mask any inputs that are finished, so that greedy would # be identitcal to forward passes. search_state["target"].data.view(-1).masked_fill_( inactive_items, int(self.vocab.pad_index)) step_masks.append(inactive_items) self.steps += 1 search_state = self.next_state( decoder, search_state, context, active_items, controls) self._states.append(search_state) self._outputs.append(search_state["target"].clone()\ .permute_as_sequence_batch_features()) output_length = output_length + active_items.long() temp_outputs = torch.cat( [prev_decoder_inputs.tensor] + \ [x.tensor for x in self._outputs], 0) temp_outputs = Variable(temp_outputs, lengths=output_length + 1, batch_dim=1, length_dim=0, pad_value=self.vocab.pad_index) search_state["prev_output"] = temp_outputs active_items = self.check_termination(search_state, active_items) self.is_finished = torch.all(~active_items) # Finish the search by collecting final sequences, and other # stats. self._collect_search_states(active_items) self._incomplete_items = active_items self._is_finished = True self._mask_T = torch.stack(step_masks) self._mask = self._mask_T.t().contiguous() return self
def next_state(self, decoder, batch_size, prev_state, context, active_items, controls): # Get next state from the decoder. next_state = decoder.next_state(prev_state, context, controls=controls) # Compute the top beam_size next outputs for each beam item. # topk_lps (1 x batch size x beam size x beam size) # candidate_outputs (1 x batch size x beam size x beam size) topk_lps, candidate_outputs = torch.topk( next_state["log_probs"].data \ .view(1, batch_size, self.beam_size, -1), k=self.beam_size, dim=3) # If any sequence was completed last step, we should mask it's log # prob so that we don't generate from the terminal token. # slp (1 x batch_size x beam size x 1) slp = prev_state["accum_log_prob"] \ .masked_fill(prev_state["terminal_mask"], float("-inf")) \ .view(1, batch_size, self.beam_size, 1) # Combine next step log probs with the previous sequences cumulative # log probs, i.e. # log P(y_t) = log P(y_<t) + log P(y_t) # candidate_log_probs (1 x batch size x beam size x beam size) candidate_log_probs = slp + topk_lps # Rerank and select the beam_size best options from the available # beam_size ** 2 candidates. # b_seq_lps (1 x (batch size * beam size) x 1) # b_scores (1 x (batch size * beam size) x 1) # b_pointer (1 x (batch size * beam size)) # b_indices ((batch size * beam size)) b_seq_lps, b_scores, b_pointer, b_indices = self._next_candidates( batch_size, candidate_log_probs, candidate_outputs) b_pointer_mask = ( (b_pointer.data >= context['source_indices'].size(1)) | (b_pointer.data < 0) ) b_pointer = b_pointer.masked_fill(b_pointer_mask, 0) # print(context['source_indices']) # print(b_pointer) b_output = context['source_indices'].gather(1, b_pointer.data.t()) b_output = b_pointer.new_with_meta(b_output.t()) # print(b_output) # print(b_indices) pm = prev_state['pointer_mask'] # print(pm.size()) # print(pm.long()) pm = pm.index_select(0, b_indices) # print(pm.long()) pm.scatter_(1, b_pointer.data.t(), 1) # print(pm.long()) # print(pm.size()) # TODO re-implement this behavior #next_state.stage_indexing("batch", b_indices) # print("BEAM INDICES") # print(b_indices) # print(prev_state['inputs']) # print(prev_state['inputs'].index_select(1, b_indices)) # print(type(prev_state['inputs'].index_select(1, b_indices))) # print(b_output) # print(type(b_output)) next_inputs = torch.cat( [ prev_state['inputs'].index_select(1, b_indices).data, b_output.data, ], 0) next_inputs = Variable( next_inputs, prev_state['inputs'].lengths + 1, length_dim=0, batch_dim=1, pad_value=prev_state['inputs'].pad_value) # print(next_inputs) next_state = { "decoder_state": next_state["decoder_state"]\ .index_select(1, b_indices), "target": b_output, "accum_log_prob": b_seq_lps, "beam_score": b_scores, "beam_indices": b_indices, "inputs": next_inputs, 'pointer_mask': pm, } return next_state #exit() #next_state = {"decoder_state": next_state["decoder_state"] #print(next_state.keys()) next_state["output"] = (b_output, ("batch", "sequence")) next_state["cumulative_log_probability"] = ( b_seq_lps, ("sequence", "batch", "placeholder") ) next_state["beam_score"] = ( b_scores, ("sequence", "batch", "placeholder") ) next_state["beam_indices"] = (b_indices, ("batch")) return next_state