Exemplo n.º 1
0
    def test_argsort(self):
        keys = [5, 4, 3, 2, 1]
        items = ["five", "four", "three", "two", "one"]
        items2 = ["e", "d", "c", "b", "a"]
        torch_keys = torch.LongTensor(keys)
        assert argsort(keys, items, items2) == [
            list(reversed(items)), list(reversed(items2))
        ]
        assert argsort(keys, items, items2, descending=True) == [items, items2]

        assert np.all(argsort(torch_keys, torch_keys)[0].numpy() == np.arange(1, 6))
Exemplo n.º 2
0
    def batchify(self, *args, **kwargs):
        """Override batchify options for seq2seq."""
        kwargs['sort'] = True  # need sorted for pack_padded
        batch = super().batchify(*args, **kwargs)

        # Get some args needed for batchify
        obs_batch = args[0]
        sort = kwargs['sort']
        is_valid = (lambda obs: 'text_vec' in obs or 'image' in obs
                    )  # from TorchAgent.batchify

        # Run this part of TorchAgent's batchify to get exs in correct order

        # ==================== START COPIED FROM TORCHAGENT ===================
        if len(obs_batch) == 0:
            return Batch()

        valid_obs = [(i, ex) for i, ex in enumerate(obs_batch) if is_valid(ex)]

        if len(valid_obs) == 0:
            return Batch()

        valid_inds, exs = zip(*valid_obs)

        # TEXT
        xs, x_lens = None, None
        if any('text_vec' in ex for ex in exs):
            _xs = [ex.get('text_vec', self.EMPTY) for ex in exs]
            xs, x_lens = padded_tensor(_xs, self.NULL_IDX, self.use_cuda)
            if sort:
                sort = False  # now we won't sort on labels
                xs, x_lens, valid_inds, exs = argsort(x_lens,
                                                      xs,
                                                      x_lens,
                                                      valid_inds,
                                                      exs,
                                                      descending=True)

        # ======== END COPIED FROM TORCHAGENT ========

        # Add history to the batch
        history = [
            ConvAI2History(ex['text'], dictionary=self.dict) for ex in exs
        ]

        # Add CT control vars to batch
        ctrl_vec = get_ctrl_vec(exs, history,
                                self.control_settings)  # tensor or None
        if self.use_cuda and ctrl_vec is not None:
            ctrl_vec = ctrl_vec.cuda()

        # Replace the old namedtuple with a new one that includes ctrl_vec and history
        ControlBatch = namedtuple(
            'Batch',
            tuple(batch.keys()) + ('ctrl_vec', 'history'))
        batch = ControlBatch(ctrl_vec=ctrl_vec, history=history, **dict(batch))

        return batch
Exemplo n.º 3
0
    def batchify(self, obs_batch, sort=False,
                 is_valid=lambda obs: 'text_vec' in obs or 'image' in obs):
        """Create a batch of valid observations from an unchecked batch.

        A valid observation is one that passes the lambda provided to the
        function, which defaults to checking if the preprocessed 'text_vec'
        field is present which would have been set by this agent's 'vectorize'
        function.

        Returns a namedtuple Batch. See original definition above for in-depth
        explanation of each field.

        If you want to include additonal fields in the batch, you can subclass
        this function and return your own "Batch" namedtuple: copy the Batch
        namedtuple at the top of this class, and then add whatever additional
        fields that you want to be able to access. You can then call
        super().batchify(...) to set up the original fields and then set up the
        additional fields in your subclass and return that batch instead.

        :param obs_batch: List of vectorized observations
        :param sort:      Default False, orders the observations by length of
                          vectors. Set to true when using
                          torch.nn.utils.rnn.pack_padded_sequence.
                          Uses the text vectors if available, otherwise uses
                          the label vectors if available.
        :param is_valid:  Function that checks if 'text_vec' is in the
                          observation, determines if an observation is valid
        """
        if len(obs_batch) == 0:
            return Batch()

        valid_obs = [(i, ex) for i, ex in enumerate(obs_batch) if is_valid(ex)]

        if len(valid_obs) == 0:
            return Batch()

        valid_inds, exs = zip(*valid_obs)

        # TEXT
        xs, x_lens = None, None
        if any('text_vec' in ex for ex in exs):
            _xs = [ex.get('text_vec', self.EMPTY) for ex in exs]
            xs, x_lens = padded_tensor(_xs, self.NULL_IDX, self.use_cuda)
            if sort:
                sort = False  # now we won't sort on labels
                xs, x_lens, valid_inds, exs = argsort(
                    x_lens, xs, x_lens, valid_inds, exs, descending=True
                )

        # LABELS
        labels_avail = any('labels_vec' in ex for ex in exs)
        some_labels_avail = (labels_avail or
                             any('eval_labels_vec' in ex for ex in exs))

        ys, y_lens, labels = None, None, None
        if some_labels_avail:
            field = 'labels' if labels_avail else 'eval_labels'

            label_vecs = [ex.get(field + '_vec', self.EMPTY) for ex in exs]
            labels = [ex.get(field + '_choice') for ex in exs]
            y_lens = [y.shape[0] for y in label_vecs]

            ys, y_lens = padded_tensor(label_vecs, self.NULL_IDX, self.use_cuda)
            if sort and xs is None:
                ys, valid_inds, label_vecs, labels, y_lens = argsort(
                    y_lens, ys, valid_inds, label_vecs, labels, y_lens,
                    descending=True
                )

        # LABEL_CANDIDATES
        cands, cand_vecs = None, None
        if any('label_candidates_vecs' in ex for ex in exs):
            cands = [ex.get('label_candidates', None) for ex in exs]
            cand_vecs = [ex.get('label_candidates_vecs', None) for ex in exs]

        # IMAGE
        imgs = None
        if any('image' in ex for ex in exs):
            imgs = [ex.get('image', None) for ex in exs]

        # MEMORIES
        mems = None
        if any('memory_vecs' in ex for ex in exs):
            mems = [ex.get('memory_vecs', None) for ex in exs]

        return Batch(text_vec=xs, text_lengths=x_lens, label_vec=ys,
                     label_lengths=y_lens, labels=labels,
                     valid_indices=valid_inds, candidates=cands,
                     candidate_vecs=cand_vecs, image=imgs, memory_vecs=mems,
                     observations=exs)
