def flatten_dialog(self, data, backward_size): results = [] for dialog in data: for i in range(1, len(dialog) - 1): e_id = i s_id = max(0, e_id - backward_size) response = dialog[i] prev = dialog[i - 1] next = dialog[i + 1] response['utt'] = self.pad_to(self.max_utt_size, response.utt, do_pad=False) prev['utt'] = self.pad_to(self.max_utt_size, prev.utt, do_pad=False) next['utt'] = self.pad_to(self.max_utt_size, next.utt, do_pad=False) contexts = [] for turn in dialog[s_id:e_id]: turn['utt'] = self.pad_to(self.max_utt_size, turn.utt, do_pad=False) contexts.append(turn) results.append( Pack(context=contexts, response=response, prev_resp=prev, next_resp=next)) return results
def _prepare_batch(self, selected_index): rows = [self.data[idx] for idx in selected_index] input_lens = np.array([len(row.utt) for row in rows], dtype=np.int32) max_len = np.max(input_lens) inputs = np.zeros((self.batch_size, max_len), dtype=np.int32) for idx, row in enumerate(rows): inputs[idx, 0:input_lens[idx]] = row.utt return Pack(outputs=inputs, output_lens=input_lens, metas=[data["meta"] for data in rows])
def flatten_dialog(self, data, backward_size): results = [] for dialog in data: for i in range(1, len(dialog)): e_id = i s_id = max(0, e_id - backward_size) response = dialog[i].copy() # response['utt_orisent'] = response.utt response['utt'] = self.pad_to(self.max_utt_size, response.utt, do_pad=False) contexts = [] for turn in dialog[s_id:e_id]: turn['utt'] = self.pad_to(self.max_utt_size, turn.utt, do_pad=False) contexts.append(turn) results.append(Pack(context=contexts, response=response)) return results
def _prepare_batch(self, selected_index): rows = [self.data[idx] for idx in selected_index] # input_context, context_lens, floors, topics, a_profiles, b_Profiles, outputs, output_lens context_lens, context_utts, out_utts, out_lens = [], [], [], [] metas = [] for row in rows: ctx = row.context resp = row.response out_utt = resp.utt context_lens.append(len(ctx)) context_utts.append([turn.utt for turn in ctx]) out_utt = out_utt out_utts.append(out_utt) out_lens.append(len(out_utt)) metas.append(resp.meta) # ori_out_utts.append(resp.utt_orisent) vec_context_lens = np.array(context_lens) vec_context = np.zeros( (self.batch_size, np.max(vec_context_lens), self.max_utt_size), dtype=np.int32) vec_outs = np.zeros((self.batch_size, np.max(out_lens)), dtype=np.int32) vec_out_lens = np.array(out_lens) for b_id in range(self.batch_size): vec_outs[b_id, 0:vec_out_lens[b_id]] = out_utts[b_id] # fill the context tensor new_array = np.empty((vec_context_lens[b_id], self.max_utt_size)) new_array.fill(0) for i, row in enumerate(context_utts[b_id]): for j, ele in enumerate(row): new_array[i, j] = ele vec_context[b_id, 0:vec_context_lens[b_id], :] = new_array return Pack(contexts=vec_context, context_lens=vec_context_lens, outputs=vec_outs, output_lens=vec_out_lens, metas=metas)
def _prepare_batch(self, selected_index): rows = [self.data[idx] for idx in selected_index] context_lens, context_utts, out_utts, out_lens = [], [], [], [] prev_utts, prev_lens = [], [] next_utts, next_lens = [], [] metas = [] for row in rows: ctx = row.context resp = row.response out_utt = resp.utt context_lens.append(len(ctx)) context_utts.append([turn.utt for turn in ctx]) out_utt = out_utt out_utts.append(out_utt) out_lens.append(len(out_utt)) metas.append(resp.meta) prev_utts.append(row.prev_resp.utt) prev_lens.append(len(row.prev_resp.utt)) next_utts.append(row.next_resp.utt) next_lens.append(len(row.next_resp.utt)) vec_context_lens = np.array(context_lens) vec_context = np.zeros( (self.batch_size, np.max(vec_context_lens), self.max_utt_size), dtype=np.int32) vec_outs = np.zeros((self.batch_size, np.max(out_lens)), dtype=np.int32) vec_prevs = np.zeros((self.batch_size, np.max(prev_lens)), dtype=np.int32) vec_nexts = np.zeros((self.batch_size, np.max(next_lens)), dtype=np.int32) vec_out_lens = np.array(out_lens) vec_prev_lens = np.array(prev_lens) vec_next_lens = np.array(next_lens) for b_id in range(self.batch_size): vec_outs[b_id, 0:vec_out_lens[b_id]] = out_utts[b_id] vec_prevs[b_id, 0:vec_prev_lens[b_id]] = prev_utts[b_id] vec_nexts[b_id, 0:vec_next_lens[b_id]] = next_utts[b_id] # fill the context tensor new_array = np.empty((vec_context_lens[b_id], self.max_utt_size)) new_array.fill(0) for i, row in enumerate(context_utts[b_id]): for j, ele in enumerate(row): new_array[i, j] = ele vec_context[b_id, 0:vec_context_lens[b_id], :] = new_array z_labels = np.zeros((self.batch_size, 2), dtype=np.int32) for b_id in range(self.batch_size): z_labels[b_id][0] = int(metas[b_id]["emotion"]) z_labels[b_id][1] = int(metas[b_id]["act"]) return Pack(contexts=vec_context, context_lens=vec_context_lens, outputs=vec_outs, output_lens=vec_out_lens, metas=metas, prevs=vec_prevs, prev_lens=vec_prev_lens, nexts=vec_nexts, next_lens=vec_next_lens, z_labels=z_labels)
def forward(self, data_feed, mode, gen_type='greedy', sample_n=1, return_latent=False): if isinstance(data_feed, tuple): data_feed = data_feed[0] batch_size = len(data_feed['output_lens']) out_utts = self.np2var(data_feed['outputs'], LONG) z_labels = data_feed.get("z_labels", None) c_labels = data_feed.get("c_labels", None) if z_labels is not None: z_labels = self.np2var(z_labels, LONG) if c_labels is not None: c_labels = self.np2var(c_labels, LONG) # output encoder output_embedding = self.embedding(out_utts) x_outs, x_last = self.x_encoder(output_embedding) if type(x_last) is tuple: x_last = x_last[0].transpose(0, 1).contiguous().view( -1, self.enc_out_size) else: x_last = x_last.transpose(0, 1).contiguous().view( -1, self.enc_out_size) # posterior network qy_logits = self.q_y(x_last).view(-1, self.config.k) log_qy = F.log_softmax(qy_logits, qy_logits.dim() - 1) # switch that controls the sampling sample_y, y_ids = self.cat_connector(qy_logits, 1.0, self.use_gpu, hard=not self.training, return_max_id=True) sample_y = sample_y.view(-1, self.config.k * self.config.mult_k) y_ids = y_ids.view(-1, self.config.mult_k) # map sample to initial state of decoder dec_init_state = self.dec_init_connector(sample_y) # get decoder inputs labels = out_utts[:, 1:].contiguous() dec_inputs = out_utts[:, 0:-1] # decode dec_outs, dec_last, dec_ctx = self.decoder(batch_size, dec_inputs, dec_init_state, mode=mode, gen_type=gen_type, beam_size=self.beam_size) # compute loss or return results if mode == GEN: return dec_ctx, labels else: # RNN reconstruction nll = self.nll_loss(dec_outs, labels) if self.config.avg_type == "seq": ppl = self.ppl(dec_outs, labels) # regularization qy to be uniform avg_log_qy = torch.exp( log_qy.view(-1, self.config.mult_k, self.config.k)) avg_log_qy = torch.log(torch.mean(avg_log_qy, dim=0) + 1e-15) b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) real_ckl = self.cat_kl_loss(log_qy, self.log_uniform_y, batch_size, average=False) real_ckl = torch.mean( torch.sum(real_ckl.view(-1, self.config.mult_k), dim=-1)) if self.config.use_mutual: reg_kl = b_pr else: reg_kl = real_ckl # find out mutual information # H(Z) - H(Z|X) mi = self.entropy_loss(avg_log_qy, unit_average=True)\ - self.entropy_loss(log_qy, unit_average=True) ce_z = self.suploss_for_z( log_qy.view(-1, self.config.mult_k, self.config.k), z_labels) if z_labels is not None else None results = Pack(nll=nll, reg_kl=reg_kl, mi=mi, bpr=b_pr, real_ckl=real_ckl, ce_z=ce_z, elbo=nll + real_ckl) if self.config.avg_type == "seq": results['PPL'] = ppl if return_latent: results['log_qy'] = log_qy results['dec_init_state'] = dec_init_state results['y_ids'] = y_ids return results
def forward(self, data_feed, mode, gen_type='greedy', sample_n=1, return_latent=False): if type(data_feed) is tuple: data_feed = data_feed[0] batch_size = len(data_feed['output_lens']) out_utts = self.np2var(data_feed['outputs'], LONG) z_labels = data_feed.get("z_labels", None) c_labels = data_feed.get("c_labels", None) if z_labels is not None: z_labels = self.np2var(z_labels, LONG) if c_labels is not None: c_labels = self.np2var(c_labels, LONG) # output encoder output_embedding = self.embedding(out_utts) x_outs, x_last = self.x_encoder(output_embedding) if type(x_last) is tuple: x_last = x_last[0].transpose(0, 1).contiguous().view( -1, self.enc_out_size) else: x_last = x_last.transpose(0, 1).contiguous().view( -1, self.enc_out_size) # x_last = torch.mean(x_outs, dim=1) # posterior network qc_logits = self.q_c(x_last) # batch_size x k qc = torch.softmax(qc_logits, dim=-1) # batch_size x k qz_logits = self.q_z(x_last).view( -1, self.config.mult_k, self.config.latent_size) # batch_size x mult_k x latent_size if mode == GEN and gen_type == "sample": sample_c = torch.randint(0, self.config.k, (batch_size, ), dtype=torch.long) # [sample_n, 1] pz = self.eta2theta( self._eta[sample_c] ) # [k, mult_k, latent_size] -> [sample_n, mult_k, latent_size] sample_y, y_ids = self.cat_connector(torch.log(pz).view( -1, self.config.latent_size), 1.0, self.use_gpu, hard=not self.training, return_max_id=True) sample_y = sample_y.view( -1, self.config.mult_k * self.config.latent_size) y_ids = y_ids.view(-1, self.config.mult_k) else: sample_y, y_ids = self.cat_connector(qz_logits.view( -1, self.config.latent_size), 1.0, self.use_gpu, hard=True, return_max_id=True) # sample_y: [batch* mult_k, latent_size], y_ids: [batch* mult_k, 1] sample_y = sample_y.view( -1, self.config.mult_k * self.config.latent_size) y_ids = y_ids.view(-1, self.config.mult_k) # decode # map sample to initial state of decoder dec_init_state = self.dec_init_connector(sample_y) # get decoder inputs labels = out_utts[:, 1:].contiguous() dec_inputs = out_utts[:, 0:-1] dec_outs, dec_last, dec_ctx = self.decoder( batch_size, dec_inputs, dec_init_state, mode=mode, gen_type="greedy", beam_size=self.beam_size, latent_variable=sample_y if self.concat_decoder_input else None) # compute loss or return results if mode == GEN: dec_ctx[DecoderRNN.KEY_LATENT] = y_ids if mode == GEN and gen_type == "sample": dec_ctx[DecoderRNN.KEY_CLASS] = sample_c return dec_ctx, labels else: # RNN reconstruction nll = self.nll_loss(dec_outs, labels) ppl = self.ppl(dec_outs, labels) # regularization terms: # CKL: avg_log_qc = torch.log(torch.mean(qc, dim=0) + 1e-15) # [k] # ckl = torch.sum(torch.exp(avg_log_qc) * (avg_log_qc - self.log_uniform_y)) # CKL (original) log_qc = torch.log(qc + 1e-15) ckl = torch.mean( torch.sum(qc * (log_qc - self.log_uniform_y), dim=-1)) # # ZKL log_qz = torch.log_softmax(qz_logits, dim=-1) qz = torch.exp(log_qz) zkl = self.zkl_loss(qc, log_qz, mean_z=True) # ZKL (original) zkl_ori = self.zkl_loss(qc, log_qz, mean_z=False) # MI: in this model, the mutual information is calculated for z avg_log_qz = torch.log(torch.mean(qz, dim=0) + 1e-15) # mult_k x k mi = torch.mean(torch.sum(qz * log_qz, dim=(-1, -2))) - torch.sum( torch.exp(avg_log_qz) * avg_log_qz) mi_of_c = torch.mean(torch.sum(qc * log_qc, dim=-1)) - torch.sum( torch.exp(avg_log_qc) * avg_log_qc) # dispersion term dispersion = self.dispersion(qc) if self.config.beta > 0: zkl = zkl + self.config.beta * dispersion if c_labels is not None: ce_c, klz_sup = self.suploss_for_c(log_qc, c_labels, log_qz) else: ce_c, klz_sup = None, None ce_z = self.suploss_for_z( log_qz, z_labels) if z_labels is not None else None c_entropy = torch.mean(torch.sum(qc * log_qc, dim=-1)) results = Pack(nll=nll, mi=mi, ckl=ckl, zkl=zkl, dispersion=dispersion, PPL=ppl, real_zkl=zkl_ori, real_ckl=ckl, ce_z=ce_z, ce_c=ce_c, klz_sup=klz_sup, elbo=nll + zkl_ori + ckl, c_entropy=c_entropy, mi_of_c=mi_of_c, param_var=self.mean_of_params(tgt_probs=qc)) if return_latent: results['log_qy'] = log_qz results['dec_init_state'] = dec_init_state results['y_ids'] = y_ids results['z'] = sample_y return results