コード例 #1
0
ファイル: data_loaders.py プロジェクト: wenxianxian/demvae
    def flatten_dialog(self, data, backward_size):
        results = []
        for dialog in data:
            for i in range(1, len(dialog) - 1):
                e_id = i
                s_id = max(0, e_id - backward_size)

                response = dialog[i]
                prev = dialog[i - 1]
                next = dialog[i + 1]

                response['utt'] = self.pad_to(self.max_utt_size,
                                              response.utt,
                                              do_pad=False)
                prev['utt'] = self.pad_to(self.max_utt_size,
                                          prev.utt,
                                          do_pad=False)
                next['utt'] = self.pad_to(self.max_utt_size,
                                          next.utt,
                                          do_pad=False)

                contexts = []
                for turn in dialog[s_id:e_id]:
                    turn['utt'] = self.pad_to(self.max_utt_size,
                                              turn.utt,
                                              do_pad=False)
                    contexts.append(turn)

                results.append(
                    Pack(context=contexts,
                         response=response,
                         prev_resp=prev,
                         next_resp=next))
        return results
コード例 #2
0
ファイル: data_loaders.py プロジェクト: wenxianxian/demvae
    def _prepare_batch(self, selected_index):
        rows = [self.data[idx] for idx in selected_index]
        input_lens = np.array([len(row.utt) for row in rows], dtype=np.int32)
        max_len = np.max(input_lens)
        inputs = np.zeros((self.batch_size, max_len), dtype=np.int32)
        for idx, row in enumerate(rows):
            inputs[idx, 0:input_lens[idx]] = row.utt

        return Pack(outputs=inputs,
                    output_lens=input_lens,
                    metas=[data["meta"] for data in rows])
コード例 #3
0
ファイル: data_loaders.py プロジェクト: wenxianxian/demvae
 def flatten_dialog(self, data, backward_size):
     results = []
     for dialog in data:
         for i in range(1, len(dialog)):
             e_id = i
             s_id = max(0, e_id - backward_size)
             response = dialog[i].copy()
             # response['utt_orisent'] = response.utt
             response['utt'] = self.pad_to(self.max_utt_size,
                                           response.utt,
                                           do_pad=False)
             contexts = []
             for turn in dialog[s_id:e_id]:
                 turn['utt'] = self.pad_to(self.max_utt_size,
                                           turn.utt,
                                           do_pad=False)
                 contexts.append(turn)
             results.append(Pack(context=contexts, response=response))
     return results
コード例 #4
0
ファイル: data_loaders.py プロジェクト: wenxianxian/demvae
    def _prepare_batch(self, selected_index):
        rows = [self.data[idx] for idx in selected_index]
        # input_context, context_lens, floors, topics, a_profiles, b_Profiles, outputs, output_lens
        context_lens, context_utts, out_utts, out_lens = [], [], [], []
        metas = []
        for row in rows:
            ctx = row.context
            resp = row.response

            out_utt = resp.utt
            context_lens.append(len(ctx))
            context_utts.append([turn.utt for turn in ctx])

            out_utt = out_utt
            out_utts.append(out_utt)
            out_lens.append(len(out_utt))
            metas.append(resp.meta)
            # ori_out_utts.append(resp.utt_orisent)

        vec_context_lens = np.array(context_lens)
        vec_context = np.zeros(
            (self.batch_size, np.max(vec_context_lens), self.max_utt_size),
            dtype=np.int32)
        vec_outs = np.zeros((self.batch_size, np.max(out_lens)),
                            dtype=np.int32)
        vec_out_lens = np.array(out_lens)

        for b_id in range(self.batch_size):
            vec_outs[b_id, 0:vec_out_lens[b_id]] = out_utts[b_id]
            # fill the context tensor
            new_array = np.empty((vec_context_lens[b_id], self.max_utt_size))
            new_array.fill(0)
            for i, row in enumerate(context_utts[b_id]):
                for j, ele in enumerate(row):
                    new_array[i, j] = ele
            vec_context[b_id, 0:vec_context_lens[b_id], :] = new_array

        return Pack(contexts=vec_context,
                    context_lens=vec_context_lens,
                    outputs=vec_outs,
                    output_lens=vec_out_lens,
                    metas=metas)