Exemplo n.º 4
0
    def forward(self, encoder_output, his_turn_end_ids):
        bsz = encoder_output.size(0)
        turn_lengths = [len(his_turn_end_ids[i]) for i in range(bsz)]
        in_batch_ids = [i for i in range(bsz)]
        his_max_len = max(turn_lengths)
        if his_max_len == 1:
            dli_loss = torch.zeros(1)[0].cuda()
        else:
            his_turn_states = torch.zeros(bsz, his_max_len,
                                          self.enc_dim).cuda()
            for i in range(bsz):
                end_ids = his_turn_end_ids[i]
                start_ids = his_turn_end_ids[i] + torch.ones(1)[0].cuda()
                start_ids = start_ids[:-1]
                start_0 = torch.zeros(1).long().cuda()
                start_ids = torch.cat([start_0, start_ids])
                for j in range(len(start_ids)):
                    s = start_ids[j]
                    e = end_ids[j]
                    tmp = torch.mean(encoder_output[i][s:e + 1], dim=0)
                    his_turn_states[i][j] = tmp

            sorted_his_turn_states, sorted_in_batch_ids, sorted_turn_lengths = argsort(
                turn_lengths,
                his_turn_states,
                in_batch_ids,
                turn_lengths,
                descending=True)
            his_turn_states_packed = nn.utils.rnn.pack_sequence(
                sorted_his_turn_states)
            out_packed, _ = self.uni_lstm(his_turn_states_packed)
            out_padded, _ = pad_packed_sequence(out_packed, batch_first=True)
            after_sort_idxs = torch.LongTensor(
                argsort(sorted_in_batch_ids, in_batch_ids,
                        descending=False)[0]).cuda()
            turns_encoder_out = torch.index_select(out_padded, 0,
                                                   after_sort_idxs)

            all_pairs = []
            all_gt = []
            for i in range(bsz):
                for j in range(turn_lengths[i]):
                    current_step_encoder_out = turns_encoder_out[i][j]
                    tmp_pairs = []
                    tmp_gt = []
                    for k in range(j + 1, turn_lengths[i]):
                        tmp_pairs.append(
                            torch.cat([
                                current_step_encoder_out, his_turn_states[i][k]
                            ], -1))
                        if k == j + 1:
                            tmp_gt.append(1)
                        else:
                            tmp_gt.append(0)
                    if len(tmp_pairs) != 0 and len(tmp_gt) != 0:
                        all_pairs.append(torch.stack(tmp_pairs))
                        all_gt.append(tmp_gt)

            loss = []
            for i in range(len(all_pairs)):
                final_out_i = self.con_fc(all_pairs[i]).squeeze(1).unsqueeze(0)
                ground_truth_i = torch.LongTensor([0]).cuda()
                len_i = final_out_i.size(0)
                dli_loss_i = self.c_loss(input=final_out_i,
                                         target=ground_truth_i)
                loss.append(dli_loss_i)
            dli_loss = torch.stack(loss)
            dli_loss = torch.mean(dli_loss)

        return dli_loss
