Esempio n. 1
0
 def _get_feed_dict(self,
                    sl,
                    sr,
                    desc,
                    cand,
                    is_train,
                    y=None,
                    lr=None,
                    keep_prob=None):
     sl, sl_seq_len = pad_sequence(sl,
                                   max_length=max_sent_len,
                                   pad_tok=0,
                                   pad_left=True,
                                   nlevels=1)
     sr, sr_seq_len = pad_sequence(sr,
                                   max_length=max_sent_len,
                                   pad_tok=0,
                                   pad_left=True,
                                   nlevels=1)
     sent_seq_len = [x + y + 1 for x, y in zip(sl_seq_len, sr_seq_len)]
     desc, desc_seq_len = pad_sequence(desc,
                                       max_length_2=max_sent_len,
                                       pad_tok=0,
                                       pad_left=True,
                                       nlevels=2)
     feed_dict = {
         self.sl: sl,
         self.sr: sr,
         self.sent_seq_len: sent_seq_len,
         self.desc: desc,
         self.desc_seq_len: desc_seq_len,
         self.cand: cand,
         self.is_train: is_train
     }
     if y is not None:
         feed_dict[self.y] = y
     if lr is not None:
         feed_dict[self.lr] = lr
     if keep_prob is not None:
         feed_dict[self.keep_prob] = keep_prob
     return feed_dict