コード例 #5
0
ファイル: data_loaders.py プロジェクト: wenxianxian/demvae
    def _prepare_batch(self, selected_index):
        rows = [self.data[idx] for idx in selected_index]

        context_lens, context_utts, out_utts, out_lens = [], [], [], []
        prev_utts, prev_lens = [], []
        next_utts, next_lens = [], []
        metas = []
        for row in rows:
            ctx = row.context
            resp = row.response

            out_utt = resp.utt
            context_lens.append(len(ctx))
            context_utts.append([turn.utt for turn in ctx])

            out_utt = out_utt
            out_utts.append(out_utt)
            out_lens.append(len(out_utt))
            metas.append(resp.meta)

            prev_utts.append(row.prev_resp.utt)
            prev_lens.append(len(row.prev_resp.utt))

            next_utts.append(row.next_resp.utt)
            next_lens.append(len(row.next_resp.utt))

        vec_context_lens = np.array(context_lens)
        vec_context = np.zeros(
            (self.batch_size, np.max(vec_context_lens), self.max_utt_size),
            dtype=np.int32)
        vec_outs = np.zeros((self.batch_size, np.max(out_lens)),
                            dtype=np.int32)
        vec_prevs = np.zeros((self.batch_size, np.max(prev_lens)),
                             dtype=np.int32)
        vec_nexts = np.zeros((self.batch_size, np.max(next_lens)),
                             dtype=np.int32)
        vec_out_lens = np.array(out_lens)
        vec_prev_lens = np.array(prev_lens)
        vec_next_lens = np.array(next_lens)

        for b_id in range(self.batch_size):
            vec_outs[b_id, 0:vec_out_lens[b_id]] = out_utts[b_id]
            vec_prevs[b_id, 0:vec_prev_lens[b_id]] = prev_utts[b_id]
            vec_nexts[b_id, 0:vec_next_lens[b_id]] = next_utts[b_id]

            # fill the context tensor
            new_array = np.empty((vec_context_lens[b_id], self.max_utt_size))
            new_array.fill(0)
            for i, row in enumerate(context_utts[b_id]):
                for j, ele in enumerate(row):
                    new_array[i, j] = ele
            vec_context[b_id, 0:vec_context_lens[b_id], :] = new_array

        z_labels = np.zeros((self.batch_size, 2), dtype=np.int32)
        for b_id in range(self.batch_size):
            z_labels[b_id][0] = int(metas[b_id]["emotion"])
            z_labels[b_id][1] = int(metas[b_id]["act"])

        return Pack(contexts=vec_context,
                    context_lens=vec_context_lens,
                    outputs=vec_outs,
                    output_lens=vec_out_lens,
                    metas=metas,
                    prevs=vec_prevs,
                    prev_lens=vec_prev_lens,
                    nexts=vec_nexts,
                    next_lens=vec_next_lens,
                    z_labels=z_labels)
