def test_argsort(self): keys = [5, 4, 3, 2, 1] items = ["five", "four", "three", "two", "one"] items2 = ["e", "d", "c", "b", "a"] torch_keys = torch.LongTensor(keys) assert argsort(keys, items, items2) == [ list(reversed(items)), list(reversed(items2)) ] assert argsort(keys, items, items2, descending=True) == [items, items2] assert np.all(argsort(torch_keys, torch_keys)[0].numpy() == np.arange(1, 6))
def batchify(self, *args, **kwargs): """Override batchify options for seq2seq.""" kwargs['sort'] = True # need sorted for pack_padded batch = super().batchify(*args, **kwargs) # Get some args needed for batchify obs_batch = args[0] sort = kwargs['sort'] is_valid = (lambda obs: 'text_vec' in obs or 'image' in obs ) # from TorchAgent.batchify # Run this part of TorchAgent's batchify to get exs in correct order # ==================== START COPIED FROM TORCHAGENT =================== if len(obs_batch) == 0: return Batch() valid_obs = [(i, ex) for i, ex in enumerate(obs_batch) if is_valid(ex)] if len(valid_obs) == 0: return Batch() valid_inds, exs = zip(*valid_obs) # TEXT xs, x_lens = None, None if any('text_vec' in ex for ex in exs): _xs = [ex.get('text_vec', self.EMPTY) for ex in exs] xs, x_lens = padded_tensor(_xs, self.NULL_IDX, self.use_cuda) if sort: sort = False # now we won't sort on labels xs, x_lens, valid_inds, exs = argsort(x_lens, xs, x_lens, valid_inds, exs, descending=True) # ======== END COPIED FROM TORCHAGENT ======== # Add history to the batch history = [ ConvAI2History(ex['text'], dictionary=self.dict) for ex in exs ] # Add CT control vars to batch ctrl_vec = get_ctrl_vec(exs, history, self.control_settings) # tensor or None if self.use_cuda and ctrl_vec is not None: ctrl_vec = ctrl_vec.cuda() # Replace the old namedtuple with a new one that includes ctrl_vec and history ControlBatch = namedtuple( 'Batch', tuple(batch.keys()) + ('ctrl_vec', 'history')) batch = ControlBatch(ctrl_vec=ctrl_vec, history=history, **dict(batch)) return batch
def batchify(self, obs_batch, sort=False, is_valid=lambda obs: 'text_vec' in obs or 'image' in obs): """Create a batch of valid observations from an unchecked batch. A valid observation is one that passes the lambda provided to the function, which defaults to checking if the preprocessed 'text_vec' field is present which would have been set by this agent's 'vectorize' function. Returns a namedtuple Batch. See original definition above for in-depth explanation of each field. If you want to include additonal fields in the batch, you can subclass this function and return your own "Batch" namedtuple: copy the Batch namedtuple at the top of this class, and then add whatever additional fields that you want to be able to access. You can then call super().batchify(...) to set up the original fields and then set up the additional fields in your subclass and return that batch instead. :param obs_batch: List of vectorized observations :param sort: Default False, orders the observations by length of vectors. Set to true when using torch.nn.utils.rnn.pack_padded_sequence. Uses the text vectors if available, otherwise uses the label vectors if available. :param is_valid: Function that checks if 'text_vec' is in the observation, determines if an observation is valid """ if len(obs_batch) == 0: return Batch() valid_obs = [(i, ex) for i, ex in enumerate(obs_batch) if is_valid(ex)] if len(valid_obs) == 0: return Batch() valid_inds, exs = zip(*valid_obs) # TEXT xs, x_lens = None, None if any('text_vec' in ex for ex in exs): _xs = [ex.get('text_vec', self.EMPTY) for ex in exs] xs, x_lens = padded_tensor(_xs, self.NULL_IDX, self.use_cuda) if sort: sort = False # now we won't sort on labels xs, x_lens, valid_inds, exs = argsort( x_lens, xs, x_lens, valid_inds, exs, descending=True ) # LABELS labels_avail = any('labels_vec' in ex for ex in exs) some_labels_avail = (labels_avail or any('eval_labels_vec' in ex for ex in exs)) ys, y_lens, labels = None, None, None if some_labels_avail: field = 'labels' if labels_avail else 'eval_labels' label_vecs = [ex.get(field + '_vec', self.EMPTY) for ex in exs] labels = [ex.get(field + '_choice') for ex in exs] y_lens = [y.shape[0] for y in label_vecs] ys, y_lens = padded_tensor(label_vecs, self.NULL_IDX, self.use_cuda) if sort and xs is None: ys, valid_inds, label_vecs, labels, y_lens = argsort( y_lens, ys, valid_inds, label_vecs, labels, y_lens, descending=True ) # LABEL_CANDIDATES cands, cand_vecs = None, None if any('label_candidates_vecs' in ex for ex in exs): cands = [ex.get('label_candidates', None) for ex in exs] cand_vecs = [ex.get('label_candidates_vecs', None) for ex in exs] # IMAGE imgs = None if any('image' in ex for ex in exs): imgs = [ex.get('image', None) for ex in exs] # MEMORIES mems = None if any('memory_vecs' in ex for ex in exs): mems = [ex.get('memory_vecs', None) for ex in exs] return Batch(text_vec=xs, text_lengths=x_lens, label_vec=ys, label_lengths=y_lens, labels=labels, valid_indices=valid_inds, candidates=cands, candidate_vecs=cand_vecs, image=imgs, memory_vecs=mems, observations=exs)
def forward(self, encoder_output, his_turn_end_ids): bsz = encoder_output.size(0) turn_lengths = [len(his_turn_end_ids[i]) for i in range(bsz)] in_batch_ids = [i for i in range(bsz)] his_max_len = max(turn_lengths) if his_max_len == 1: dli_loss = torch.zeros(1)[0].cuda() else: his_turn_states = torch.zeros(bsz, his_max_len, self.enc_dim).cuda() for i in range(bsz): end_ids = his_turn_end_ids[i] start_ids = his_turn_end_ids[i] + torch.ones(1)[0].cuda() start_ids = start_ids[:-1] start_0 = torch.zeros(1).long().cuda() start_ids = torch.cat([start_0, start_ids]) for j in range(len(start_ids)): s = start_ids[j] e = end_ids[j] tmp = torch.mean(encoder_output[i][s:e + 1], dim=0) his_turn_states[i][j] = tmp sorted_his_turn_states, sorted_in_batch_ids, sorted_turn_lengths = argsort( turn_lengths, his_turn_states, in_batch_ids, turn_lengths, descending=True) his_turn_states_packed = nn.utils.rnn.pack_sequence( sorted_his_turn_states) out_packed, _ = self.uni_lstm(his_turn_states_packed) out_padded, _ = pad_packed_sequence(out_packed, batch_first=True) after_sort_idxs = torch.LongTensor( argsort(sorted_in_batch_ids, in_batch_ids, descending=False)[0]).cuda() turns_encoder_out = torch.index_select(out_padded, 0, after_sort_idxs) all_pairs = [] all_gt = [] for i in range(bsz): for j in range(turn_lengths[i]): current_step_encoder_out = turns_encoder_out[i][j] tmp_pairs = [] tmp_gt = [] for k in range(j + 1, turn_lengths[i]): tmp_pairs.append( torch.cat([ current_step_encoder_out, his_turn_states[i][k] ], -1)) if k == j + 1: tmp_gt.append(1) else: tmp_gt.append(0) if len(tmp_pairs) != 0 and len(tmp_gt) != 0: all_pairs.append(torch.stack(tmp_pairs)) all_gt.append(tmp_gt) loss = [] for i in range(len(all_pairs)): final_out_i = self.con_fc(all_pairs[i]).squeeze(1).unsqueeze(0) ground_truth_i = torch.LongTensor([0]).cuda() len_i = final_out_i.size(0) dli_loss_i = self.c_loss(input=final_out_i, target=ground_truth_i) loss.append(dli_loss_i) dli_loss = torch.stack(loss) dli_loss = torch.mean(dli_loss) return dli_loss
def forward(self, input, his_turn_end_ids): """ input data is a FloatTensor of shape [batch, seq_len, dim] mask is a ByteTensor of shape [batch, seq_len], filled with 1 when inside the sequence and 0 outside. """ # print(input) # print(his_turn_end_ids) bsz = len(input) turn_lengths = [len(his_turn_end_ids[i]) for i in range(bsz)] his_turns = torch.zeros(bsz, self.max_turns, self.max_single_seq_len).long().cuda() mask = torch.zeros(bsz, self.max_turns).cuda() for i in range(bsz): end_ids = his_turn_end_ids[i] start_ids = his_turn_end_ids[i] + torch.ones(1)[0].cuda() start_ids = start_ids[:-1] start_0 = torch.zeros(1).long().cuda() start_ids = torch.cat([start_0, start_ids]) his_len = len(start_ids) if his_len <= self.max_turns: for j in range(his_len): s = start_ids[j] e = end_ids[j] if e - s < self.max_single_seq_len: his_turns[i][j][0:e + 1 - s] = input[i][s:e + 1] else: his_turns[i][j][0:self.max_single_seq_len] = input[i][ s:s + self.max_single_seq_len] mask[i][j] = torch.ones(1)[0].cuda() for k in range(his_len, self.max_turns): his_turns[i][k][0] = torch.ones(1)[0].long().cuda() else: longer = his_len - self.max_turns for j in range(his_len - self.max_turns, his_len): s = start_ids[j] e = end_ids[j] if e - s < self.max_single_seq_len: his_turns[i][j - longer][0:e + 1 - s] = input[i][s:e + 1] else: his_turns[i][ j - longer][0:self.max_single_seq_len] = input[i][ s:s + self.max_single_seq_len] mask[i][j - longer] = torch.ones(1)[0].cuda() his_turns = his_turns.view(-1, self.max_single_seq_len) xs = self.rnn_input_dropout(his_turns) xes = self.rnn_dropout(self.embeddings(xs)) attn_mask = xs.ne(0) x_lens = torch.sum(attn_mask.int(), dim=1) in_flatten_ids = [k for k in range(len(xs))] sorted_xes, sorted_in_flatten_ids, sorted_x_lens = argsort( x_lens, xes, in_flatten_ids, x_lens, descending=True) xes_packed = pack_padded_sequence(sorted_xes, sorted_x_lens, batch_first=True) # xes_packed = pack_sequence(sorted_xes) out_packed, _ = self.rnn(xes_packed) out_padded, _ = pad_packed_sequence(out_packed, batch_first=True) after_sort_idxs = torch.LongTensor( argsort(sorted_in_flatten_ids, in_flatten_ids, descending=False)[0]).cuda() his_encoder_outs = torch.index_select(out_padded, 0, after_sort_idxs) real_max_seq_len = his_encoder_outs.size(1) his_encoder_outs = his_encoder_outs.view(bsz, self.max_turns, real_max_seq_len, self.rnn_hsz) expand_mask = mask.unsqueeze(-1).expand( bsz, self.max_turns, real_max_seq_len * self.rnn_hsz) expand_mask = expand_mask.view(bsz, self.max_turns, real_max_seq_len, self.rnn_hsz) his_encoder_outs = his_encoder_outs.mul(expand_mask) final_encoder_outs = [] for i in range(bsz): for j in range(self.max_turns): for k in range(real_max_seq_len): if len(torch.nonzero(his_encoder_outs[i][j][k])) != 0: tmp = his_encoder_outs[i][j][k] else: break final_encoder_outs.append(tmp) final_encoder_outs = torch.stack(final_encoder_outs).view( bsz, self.max_turns, -1) positions = mask.new(self.max_turns).long() positions = torch.arange(self.max_turns, out=positions).unsqueeze(0) tensor = final_encoder_outs if self.embeddings_scale: tensor = tensor * np.sqrt(self.dim) tensor = tensor + self.position_embeddings(positions).expand_as(tensor) tensor *= mask.unsqueeze(-1).float() for i in range(self.n_layers): tensor = self.layers[i](tensor, mask) if self.reduction: divisor = mask.float().sum(dim=1).unsqueeze(-1).clamp(min=1e-20) output = tensor.sum(dim=1) / divisor return output else: output = tensor return output, mask
def eval_step(self, batch): """Process batch of inputs. If the batch includes labels, calculate validation metrics as well. If --skip-generation is not set, return a prediction for each input. :param batch: parlai.core.torch_agent.Batch, contains tensorized version of observations. """ if batch.text_vec is None: return self.is_training = False samples = self._make_sample(batch.text_vec, batch.label_vec) self.model.eval() if batch.label_vec is not None: # Interactive mode won't have a gold label self.trainer.valid_step(samples) # Output placeholders reranked_cands = None generated_output = None # Grade each of the candidate sequences if batch.candidate_vecs is not None: bsz = len(batch.text_vec) reranked_cands = [] # score the candidates for each item in the batch separately, so that # we can support variable number of candidates for i in range(bsz): cands = batch.candidate_vecs[i] if not cands: reranked_cands.append(None) continue ncand = len(cands) # repeat the input many times xs = batch.text_vec[i].unsqueeze(0).expand(ncand, -1) # some models crash if there's leading padding on every example xs = xs[:, :batch.text_lengths[i]] # and appropriately pack the outputs ys, _ = padded_tensor(cands, self.NULL_IDX, self.use_cuda) s = self._make_sample(xs, ys) # perform the actual grading, extract the scores scored = list( self.scorer.score_batched_itr([s], cuda=self.use_cuda)) scores = [s[3][0]['score'].item() for s in scored] # intentional hanging comma here; argsort returns a list ranked, = argsort(scores, batch.candidates[i], descending=True) reranked_cands.append(ranked) # Next generate freely to create our response if not self.args.skip_generation: generated_output = self._generate(samples) elif reranked_cands: # we're skiping generation, but we're also grading candidates # so output the highest ranked candidate # In the case of zero candidates, we don't have something to rank, # so we may need to pass on that None generated_output = [ ranked and ranked[0] or None for ranked in reranked_cands ] else: # no output at all pass return Output(generated_output, reranked_cands)
def forward(self, input, his_turn_end_ids): """Encode sequence. :param input: (bsz x seqlen) LongTensor of input token indices :returns: encoder outputs, hidden state, attention mask encoder outputs are the output state at each step of the encoding. the hidden state is the final hidden state of the encoder. the attention mask is a mask of which input values are nonzero. """ bsz = len(input) turn_lengths = [len(his_turn_end_ids[i]) for i in range(bsz)] his_turns = torch.zeros(bsz, self.max_turns, self.max_single_seq_len).long().cuda() mask = torch.zeros(bsz, self.max_turns).cuda() for i in range(bsz): end_ids = his_turn_end_ids[i] start_ids = his_turn_end_ids[i] + torch.ones(1)[0].cuda() start_ids = start_ids[:-1] start_0 = torch.zeros(1).long().cuda() start_ids = torch.cat([start_0, start_ids]) his_len = len(start_ids) if his_len <= self.max_turns: for j in range(his_len): s = start_ids[j] e = end_ids[j] if e - s < self.max_single_seq_len: his_turns[i][j][0:e + 1 - s] = input[i][s:e + 1] else: his_turns[i][j][0:self.max_single_seq_len] = input[i][ s:s + self.max_single_seq_len] mask[i][j] = torch.ones(1)[0].cuda() for k in range(his_len, self.max_turns): his_turns[i][k][0] = torch.ones(1)[0].long().cuda() else: longer = his_len - self.max_turns for j in range(his_len - self.max_turns, his_len): s = start_ids[j] e = end_ids[j] if e - s < self.max_single_seq_len: his_turns[i][j - longer][0:e + 1 - s] = input[i][s:e + 1] else: his_turns[i][ j - longer][0:self.max_single_seq_len] = input[i][ s:s + self.max_single_seq_len] mask[i][j - longer] = torch.ones(1)[0].cuda() his_turns = his_turns.view(-1, self.max_single_seq_len) xs = self.input_dropout(his_turns) xes = self.dropout(self.lt(xs)) attn_mask = xs.ne(0) x_lens = torch.sum(attn_mask.int(), dim=1) in_flatten_ids = [k for k in range(len(xs))] sorted_xes, sorted_in_flatten_ids, sorted_x_lens = argsort( x_lens, xes, in_flatten_ids, x_lens, descending=True) xes_packed = pack_padded_sequence(sorted_xes, sorted_x_lens, batch_first=True) # xes_packed = pack_sequence(sorted_xes) out_packed, _ = self.rnn(xes_packed) out_padded, _ = pad_packed_sequence(out_packed, batch_first=True) after_sort_idxs = torch.LongTensor( argsort(sorted_in_flatten_ids, in_flatten_ids, descending=False)[0]).cuda() his_encoder_outs = torch.index_select(out_padded, 0, after_sort_idxs) real_max_seq_len = his_encoder_outs.size(1) his_encoder_outs = his_encoder_outs.view(bsz, self.max_turns, real_max_seq_len, self.hsz) expand_mask = mask.unsqueeze(-1).expand(bsz, self.max_turns, real_max_seq_len * self.hsz) expand_mask = expand_mask.view(bsz, self.max_turns, real_max_seq_len, self.hsz) his_encoder_outs = his_encoder_outs.mul(expand_mask) final_encoder_outs = [] for i in range(bsz): for j in range(self.max_turns): for k in range(real_max_seq_len): if len(torch.nonzero(his_encoder_outs[i][j][k])) != 0: tmp = his_encoder_outs[i][j][k] else: break final_encoder_outs.append(tmp) final_encoder_outs = torch.stack(final_encoder_outs).view( bsz, self.max_turns, -1) hier_xes = final_encoder_outs hier_x_lens = torch.sum(mask.int(), dim=1) in_example_ids = [k for k in range(len(hier_xes))] sorted_hier_xes, sorted_in_example_ids, sorted_hier_x_lens = argsort( hier_x_lens, hier_xes, in_example_ids, hier_x_lens, descending=True) hier_xes_packed = pack_padded_sequence(sorted_hier_xes, sorted_hier_x_lens, batch_first=True) hier_out_packed, hier_hidden_packed = self.hier_rnn(hier_xes_packed) hier_out_padded, _ = pad_packed_sequence(hier_out_packed, batch_first=True) hier_after_sort_idxs = torch.LongTensor( argsort(sorted_in_example_ids, in_example_ids, descending=False)[0]).cuda() hier_his_encoder_outs = torch.index_select(hier_out_padded, 0, hier_after_sort_idxs) real_max_his_n_turn = hier_his_encoder_outs.size(1) hier_final_encoder_outs = hier_his_encoder_outs.view( bsz, real_max_his_n_turn, -1) transpose_hidden = _transpose_hidden_state(hier_hidden_packed) hier_fianl_hidden = torch.index_select(transpose_hidden, 0, hier_after_sort_idxs) hier_attn_mask = torch.zeros(bsz, real_max_his_n_turn).cuda() for i in range(bsz): for j in range(real_max_his_n_turn): hier_attn_mask[i][j] = mask[i][j] # print(hier_attn_mask.size()) # print(turn_lengths) if self.hier_dirs > 1: # project to decoder dimension by taking sum of forward and back if isinstance(self.hier_rnn, nn.LSTM): hier_fianl_hidden = (hier_fianl_hidden[0].view( -1, self.hier_dirs, bsz, self.hier_hsz).sum(1), hier_fianl_hidden[1].view( -1, self.hier_dirs, bsz, self.hier_hsz).sum(1)) else: hier_fianl_hidden = hier_fianl_hidden.view( -1, self.hier_dirs, bsz, self.hier_hsz).sum(1) hier_fianl_hidden = _transpose_hidden_state(hier_final_hidden) return hier_final_encoder_outs, hier_fianl_hidden, hier_attn_mask