Exemplo n.º 5
0
    def forward(self, input, his_turn_end_ids):
        """
            input data is a FloatTensor of shape [batch, seq_len, dim]
            mask is a ByteTensor of shape [batch, seq_len], filled with 1 when
            inside the sequence and 0 outside.
        """
        # print(input)
        # print(his_turn_end_ids)
        bsz = len(input)
        turn_lengths = [len(his_turn_end_ids[i]) for i in range(bsz)]
        his_turns = torch.zeros(bsz, self.max_turns,
                                self.max_single_seq_len).long().cuda()
        mask = torch.zeros(bsz, self.max_turns).cuda()
        for i in range(bsz):
            end_ids = his_turn_end_ids[i]
            start_ids = his_turn_end_ids[i] + torch.ones(1)[0].cuda()
            start_ids = start_ids[:-1]
            start_0 = torch.zeros(1).long().cuda()
            start_ids = torch.cat([start_0, start_ids])
            his_len = len(start_ids)
            if his_len <= self.max_turns:
                for j in range(his_len):
                    s = start_ids[j]
                    e = end_ids[j]
                    if e - s < self.max_single_seq_len:
                        his_turns[i][j][0:e + 1 - s] = input[i][s:e + 1]
                    else:
                        his_turns[i][j][0:self.max_single_seq_len] = input[i][
                            s:s + self.max_single_seq_len]
                    mask[i][j] = torch.ones(1)[0].cuda()
                for k in range(his_len, self.max_turns):
                    his_turns[i][k][0] = torch.ones(1)[0].long().cuda()
            else:
                longer = his_len - self.max_turns
                for j in range(his_len - self.max_turns, his_len):
                    s = start_ids[j]
                    e = end_ids[j]
                    if e - s < self.max_single_seq_len:
                        his_turns[i][j - longer][0:e + 1 - s] = input[i][s:e +
                                                                         1]
                    else:
                        his_turns[i][
                            j - longer][0:self.max_single_seq_len] = input[i][
                                s:s + self.max_single_seq_len]
                    mask[i][j - longer] = torch.ones(1)[0].cuda()
        his_turns = his_turns.view(-1, self.max_single_seq_len)

        xs = self.rnn_input_dropout(his_turns)
        xes = self.rnn_dropout(self.embeddings(xs))
        attn_mask = xs.ne(0)
        x_lens = torch.sum(attn_mask.int(), dim=1)

        in_flatten_ids = [k for k in range(len(xs))]
        sorted_xes, sorted_in_flatten_ids, sorted_x_lens = argsort(
            x_lens, xes, in_flatten_ids, x_lens, descending=True)

        xes_packed = pack_padded_sequence(sorted_xes,
                                          sorted_x_lens,
                                          batch_first=True)
        # xes_packed = pack_sequence(sorted_xes)
        out_packed, _ = self.rnn(xes_packed)
        out_padded, _ = pad_packed_sequence(out_packed, batch_first=True)
        after_sort_idxs = torch.LongTensor(
            argsort(sorted_in_flatten_ids, in_flatten_ids,
                    descending=False)[0]).cuda()
        his_encoder_outs = torch.index_select(out_padded, 0, after_sort_idxs)
        real_max_seq_len = his_encoder_outs.size(1)

        his_encoder_outs = his_encoder_outs.view(bsz, self.max_turns,
                                                 real_max_seq_len,
                                                 self.rnn_hsz)

        expand_mask = mask.unsqueeze(-1).expand(
            bsz, self.max_turns, real_max_seq_len * self.rnn_hsz)
        expand_mask = expand_mask.view(bsz, self.max_turns, real_max_seq_len,
                                       self.rnn_hsz)

        his_encoder_outs = his_encoder_outs.mul(expand_mask)

        final_encoder_outs = []
        for i in range(bsz):
            for j in range(self.max_turns):
                for k in range(real_max_seq_len):
                    if len(torch.nonzero(his_encoder_outs[i][j][k])) != 0:
                        tmp = his_encoder_outs[i][j][k]
                    else:
                        break
                final_encoder_outs.append(tmp)
        final_encoder_outs = torch.stack(final_encoder_outs).view(
            bsz, self.max_turns, -1)

        positions = mask.new(self.max_turns).long()
        positions = torch.arange(self.max_turns, out=positions).unsqueeze(0)
        tensor = final_encoder_outs
        if self.embeddings_scale:
            tensor = tensor * np.sqrt(self.dim)
        tensor = tensor + self.position_embeddings(positions).expand_as(tensor)

        tensor *= mask.unsqueeze(-1).float()
        for i in range(self.n_layers):
            tensor = self.layers[i](tensor, mask)

        if self.reduction:
            divisor = mask.float().sum(dim=1).unsqueeze(-1).clamp(min=1e-20)
            output = tensor.sum(dim=1) / divisor
            return output
        else:
            output = tensor
            return output, mask