コード例 #6
0
ファイル: sup_models.py プロジェクト: wenxianxian/demvae
    def forward(self,
                data_feed,
                mode,
                gen_type='greedy',
                sample_n=1,
                return_latent=False):
        if isinstance(data_feed, tuple):
            data_feed = data_feed[0]
        batch_size = len(data_feed['output_lens'])
        out_utts = self.np2var(data_feed['outputs'], LONG)

        z_labels = data_feed.get("z_labels", None)
        c_labels = data_feed.get("c_labels", None)
        if z_labels is not None:
            z_labels = self.np2var(z_labels, LONG)
        if c_labels is not None:
            c_labels = self.np2var(c_labels, LONG)

        # output encoder
        output_embedding = self.embedding(out_utts)
        x_outs, x_last = self.x_encoder(output_embedding)
        if type(x_last) is tuple:
            x_last = x_last[0].transpose(0, 1).contiguous().view(
                -1, self.enc_out_size)
        else:
            x_last = x_last.transpose(0, 1).contiguous().view(
                -1, self.enc_out_size)

        # posterior network
        qy_logits = self.q_y(x_last).view(-1, self.config.k)
        log_qy = F.log_softmax(qy_logits, qy_logits.dim() - 1)

        # switch that controls the sampling
        sample_y, y_ids = self.cat_connector(qy_logits,
                                             1.0,
                                             self.use_gpu,
                                             hard=not self.training,
                                             return_max_id=True)
        sample_y = sample_y.view(-1, self.config.k * self.config.mult_k)
        y_ids = y_ids.view(-1, self.config.mult_k)

        # map sample to initial state of decoder
        dec_init_state = self.dec_init_connector(sample_y)

        # get decoder inputs
        labels = out_utts[:, 1:].contiguous()
        dec_inputs = out_utts[:, 0:-1]

        # decode
        dec_outs, dec_last, dec_ctx = self.decoder(batch_size,
                                                   dec_inputs,
                                                   dec_init_state,
                                                   mode=mode,
                                                   gen_type=gen_type,
                                                   beam_size=self.beam_size)
        # compute loss or return results
        if mode == GEN:
            return dec_ctx, labels
        else:
            # RNN reconstruction
            nll = self.nll_loss(dec_outs, labels)
            if self.config.avg_type == "seq":
                ppl = self.ppl(dec_outs, labels)

            # regularization qy to be uniform
            avg_log_qy = torch.exp(
                log_qy.view(-1, self.config.mult_k, self.config.k))
            avg_log_qy = torch.log(torch.mean(avg_log_qy, dim=0) + 1e-15)
            b_pr = self.cat_kl_loss(avg_log_qy,
                                    self.log_uniform_y,
                                    batch_size,
                                    unit_average=True)

            real_ckl = self.cat_kl_loss(log_qy,
                                        self.log_uniform_y,
                                        batch_size,
                                        average=False)
            real_ckl = torch.mean(
                torch.sum(real_ckl.view(-1, self.config.mult_k), dim=-1))

            if self.config.use_mutual:
                reg_kl = b_pr
            else:
                reg_kl = real_ckl

            # find out mutual information
            # H(Z) - H(Z|X)
            mi = self.entropy_loss(avg_log_qy, unit_average=True)\
                 - self.entropy_loss(log_qy, unit_average=True)

            ce_z = self.suploss_for_z(
                log_qy.view(-1, self.config.mult_k, self.config.k),
                z_labels) if z_labels is not None else None

            results = Pack(nll=nll,
                           reg_kl=reg_kl,
                           mi=mi,
                           bpr=b_pr,
                           real_ckl=real_ckl,
                           ce_z=ce_z,
                           elbo=nll + real_ckl)

            if self.config.avg_type == "seq":
                results['PPL'] = ppl

            if return_latent:
                results['log_qy'] = log_qy
                results['dec_init_state'] = dec_init_state
                results['y_ids'] = y_ids

            return results