Esempio n. 2
0
    def batch_act(self, observations):
        def is_valid_history(history):
            return len(history['dialog'])

        def to_tensor(string):
            ids = [self.vocab.bos_id] + self.vocab.string2ids(string) + [self.vocab.eos_id]
            return torch.tensor(ids, dtype=torch.long)

        batch_reply = [{'id': self.getID(), 'text': '', 'text_candidates': []} for _ in range(len(observations))]
        valid_ids = [i for i, obs in enumerate(observations) if is_valid_history(obs['agent'].history)]
        batch_size = len(valid_ids)

        if batch_size == 0:
            return batch_reply

        try:
            valid_observations = [observations[i] for i in valid_ids]

            infos = [obs['agent'].history['info'][:self.model.n_pos_embeddings-3] for obs in valid_observations]
            infos = [([self.vocab.info_bos_id] + ifo + [self.vocab.info_eos_id] if len(ifo) else ifo) for ifo in infos]
            dialogs = [list(obs['agent'].history['dialog'])[-self.model.n_pos_embeddings+1:] for obs in valid_observations]
            contexts = []

            if max(map(len, infos)) > 0:
                infos = [torch.tensor(i, dtype=torch.long) for i in infos]
                infos = pad_sequence(infos, batch_first=True, padding_value=self.model.padding_idx)
                if self.use_cuda:
                    infos = infos.cuda()
                contexts.append(infos)

            if max(map(len, dialogs)) > 0:
                dialogs = [torch.tensor(d, dtype=torch.long) for d in dialogs]
                dialogs = pad_sequence(dialogs, batch_first=True, padding_value=self.model.padding_idx)
                if self.use_cuda:
                    dialogs = dialogs.cuda()
                contexts.append(dialogs)

            enc_contexts = [self.model.encode(c) for c in contexts]
            pred_texts = self.model.beam_search(enc_contexts)

            for i in range(batch_size):
                pred_text_str, pred_text = self._postprocess_text(pred_texts[i], valid_observations[i]['agent'])

                valid_observations[i]['agent'].history['dialog'].extend([self.vocab.talker2_bos_id] +
                                                                        pred_text +
                                                                        [self.vocab.talker2_eos_id])
                batch_reply[valid_ids[i]]['text'] = pred_text_str
                batch_reply[valid_ids[i]]['episode_done'] = valid_observations[i]['agent'].episode_done

            if self.opt['rank_candidates']:
                candidates = [list(obs.get('label_candidates', [])) for obs in valid_observations]
                lens_candidates = [len(c) for c in candidates]

                if max(lens_candidates) > 0:
                    candidates = [c + ['' for _ in range(max(lens_candidates) - len(c))] for c in candidates]
                    scores = [[] for _ in range(len(candidates))]

                    for i in range(max(lens_candidates)):
                        current_cands = [to_tensor(c[i])[:self.model.n_pos_embeddings-1] for c in candidates]
                        current_cands = pad_sequence(current_cands, batch_first=True, padding_value=self.model.padding_idx)
                        if self.use_cuda:
                            current_cands = current_cands.cuda()

                        logits = self.model.decode(current_cands[:, :-1], enc_contexts)
                        log_probas = F.log_softmax(logits, dim=-1)
                        log_probas = torch.gather(log_probas, -1, current_cands[:, 1:].unsqueeze(-1)).squeeze(-1)
                        log_probas.masked_fill_(current_cands[:, 1:].eq(self.model.padding_idx), 0)

                        current_lens = current_cands[:, 1:].ne(self.model.padding_idx).float().sum(dim=-1)
                        current_scores = log_probas.sum(dim=-1) / current_lens

                        for k, s in enumerate(current_scores):
                            if i < lens_candidates[k]:
                                scores[k].append(s.item())

                    ranked_ids = [sorted(range(len(s)), key=lambda k: s[k], reverse=True) for s in scores]
                    ranked_strings = [[c[i] for i in ids] for ids, c in zip(ranked_ids, candidates)]

                    for i in range(batch_size):
                        batch_reply[valid_ids[i]]['text_candidates'] = ranked_strings[i]

        except Exception as e:
            # raise e
            print(e)

        return batch_reply
    def batch_act(self, observations):
        def is_valid_history(history):
            return len(history['dialog'])

        def to_tensor(string):
            ids = [self.vocab.bos_id] + self.vocab.string2ids(string) + [self.vocab.eos_id]
            ids = self._add_dialog_embeddings(ids, self.vocab.sent_dialog_id)
            return torch.tensor(ids, dtype=torch.long)

        def to_cuda(data):
            if not self.use_cuda:
                return data

            if isinstance(data, (list, tuple, map)):
                return list(map(lambda x: x.cuda(), data))

            return data.cuda()

        batch_reply = [{'id': self.getID(), 'text': '', 'text_candidates': []} for _ in range(len(observations))]
        valid_ids = [i for i, obs in enumerate(observations) if is_valid_history(obs['agent'].history)]
        batch_size = len(valid_ids)

        if batch_size == 0:
            return batch_reply

        try:
            valid_observations = [observations[i] for i in valid_ids]

            infos = [obs['agent'].history['info'] for obs in valid_observations]
            dialogs = [list(obs['agent'].history['dialog'])[-self.model.n_pos_embeddings+1:] for obs in valid_observations]
            contexts = []

            if max(map(len, infos)) > 0:
                infos = [torch.tensor(i, dtype=torch.long) for i in infos]
                contexts.append(infos)

            if max(map(len, dialogs)) > 0:
                dialogs = [torch.tensor(d, dtype=torch.long) for d in dialogs]
                contexts.append(dialogs)

            if self.single_input:
                contexts = [torch.cat(c, dim=0) for c in zip(*contexts)]
                raw_context = contexts if self.opt['rank_candidates'] else None
                contexts = pad_sequence(contexts, batch_first=True, padding_value=self.model.padding_idx, left=True)
            else:
                contexts = map(lambda x: pad_sequence(x, batch_first=True, padding_value=self.model.padding_idx),
                               contexts)

            contexts = to_cuda(contexts)

            pred_texts = self.model.predict(contexts)

            for i in range(batch_size):
                pred_toks = self._process_2nd_replica(pred_texts[i])
                valid_observations[i]['agent'].history['dialog'].extend(pred_toks)
                batch_reply[valid_ids[i]]['text'] = self.vocab.ids2string(pred_texts[i])
                batch_reply[valid_ids[i]]['episode_done'] = valid_observations[i]['agent'].episode_done

            if self.opt['rank_candidates']:
                enc_contexts = [self.model.encode(c) for c in contexts] if not self.single_input else []

                candidates = [list(obs.get('label_candidates', [])) for obs in valid_observations]
                lens_candidates = [len(c) for c in candidates]

                if max(lens_candidates) > 0:
                    candidates = [c + ['' for _ in range(max(lens_candidates) - len(c))] for c in candidates]
                    scores = [[] for _ in range(len(candidates))]

                    for i in range(max(lens_candidates)):
                        current_cands = [to_tensor(c[i])[:self.model.n_pos_embeddings-1] for c in candidates]

                        lens = map(lambda x: x.size(0), current_cands) if self.single_input else None
                        if self.single_input:
                            lens = map(lambda x: x.size(0), current_cands)
                            current_cands = [torch.cat(c, dim=0)[-self.model.n_pos_embeddings:]
                                             for c in zip(raw_context, current_cands)]

                        current_cands = to_cuda(current_cands)
                        current_cands = pad_sequence(current_cands, batch_first=True,
                                                     padding_value=self.model.padding_idx)

                        logits = self.model.decode(current_cands[:, :-1], enc_contexts)

                        if current_cands.dim() == 3:
                            current_cands = current_cands[:, :, 0]

                        log_probas = F.log_softmax(logits, dim=-1)
                        log_probas = torch.gather(log_probas, -1, current_cands[:, 1:].unsqueeze(-1)).squeeze(-1)

                        if self.single_input:
                            # zero context
                            for j, l in enumerate(lens):
                                current_cands[j, :-l+1] = self.model.padding_idx

                        log_probas.masked_fill_(current_cands[:, 1:].eq(self.model.padding_idx), 0)

                        current_lens = current_cands[:, 1:].ne(self.model.padding_idx).float().sum(dim=-1)
                        current_scores = log_probas.sum(dim=-1) / current_lens

                        for k, s in enumerate(current_scores):
                            if i < lens_candidates[k]:
                                scores[k].append(s.item())

                    ranked_ids = [sorted(range(len(s)), key=lambda k: s[k], reverse=True) for s in scores]
                    ranked_strings = [[c[i] for i in ids] for ids, c in zip(ranked_ids, candidates)]

                    for i in range(batch_size):
                        batch_reply[valid_ids[i]]['text_candidates'] = ranked_strings[i]

        except Exception as e:
            # raise e
            print(e)

        return batch_reply