Exemplo n.º 6
0
    def eval_step(self, batch):
        """Process batch of inputs.

        If the batch includes labels, calculate validation metrics as well.
        If --skip-generation is not set, return a prediction for each input.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = False
        samples = self._make_sample(batch.text_vec, batch.label_vec)
        self.model.eval()
        if batch.label_vec is not None:
            # Interactive mode won't have a gold label
            self.trainer.valid_step(samples)

        # Output placeholders
        reranked_cands = None
        generated_output = None

        # Grade each of the candidate sequences
        if batch.candidate_vecs is not None:
            bsz = len(batch.text_vec)
            reranked_cands = []
            # score the candidates for each item in the batch separately, so that
            # we can support variable number of candidates
            for i in range(bsz):
                cands = batch.candidate_vecs[i]
                if not cands:
                    reranked_cands.append(None)
                    continue
                ncand = len(cands)
                # repeat the input many times
                xs = batch.text_vec[i].unsqueeze(0).expand(ncand, -1)
                # some models crash if there's leading padding on every example
                xs = xs[:, :batch.text_lengths[i]]
                # and appropriately pack the outputs
                ys, _ = padded_tensor(cands, self.NULL_IDX, self.use_cuda)
                s = self._make_sample(xs, ys)
                # perform the actual grading, extract the scores
                scored = list(
                    self.scorer.score_batched_itr([s], cuda=self.use_cuda))
                scores = [s[3][0]['score'].item() for s in scored]
                # intentional hanging comma here; argsort returns a list
                ranked, = argsort(scores, batch.candidates[i], descending=True)
                reranked_cands.append(ranked)

        # Next generate freely to create our response
        if not self.args.skip_generation:
            generated_output = self._generate(samples)
        elif reranked_cands:
            # we're skiping generation, but we're also grading candidates
            # so output the highest ranked candidate
            # In the case of zero candidates, we don't have something to rank,
            # so we may need to pass on that None
            generated_output = [
                ranked and ranked[0] or None for ranked in reranked_cands
            ]
        else:
            # no output at all
            pass

        return Output(generated_output, reranked_cands)
Exemplo n.º 7
0
    def forward(self, input, his_turn_end_ids):
        """Encode sequence.

        :param input: (bsz x seqlen) LongTensor of input token indices

        :returns: encoder outputs, hidden state, attention mask
            encoder outputs are the output state at each step of the encoding.
            the hidden state is the final hidden state of the encoder.
            the attention mask is a mask of which input values are nonzero.
        """

        bsz = len(input)
        turn_lengths = [len(his_turn_end_ids[i]) for i in range(bsz)]
        his_turns = torch.zeros(bsz, self.max_turns,
                                self.max_single_seq_len).long().cuda()
        mask = torch.zeros(bsz, self.max_turns).cuda()
        for i in range(bsz):
            end_ids = his_turn_end_ids[i]
            start_ids = his_turn_end_ids[i] + torch.ones(1)[0].cuda()
            start_ids = start_ids[:-1]
            start_0 = torch.zeros(1).long().cuda()
            start_ids = torch.cat([start_0, start_ids])
            his_len = len(start_ids)
            if his_len <= self.max_turns:
                for j in range(his_len):
                    s = start_ids[j]
                    e = end_ids[j]
                    if e - s < self.max_single_seq_len:
                        his_turns[i][j][0:e + 1 - s] = input[i][s:e + 1]
                    else:
                        his_turns[i][j][0:self.max_single_seq_len] = input[i][
                            s:s + self.max_single_seq_len]
                    mask[i][j] = torch.ones(1)[0].cuda()
                for k in range(his_len, self.max_turns):
                    his_turns[i][k][0] = torch.ones(1)[0].long().cuda()
            else:
                longer = his_len - self.max_turns
                for j in range(his_len - self.max_turns, his_len):
                    s = start_ids[j]
                    e = end_ids[j]
                    if e - s < self.max_single_seq_len:
                        his_turns[i][j - longer][0:e + 1 - s] = input[i][s:e +
                                                                         1]
                    else:
                        his_turns[i][
                            j - longer][0:self.max_single_seq_len] = input[i][
                                s:s + self.max_single_seq_len]
                    mask[i][j - longer] = torch.ones(1)[0].cuda()
        his_turns = his_turns.view(-1, self.max_single_seq_len)

        xs = self.input_dropout(his_turns)
        xes = self.dropout(self.lt(xs))
        attn_mask = xs.ne(0)
        x_lens = torch.sum(attn_mask.int(), dim=1)

        in_flatten_ids = [k for k in range(len(xs))]
        sorted_xes, sorted_in_flatten_ids, sorted_x_lens = argsort(
            x_lens, xes, in_flatten_ids, x_lens, descending=True)

        xes_packed = pack_padded_sequence(sorted_xes,
                                          sorted_x_lens,
                                          batch_first=True)
        # xes_packed = pack_sequence(sorted_xes)
        out_packed, _ = self.rnn(xes_packed)
        out_padded, _ = pad_packed_sequence(out_packed, batch_first=True)
        after_sort_idxs = torch.LongTensor(
            argsort(sorted_in_flatten_ids, in_flatten_ids,
                    descending=False)[0]).cuda()
        his_encoder_outs = torch.index_select(out_padded, 0, after_sort_idxs)
        real_max_seq_len = his_encoder_outs.size(1)
        his_encoder_outs = his_encoder_outs.view(bsz, self.max_turns,
                                                 real_max_seq_len, self.hsz)

        expand_mask = mask.unsqueeze(-1).expand(bsz, self.max_turns,
                                                real_max_seq_len * self.hsz)
        expand_mask = expand_mask.view(bsz, self.max_turns, real_max_seq_len,
                                       self.hsz)

        his_encoder_outs = his_encoder_outs.mul(expand_mask)

        final_encoder_outs = []
        for i in range(bsz):
            for j in range(self.max_turns):
                for k in range(real_max_seq_len):
                    if len(torch.nonzero(his_encoder_outs[i][j][k])) != 0:
                        tmp = his_encoder_outs[i][j][k]
                    else:
                        break
                final_encoder_outs.append(tmp)
        final_encoder_outs = torch.stack(final_encoder_outs).view(
            bsz, self.max_turns, -1)

        hier_xes = final_encoder_outs
        hier_x_lens = torch.sum(mask.int(), dim=1)

        in_example_ids = [k for k in range(len(hier_xes))]

        sorted_hier_xes, sorted_in_example_ids, sorted_hier_x_lens = argsort(
            hier_x_lens,
            hier_xes,
            in_example_ids,
            hier_x_lens,
            descending=True)

        hier_xes_packed = pack_padded_sequence(sorted_hier_xes,
                                               sorted_hier_x_lens,
                                               batch_first=True)

        hier_out_packed, hier_hidden_packed = self.hier_rnn(hier_xes_packed)
        hier_out_padded, _ = pad_packed_sequence(hier_out_packed,
                                                 batch_first=True)

        hier_after_sort_idxs = torch.LongTensor(
            argsort(sorted_in_example_ids, in_example_ids,
                    descending=False)[0]).cuda()
        hier_his_encoder_outs = torch.index_select(hier_out_padded, 0,
                                                   hier_after_sort_idxs)

        real_max_his_n_turn = hier_his_encoder_outs.size(1)
        hier_final_encoder_outs = hier_his_encoder_outs.view(
            bsz, real_max_his_n_turn, -1)
        transpose_hidden = _transpose_hidden_state(hier_hidden_packed)

        hier_fianl_hidden = torch.index_select(transpose_hidden, 0,
                                               hier_after_sort_idxs)

        hier_attn_mask = torch.zeros(bsz, real_max_his_n_turn).cuda()
        for i in range(bsz):
            for j in range(real_max_his_n_turn):
                hier_attn_mask[i][j] = mask[i][j]
        # print(hier_attn_mask.size())
        # print(turn_lengths)

        if self.hier_dirs > 1:
            # project to decoder dimension by taking sum of forward and back
            if isinstance(self.hier_rnn, nn.LSTM):
                hier_fianl_hidden = (hier_fianl_hidden[0].view(
                    -1, self.hier_dirs, bsz,
                    self.hier_hsz).sum(1), hier_fianl_hidden[1].view(
                        -1, self.hier_dirs, bsz, self.hier_hsz).sum(1))
            else:
                hier_fianl_hidden = hier_fianl_hidden.view(
                    -1, self.hier_dirs, bsz, self.hier_hsz).sum(1)
            hier_fianl_hidden = _transpose_hidden_state(hier_final_hidden)

        return hier_final_encoder_outs, hier_fianl_hidden, hier_attn_mask