コード例 #7
0
ファイル: sup_models.py プロジェクト: wenxianxian/demvae
    def forward(self,
                data_feed,
                mode,
                gen_type='greedy',
                sample_n=1,
                return_latent=False):
        if type(data_feed) is tuple:
            data_feed = data_feed[0]
        batch_size = len(data_feed['output_lens'])
        out_utts = self.np2var(data_feed['outputs'], LONG)

        z_labels = data_feed.get("z_labels", None)
        c_labels = data_feed.get("c_labels", None)

        if z_labels is not None:
            z_labels = self.np2var(z_labels, LONG)
        if c_labels is not None:
            c_labels = self.np2var(c_labels, LONG)

        # output encoder
        output_embedding = self.embedding(out_utts)
        x_outs, x_last = self.x_encoder(output_embedding)
        if type(x_last) is tuple:
            x_last = x_last[0].transpose(0, 1).contiguous().view(
                -1, self.enc_out_size)
        else:
            x_last = x_last.transpose(0, 1).contiguous().view(
                -1, self.enc_out_size)

        # x_last = torch.mean(x_outs, dim=1)

        # posterior network
        qc_logits = self.q_c(x_last)  # batch_size x k
        qc = torch.softmax(qc_logits, dim=-1)  # batch_size x k
        qz_logits = self.q_z(x_last).view(
            -1, self.config.mult_k,
            self.config.latent_size)  # batch_size x mult_k x latent_size

        if mode == GEN and gen_type == "sample":
            sample_c = torch.randint(0,
                                     self.config.k, (batch_size, ),
                                     dtype=torch.long)  # [sample_n, 1]
            pz = self.eta2theta(
                self._eta[sample_c]
            )  # [k, mult_k, latent_size] -> [sample_n, mult_k, latent_size]

            sample_y, y_ids = self.cat_connector(torch.log(pz).view(
                -1, self.config.latent_size),
                                                 1.0,
                                                 self.use_gpu,
                                                 hard=not self.training,
                                                 return_max_id=True)
            sample_y = sample_y.view(
                -1, self.config.mult_k * self.config.latent_size)
            y_ids = y_ids.view(-1, self.config.mult_k)
        else:
            sample_y, y_ids = self.cat_connector(qz_logits.view(
                -1, self.config.latent_size),
                                                 1.0,
                                                 self.use_gpu,
                                                 hard=True,
                                                 return_max_id=True)
            # sample_y: [batch* mult_k, latent_size], y_ids: [batch* mult_k, 1]
            sample_y = sample_y.view(
                -1, self.config.mult_k * self.config.latent_size)
            y_ids = y_ids.view(-1, self.config.mult_k)

        # decode
        # map sample to initial state of decoder
        dec_init_state = self.dec_init_connector(sample_y)

        # get decoder inputs
        labels = out_utts[:, 1:].contiguous()
        dec_inputs = out_utts[:, 0:-1]

        dec_outs, dec_last, dec_ctx = self.decoder(
            batch_size,
            dec_inputs,
            dec_init_state,
            mode=mode,
            gen_type="greedy",
            beam_size=self.beam_size,
            latent_variable=sample_y if self.concat_decoder_input else None)
        # compute loss or return results
        if mode == GEN:
            dec_ctx[DecoderRNN.KEY_LATENT] = y_ids
            if mode == GEN and gen_type == "sample":
                dec_ctx[DecoderRNN.KEY_CLASS] = sample_c
            return dec_ctx, labels
        else:
            # RNN reconstruction
            nll = self.nll_loss(dec_outs, labels)
            ppl = self.ppl(dec_outs, labels)

            # regularization terms:
            # CKL:
            avg_log_qc = torch.log(torch.mean(qc, dim=0) + 1e-15)  # [k]
            # ckl = torch.sum(torch.exp(avg_log_qc) * (avg_log_qc - self.log_uniform_y))
            # CKL (original)
            log_qc = torch.log(qc + 1e-15)
            ckl = torch.mean(
                torch.sum(qc * (log_qc - self.log_uniform_y), dim=-1))  #

            # ZKL
            log_qz = torch.log_softmax(qz_logits, dim=-1)
            qz = torch.exp(log_qz)
            zkl = self.zkl_loss(qc, log_qz, mean_z=True)
            # ZKL (original)
            zkl_ori = self.zkl_loss(qc, log_qz, mean_z=False)

            # MI: in this model, the mutual information is calculated for z
            avg_log_qz = torch.log(torch.mean(qz, dim=0) + 1e-15)  # mult_k x k
            mi = torch.mean(torch.sum(qz * log_qz, dim=(-1, -2))) - torch.sum(
                torch.exp(avg_log_qz) * avg_log_qz)
            mi_of_c = torch.mean(torch.sum(qc * log_qc, dim=-1)) - torch.sum(
                torch.exp(avg_log_qc) * avg_log_qc)

            # dispersion term
            dispersion = self.dispersion(qc)

            if self.config.beta > 0:
                zkl = zkl + self.config.beta * dispersion

            if c_labels is not None:
                ce_c, klz_sup = self.suploss_for_c(log_qc, c_labels, log_qz)
            else:
                ce_c, klz_sup = None, None
            ce_z = self.suploss_for_z(
                log_qz, z_labels) if z_labels is not None else None
            c_entropy = torch.mean(torch.sum(qc * log_qc, dim=-1))

            results = Pack(nll=nll,
                           mi=mi,
                           ckl=ckl,
                           zkl=zkl,
                           dispersion=dispersion,
                           PPL=ppl,
                           real_zkl=zkl_ori,
                           real_ckl=ckl,
                           ce_z=ce_z,
                           ce_c=ce_c,
                           klz_sup=klz_sup,
                           elbo=nll + zkl_ori + ckl,
                           c_entropy=c_entropy,
                           mi_of_c=mi_of_c,
                           param_var=self.mean_of_params(tgt_probs=qc))

            if return_latent:
                results['log_qy'] = log_qz
                results['dec_init_state'] = dec_init_state
                results['y_ids'] = y_ids
                results['z'] = sample_y

            return results