Пример #1
0
    def approx_policy(self, s_vec, a=None):
        if self.discrete_actions:
            # use a softmax policy for discrete action spaces
            prefs = np.array(
                [self.preferences(s_vec, aa) for aa in range(self.n_actions)])
            action_probs = np.exp(prefs) / np.sum(np.exp(prefs))

            if a is not None:
                return action_probs[a]

            action_num = np.random.multinomial(1, action_probs).argmax()
            action = self.num2action[action_num]
        else:
            # use a gaussian policy for continuous action spaces
            mu_weights = self.policy_weights[:self.n_tiles, :].T
            sigma_weights = self.policy_weights[self.n_tiles:, :].T
            mu = np.dot(mu_weights, s_vec)
            sigma = np.exp(np.dot(sigma_weights, s_vec))

            if a is not None:
                return sample_gaussian(mu, sigma, a)

            action = sample_gaussian(mu, sigma)
            # TODO: should check if the action is outside of a bounded
            # continuous action space here

        if isinstance(action, tuple):
            if any([np.isnan(a) for a in action]):
                import ipdb
                ipdb.set_trace()
        elif np.isnan(action):
            import ipdb
            ipdb.set_trace()

        return action
Пример #2
0
 def forward(self, inputs):
     ret = {}
     x, y_class = inputs['x'], inputs['y_class']
     m, v = self.encoder(x)
     flow_loss = 0.0
     if self.use_deterministic_encoder:
         y = self.decoder(m)
         kl_loss = torch.zeros(1)
     elif self.use_flow:
         z = ut.sample_gaussian(m, v)
         decoder_input = z if not self.use_encoding_in_decoder else \
         torch.cat((z,m),dim=-1) #BUGBUG: Ideally the encodings before passing to mu and sigma should be here.
         # decoder_input, log_jacobians = self.flow(decoder_input)
         # flow_loss = self.bound(decoder_input, log_jacobians)
         flow_decoder_input = torch.zeros_like(decoder_input)
         for i in range(self.z_dim):
             flow = self.flow[i]
             single_input = decoder_input[:, i].unsqueeze(1)
             single_output, log_jacobians = flow(single_input)
             flow_decoder_input[:, i] = single_output.squeeze(1)
             flow_loss += self.bound(single_output, log_jacobians)
         flow_loss /= self.z_dim
         y = self.decoder(flow_decoder_input)
         #compute KL divergence loss :
         p_m = self.z_prior[0].expand(m.size())
         p_v = self.z_prior[1].expand(v.size())
         kl_loss = ut.kl_normal(m, v, p_m, p_v)
     else:
         z = ut.sample_gaussian(m, v)
         decoder_input = z if not self.use_encoding_in_decoder else \
         torch.cat((z,m),dim=-1) #BUGBUG: Ideally the encodings before passing to mu and sigma should be here.
         y = self.decoder(decoder_input)
         #compute KL divergence loss :
         p_m = self.z_prior[0].expand(m.size())
         p_v = self.z_prior[1].expand(v.size())
         kl_loss = ut.kl_normal(m, v, p_m, p_v)
     #compute reconstruction loss
     if self.loss_type is 'chamfer':
         x_reconst = CD_loss(y, x)
     # mean or sum
     if self.loss_sum_mean == "mean":
         x_reconst = x_reconst.mean()
         kl_loss = kl_loss.mean()
     else:
         x_reconst = x_reconst.sum()
         kl_loss = kl_loss.sum()
     nelbo = x_reconst + kl_loss + flow_loss
     ret = {
         'nelbo': nelbo,
         'kl_loss': kl_loss,
         'x_reconst': x_reconst,
         'flow_loss': flow_loss
     }
     # classifer network
     mv = torch.cat((m, v), dim=1)
     y_logits = self.z_classifer(mv)
     z_cl_loss = self.z_classifer.cross_entropy_loss(y_logits, y_class)
     ret['z_cl_loss'] = z_cl_loss
     return ret
Пример #3
0
 def fetch_latent_z(self,x):
     m, v = self.encoder(x)
     if self.use_deterministic_encoder:
         z = m
     else:
         z =  ut.sample_gaussian(m,v)
     return z
Пример #4
0
 def forward(self, inputs):
     x,y_class = inputs['x'], inputs['y_class']
     m, v = self.encoder(x)
     if self.use_deterministic_encoder:
         y = self.decoder(m)
         kl_loss = torch.zeros(1)
     else:
         z =  ut.sample_gaussian(m,v).to(device)
         y = self.decoder(z)
         #compute KL divergence loss :
         p_m = self.z_prior[0].expand(m.size())
         p_v = self.z_prior[1].expand(v.size())
         kl_loss = ut.kl_normal(m,v,p_m,p_v)
     #compute reconstruction loss 
     if self.loss_type is 'chamfer':
         x_reconst = CD_loss(y,x)
     
     x_reconst = x_reconst.mean()
     kl_loss = kl_loss.mean()
     #compute classifers
     y_logits = self.z_classifer(z)
     cl_loss = self.z_classifer.cross_entropy_loss(y_logits,y_class)
     nelbo = x_reconst + kl_loss 
     loss = nelbo + cl_loss
     ret = {'loss':loss, 'nelbo':nelbo, 'kl_loss':kl_loss, 'x_reconst':x_reconst, 'cl_loss':cl_loss}
     return ret
        def encode(self):

            self.W_0_mu = utils.unif_weight_init(shape=[self.n_nodes, self.n_hidden])
            self.b_0_mu = tf.Variable(tf.constant(0.01, dtype=tf.float32, shape=[self.n_hidden]))
            self.W_1_mu = utils.unif_weight_init(shape=[self.n_hidden, self.n_embedding])
            self.b_1_mu = tf.Variable(tf.constant(0.01, dtype=tf.float32, shape=[self.n_embedding]))

            self.W_0_sigma = utils.unif_weight_init(shape=[self.n_nodes, self.n_hidden])
            self.b_0_sigma = tf.Variable(tf.constant(0.01, dtype=tf.float32, shape=[self.n_hidden]))
            self.W_1_sigma = utils.unif_weight_init(shape=[self.n_hidden, self.n_embedding])
            self.b_1_sigma = tf.Variable(tf.constant(0.01, dtype=tf.float32, shape=[self.n_embedding]))

            hidden_0_mu_ = utils.gcn_layer_id(self.norm_adj_mat, self.W_0_mu, self.b_0_mu)
            if self.dropout:
                hidden_0_mu = tf.nn.dropout(hidden_0_mu_, self.keep_prob)
            else:
                hidden_0_mu = hidden_0_mu_
            self.mu = utils.gcn_layer(self.norm_adj_mat, hidden_0_mu, self.W_1_mu, self.b_1_mu)
            
            hidden_0_sigma_ = utils.gcn_layer_id(self.norm_adj_mat, self.W_0_sigma, self.b_0_sigma)
            if self.dropout:
                hidden_0_sigma = tf.nn.dropout(hidden_0_sigma_, self.keep_prob)
            else:
                hidden_0_sigma = hidden_0_sigma_
            log_sigma = utils.gcn_layer(self.norm_adj_mat, hidden_0_sigma, self.W_1_sigma, self.b_1_sigma)
            self.sigma = tf.exp(log_sigma)

            return utils.sample_gaussian(self.mu, self.sigma)
Пример #6
0
    def visualize(self, epoch, ims_np):
        Igen = self.netG(self.fixed_noise)
        z = utils.sample_gaussian(self.netZ.emb.weight.clone().cpu(),
                                  self.vis_n)
        Igauss = self.netG(z)
        idx = torch.from_numpy(np.arange(self.vis_n)).cuda()
        Irec = self.netG(self.netZ(idx))
        Iact = torch.from_numpy(ims_np[:self.vis_n]).cuda()

        epoch = 0
        # Generated images
        vutils.save_image(Igen.data,
                          'runs/ims_%s/generations_epoch_%03d.png' %
                          (self.rn, epoch),
                          normalize=False)
        # Reconstructed images
        vutils.save_image(Irec.data,
                          'runs/ims_%s/reconstructions_epoch_%03d.png' %
                          (self.rn, epoch),
                          normalize=False)
        vutils.save_image(Iact.data,
                          'runs/ims_%s/act.png' % (self.rn),
                          normalize=False)
        vutils.save_image(Igauss.data,
                          'runs/ims_%s/gaussian_epoch_%03d.png' %
                          (self.rn, epoch),
                          normalize=False)
Пример #7
0
def main():
    counter = 20
    rn = "golf"
    nz = 100
    parallel = True

    W = torch.load('runs%d/nets_%s/netZ_glo.pth' % (counter, rn))
    W = W['emb.weight'].data.cpu().numpy()

    netG = model_video_orig.netG_new(nz).cuda()
    if torch.cuda.device_count() > 1:
        parallel = True
        print("Using", torch.cuda.device_count(), "GPUs!")
        netG = nn.DataParallel(netG)

    Zs = utils.sample_gaussian(torch.from_numpy(W), 10000)
    Zs = Zs.data.cpu().numpy()

    state_dict = torch.load('runs%d/nets_%s/netG_glo.pth' % (counter, rn))
    netG.load_state_dict(state_dict)  # load the weights for generator (GLO)
    if parallel:
        netG = netG.module

    gmm = GaussianMixture(n_components=100,
                          covariance_type='full',
                          max_iter=100,
                          n_init=10)
    gmm.fit(W)

    z = torch.from_numpy(gmm.sample(100)[0]).float().cuda()
    video = netG(z)
    utils.make_gif(video, 'runs%d/ims_%s/sample' % (counter, rn), 16)
    return video
Пример #8
0
 def reconstruct_input(self,x):
     m, v = self.encoder(x)
     if self.use_deterministic_encoder:
         y = self.decoder(m)
     else:
         z =  ut.sample_gaussian(m,v)
         y = self.decoder(z)
     return y
Пример #9
0
 def sample_point(self, batch):
     p_m = self.z_prior[0].expand(batch, self.z_dim).to(device)
     p_v = self.z_prior[1].expand(batch, self.z_dim).to(device)
     z = ut.sample_gaussian(p_m, p_v)
     decoder_input = z if not self.use_encoding_in_decoder else \
     torch.cat((z,p_m),dim=-1) #BUGBUG: Ideally the encodings before passing to mu and sigma should be here.
     y = self.decoder(decoder_input)
     return y
Пример #10
0
 def sample_z(self, batch):
     m, v = utils.gaussian_parameters(self.z_pre.squeeze(0), dim=0)
     idx = torch.distributions.categorical.Categorical(self.pi).sample(
         (batch, ))
     m, v = m[idx], v[idx]
     x = utils.sample_gaussian(m, v)
     if self.n_flows > 0:
         for flow in self.nf[::-1]:
             x, _ = flow.inverse(x)
     return x
Пример #11
0
 def reconstruct_input(self, x):
     m, v = self.encoder(x)
     if self.use_deterministic_encoder:
         y = self.decoder(m)
     else:
         z = ut.sample_gaussian(m, v)
         decoder_input = z if not self.use_encoding_in_decoder else \
         torch.cat((z,m),dim=-1) #BUGBUG: Ideally the encodings before passing to mu and sigma should be here.
         y = self.decoder(decoder_input)
     return y
Пример #12
0
 def sample_point(self, batch):
     m, v = ut.gaussian_parameters(self.z_pre.squeeze(0), dim=0)
     idx = torch.distributions.categorical.Categorical(self.pi).sample(
         (batch, ))
     m, v = m[idx], v[idx]
     z = ut.sample_gaussian(m, v)
     decoder_input = z if not self.use_encoding_in_decoder else \
     torch.cat((z,m),dim=-1) #BUGBUG: Ideally the encodings before passing to mu and sigma should be here.
     y = self.sample_x_given(decoder_input)
     return y
Пример #13
0
    def visualize(self, epoch, opt_params):
        Igen = self.netG(self.fixed_noise)  # GLO on a noise
        utils.make_gif(
            Igen,
            'runs%d/ims_%s/generations_epoch_%03d' % (counter, self.rn, epoch),
            opt_params.batch_size)

        z = utils.sample_gaussian(self.netZ.emb.weight.clone().cpu(),
                                  self.vis_n)
        Igauss = self.netG(z)  # GLO on gaussian
        utils.make_gif(
            Igauss,
            'runs%d/ims_%s/gaussian_epoch_%03d' % (counter, self.rn, epoch),
            opt_params.batch_size)
def eval(dataloader):
	lst_loss_vae=list()
	lst_loss_ae=list()

	for name_model,model in dict_model.items():
	    model.eval()
	with torch.no_grad():
	    for ind_batch, X_batch in enumerate(dataloader):
	        cur_batch_size=X_batch.shape[0]
	        #print(X_batch.shape)
	        #reconstruct
	        out,_,_=dict_model['PointNet'](X_batch.permute(0,2,1)) #out: batch, 1024
	        set_rep=dict_model['FeatureCompression'](out) #set_rep: batch, dim_rep

	        #encoding. dim: batch, dim_z
	        qm,qv=ut.gaussian_parameters(dict_model['EncoderVAE'](set_rep),dim=1)
	        #sample z
	        z=ut.sample_gaussian(qm,qv,device=device) #batch_size, dim_z
	        #z to rep
	        rep_m,rep_v=ut.gaussian_parameters(dict_model['LatentToRep'](z))
	        X_rec=dict_model['DecoderAE'](ut.sample_gaussian(rep_m,rep_v,device=device)).reshape(cur_batch_size,-1,3)

	        #ae
	        X_rec_ae=dict_model['DecoderAE'](set_rep).reshape(cur_batch_size,-1,3)

	        dist_1,dist_2 = chamfer_dist(X_batch, X_rec)
	        loss_vae = (torch.mean(dist_1,axis=1)) + (torch.mean(dist_2,axis=1))

	        dist_1,dist_2 = chamfer_dist(X_batch, X_rec_ae)
	        loss_ae = (torch.mean(dist_1,axis=1)) + (torch.mean(dist_2,axis=1))

	        lst_loss_vae.append(loss_vae)
	        lst_loss_ae.append(loss_ae)
	avg_loss_vae=torch.cat(lst_loss_vae).mean().item()
	avg_loss_ae=torch.cat(lst_loss_ae).mean().item()
	return avg_loss_vae, avg_loss_ae
Пример #15
0
 def get_mmd(self, z):
     m_mixture, s_mixture = utils.gaussian_parameters(self.z_pre, dim=1)
     num_sample = 200
     z_pri = utils.sample_gaussian(m_mixture,
                                   s_mixture,
                                   repeat=num_sample // self.k)
     if self.n_flows > 0:
         for flow in self.nf:
             z_pri, _ = flow.forward(z_pri)
     # (Monte Carlo) sample the z otherwise GPU will be out of memory
     z_post_samples = z[random.sample(range(z.shape[0]), num_sample)]
     prior_z_kernel = self.compute_kernel(z_pri, z_pri)
     posterior_z_kernel = self.compute_kernel(z_post_samples,
                                              z_post_samples)
     mixed_kernel = self.compute_kernel(z_pri, z_post_samples)
     mmd = prior_z_kernel.mean() + posterior_z_kernel.mean(
     ) - 2 * mixed_kernel.mean()
     return mmd
Пример #16
0
 def forward(self, g, h, r, norm):
     self.node_id = h.squeeze()
     h = self.input_layer(g, h, r, norm)
     h = self.rconv_layer_1(g, h, r, norm)
     h = self.rconv_layer_2(g, h, r, norm)
     self.z_mean, self.z_sigma = utils.gaussian_parameters(h)
     z = utils.sample_gaussian(self.z_mean, self.z_sigma)
     # flow transformed samples
     if self.n_flows > 0:
         self.flow_log_prob = torch.zeros(1, )
         log_det_sum = torch.zeros((z.shape[0], ))
         for flow in self.nf:
             z, log_det = flow.forward(z)
             if log_det.is_cuda:
                 log_det_sum = log_det_sum.cuda()
             log_det_sum = log_det_sum + log_det.view(-1)
         self.flow_log_prob = torch.mean(log_det_sum.view(-1, 1))
     return z
Пример #17
0
    def encode(self):

        self.W_0_mu = utils.unif_weight_init(
            shape=[self.n_nodes, self.n_hidden])
        self.b_0_mu = tf.Variable(
            tf.constant(0.01, dtype=tf.float32, shape=[self.n_hidden]))
        self.W_1_mu = utils.unif_weight_init(
            shape=[self.n_hidden, self.n_embedding])
        self.b_1_mu = tf.Variable(
            tf.constant(0.01, dtype=tf.float32, shape=[self.n_embedding]))

        self.W_0_sigma = utils.unif_weight_init(
            shape=[self.n_nodes, self.n_hidden])
        self.b_0_sigma = tf.Variable(
            tf.constant(0.01, dtype=tf.float32, shape=[self.n_hidden]))
        self.W_1_sigma = utils.unif_weight_init(
            shape=[self.n_hidden, self.n_embedding])
        self.b_1_sigma = tf.Variable(
            tf.constant(0.01, dtype=tf.float32, shape=[self.n_embedding]))

        hidden_0_mu_ = utils.gcn_layer_id(self.norm_adj_mat, self.W_0_mu,
                                          self.b_0_mu)
        if self.dropout:
            hidden_0_mu = tf.nn.dropout(hidden_0_mu_, self.keep_prob)
        else:
            hidden_0_mu = hidden_0_mu_
        self.mu = utils.gcn_layer(self.norm_adj_mat, hidden_0_mu, self.W_1_mu,
                                  self.b_1_mu)

        hidden_0_sigma_ = utils.gcn_layer_id(self.norm_adj_mat, self.W_0_sigma,
                                             self.b_0_sigma)
        if self.dropout:
            hidden_0_sigma = tf.nn.dropout(hidden_0_sigma_, self.keep_prob)
        else:
            hidden_0_sigma = hidden_0_sigma_
        log_sigma = utils.gcn_layer(self.norm_adj_mat, hidden_0_sigma,
                                    self.W_1_sigma, self.b_1_sigma)
        self.sigma = tf.exp(log_sigma)

        return utils.sample_gaussian(self.mu, self.sigma)
Пример #18
0
 def forward(self, inputs):
     ret = {}
     x, y_class = inputs['x'], inputs['y_class']
     m, v = self.encoder(x)
     # Compute the mixture of Gaussian prior
     prior = ut.gaussian_parameters(self.z_pre, dim=1)
     if self.use_deterministic_encoder:
         y = self.decoder(m)
         kl_loss = torch.zeros(1)
     else:
         z = ut.sample_gaussian(m, v)
         decoder_input = z if not self.use_encoding_in_decoder else \
         torch.cat((z,m),dim=-1) #BUGBUG: Ideally the encodings before passing to mu and sigma should be here.
         y = self.decoder(decoder_input)
         #compute KL divergence loss :
         z_prior_m, z_prior_v = prior[0], prior[1]
         kl_loss = ut.log_normal(z, m, v) - ut.log_normal_mixture(
             z, z_prior_m, z_prior_v)
     #compute reconstruction loss
     if self.loss_type is 'chamfer':
         x_reconst = CD_loss(y, x)
     # mean or sum
     if self.loss_sum_mean == "mean":
         x_reconst = x_reconst.mean()
         kl_loss = kl_loss.mean()
     else:
         x_reconst = x_reconst.sum()
         kl_loss = kl_loss.sum()
     nelbo = x_reconst + kl_loss
     ret = {'nelbo': nelbo, 'kl_loss': kl_loss, 'x_reconst': x_reconst}
     # classifer network
     mv = torch.cat((m, v), dim=1)
     y_logits = self.z_classifer(mv)
     z_cl_loss = self.z_classifer.cross_entropy_loss(y_logits, y_class)
     ret['z_cl_loss'] = z_cl_loss
     return ret
Пример #19
0
rn = params['name']
train_path = params['train_path']
test_path = params['test_path']
d = params['icp']['dim']
nz = params['glo']['nz']
do_bn = params['glo']['do_bn']
nc = params['nc']
sz = params['sz']
batch_size = params['fid']['batch_size']
total_n = params['fid']['n_images']

W = torch.load('runs/nets_%s/netZ_nag.pth' % (rn))
W = W['emb.weight'].data.cpu().numpy()

Zs = utils.sample_gaussian(torch.from_numpy(W), total_n)
Zs = Zs.data.cpu().numpy()

state_dict = torch.load('runs/nets_%s/netT_nag.pth' % rn)
netT = icp._netT(d, nz).cuda()
netT.load_state_dict(state_dict)

netG = model._netG(nz, sz, nc, do_bn).cuda()
state_dict = torch.load('runs/nets_%s/netG_nag.pth' % (rn))
netG.load_state_dict(state_dict)

train_ims = np.load(train_path).astype('float')
test_ims = np.load(test_path).astype('float')

rp = np.random.permutation(len(train_ims))[:total_n]
train_ims = train_ims[rp]
Пример #20
0
    def save_synth_cross_modal(self, iters):

        decoderA = self.decoderA
        decoderB = self.decoderB

        # sample a mini-batch
        iterator1 = iter(self.data_loader)
        XA, XB, idsA, idsB, ids = next(iterator1)  # (n x C x H x W)
        if self.use_cuda:
            XA = XA.cuda()
            XB = XB.cuda()

        # z = encA(xA)
        mu_infA, std_infA, _ = self.encoderA(XA)

        # z = encB(xB)
        mu_infB, std_infB, _ = self.encoderB(XB)

        # encoder samples (for cross-modal prediction)
        Z_infA = sample_gaussian(self.use_cuda, mu_infA, std_infA)
        Z_infB = sample_gaussian(self.use_cuda, mu_infB, std_infB)

        WS = torch.ones(XA.shape)
        if self.use_cuda:
            WS = WS.cuda()

        n = XA.shape[0]

        perm = torch.arange(0, (1 + 2) * n).view(1 + 2, n).transpose(1, 0)
        perm = perm.contiguous().view(-1)

        # 1) generate xB from given xA (A2B)

        merged = torch.cat([XA], dim=0)
        XB_synth = torch.sigmoid(decoderB(Z_infA)).data  # given XA
        merged = torch.cat([merged, XB_synth], dim=0)
        merged = torch.cat([merged, WS], dim=0)
        merged = merged[perm, :].cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_synth,
                             'synth_cross_modal_A2B_%s.jpg' % iters)
        mkdirs(self.output_dir_synth)
        save_image(tensor=merged,
                   filename=fname,
                   nrow=(1 + 2) * int(np.sqrt(n)),
                   pad_value=1)

        # 2) generate xA from given xB (B2A)

        merged = torch.cat([XB], dim=0)
        XA_synth = torch.sigmoid(decoderA(Z_infB)).data  # given XB
        merged = torch.cat([merged, XA_synth], dim=0)
        merged = torch.cat([merged, WS], dim=0)
        merged = merged[perm, :].cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_synth,
                             'synth_cross_modal_B2A_%s.jpg' % iters)
        mkdirs(self.output_dir_synth)
        save_image(tensor=merged,
                   filename=fname,
                   nrow=(1 + 2) * int(np.sqrt(n)),
                   pad_value=1)
Пример #21
0
    def __init__(self,
                 sess,
                 config,
                 api,
                 log_dir,
                 forward,
                 scope=None):  # forward???
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.seen_intent = api.seen_intent
        self.rev_seen_intent = api.rev_seen_intent
        self.seen_intent_size = len(self.rev_seen_intent)
        self.unseen_intent = api.unseen_intent
        self.rev_unseen_intent = api.rev_unseen_intent
        self.unseen_intent_size = len(self.rev_unseen_intent)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.label_embed_size = config.label_embed_size
        self.latent_size = config.latent_size

        self.seed = config.seed
        self.use_ot_label = config.use_ot_label
        self.use_rand_ot_label = config.use_rand_ot_label  # Only valid if use_ot_label is true, whether use all other label
        self.use_rand_fixed_ot_label = config.use_rand_fixed_ot_label  # valid when use_ot_label=true and use_rand_ot_label=true
        if self.use_ot_label:
            self.rand_ot_label_num = config.rand_ot_label_num  # valid when use_ot_label=true and use_rand_ot_label=true
        else:
            self.rand_ot_label_num = self.seen_intent_size - 1

        with tf.name_scope("io"):
            # all dialog context and known attributes
            self.labels = tf.placeholder(
                dtype=tf.int32, shape=(None, ),
                name="labels")  # each utterance have a label, [batch_size,]
            self.ot_label_rand = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="ot_labels_rand")
            self.ot_labels_all = tf.placeholder(
                dtype=tf.int32, shape=(None, None),
                name="ot_labels_all")  #(batch_size, len(api.label_vocab)-1)

            # target response given the dialog context
            self.io_tokens = tf.placeholder(dtype=tf.int32,
                                            shape=(None, None),
                                            name="output_tokens")
            self.io_lens = tf.placeholder(dtype=tf.int32,
                                          shape=(None, ),
                                          name="output_lens")
            self.output_labels = tf.placeholder(dtype=tf.int32,
                                                shape=(None, ),
                                                name="output_labels")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(
                dtype=tf.bool, name="use_prior")  # whether use prior
            self.prior_mulogvar = tf.placeholder(
                dtype=tf.float32,
                shape=(None, config.latent_size * 2),
                name="prior_mulogvar")

            self.batch_size = tf.placeholder(dtype=tf.int32, name="batch_size")

        max_out_len = array_ops.shape(self.io_tokens)[1]
        # batch_size = array_ops.shape(self.io_tokens)[0]
        batch_size = self.batch_size

        with variable_scope.variable_scope("labelEmbedding",
                                           reuse=tf.AUTO_REUSE):
            self.la_embedding = tf.get_variable(
                "embedding", [self.seen_intent_size, config.label_embed_size],
                dtype=tf.float32)
            label_embedding = embedding_ops.embedding_lookup(
                self.la_embedding, self.output_labels)  # not use

        with variable_scope.variable_scope("wordEmbedding",
                                           reuse=tf.AUTO_REUSE):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32,
                trainable=False)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask  # boardcast, first row is all 0.

            io_embedding = embedding_ops.embedding_lookup(
                embedding, self.io_tokens)  # 3 dim

            if config.sent_type == "bow":
                io_embedding, _ = get_bow(io_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                io_embedding, _ = get_rnn_encode(io_embedding,
                                                 sent_cell,
                                                 self.io_lens,
                                                 scope="sent_rnn",
                                                 reuse=tf.AUTO_REUSE)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                io_embedding, _ = get_bi_rnn_encode(
                    io_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    self.io_lens,
                    scope="sent_bi_rnn",
                    reuse=tf.AUTO_REUSE
                )  # equal to x of the graph, (batch_size, 300*2)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # print('==========================', io_embedding) # Tensor("models_2/wordEmbedding/sent_bi_rnn/concat:0", shape=(?, 600), dtype=float32)

            # convert label into 1 hot
            my_label_one_hot = tf.one_hot(tf.reshape(self.labels, [-1]),
                                          depth=self.seen_intent_size,
                                          dtype=tf.float32)  # 2 dim
            if config.use_ot_label:
                if config.use_rand_ot_label:
                    ot_label_one_hot = tf.one_hot(tf.reshape(
                        self.ot_label_rand, [-1]),
                                                  depth=self.seen_intent_size,
                                                  dtype=tf.float32)
                    ot_label_one_hot = tf.reshape(
                        ot_label_one_hot,
                        [-1, self.seen_intent_size * self.rand_ot_label_num])
                else:
                    ot_label_one_hot = tf.one_hot(tf.reshape(
                        self.ot_labels_all, [-1]),
                                                  depth=self.seen_intent_size,
                                                  dtype=tf.float32)
                    ot_label_one_hot = tf.reshape(
                        ot_label_one_hot, [
                            -1, self.seen_intent_size *
                            (self.seen_intent_size - 1)
                        ]
                    )  # (batch_size, len(api.label_vocab)*(len(api.label_vocab)-1))

        with variable_scope.variable_scope("recognitionNetwork",
                                           reuse=tf.AUTO_REUSE):
            recog_input = io_embedding
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")  # config.latent_size=200
            recog_mu, recog_logvar = tf.split(
                recog_mulogvar, 2, axis=1
            )  # recognition network output. (batch_size, config.latent_size)

        with variable_scope.variable_scope("priorNetwork",
                                           reuse=tf.AUTO_REUSE):
            # p(xyz) = p(z)p(x|z)p(y|xz)
            # prior network parameter, assum the normal distribution
            # prior_mulogvar = tf.constant([[1] * config.latent_size + [0] * config.latent_size]*batch_size,
            #                              dtype=tf.float32, name="muvar") # can not use by this manner
            prior_mulogvar = self.prior_mulogvar
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(
                self.use_prior,  # bool input
                lambda: sample_gaussian(prior_mu, prior_logvar
                                        ),  # equal to shape(prior_logvar)
                lambda: sample_gaussian(recog_mu, recog_logvar)
            )  # if ... else ..., (batch_size, config.latent_size)
            self.z = latent_sample

        with variable_scope.variable_scope("generationNetwork",
                                           reuse=tf.AUTO_REUSE):
            bow_loss_inputs = latent_sample  # (part of) response network input
            label_inputs = latent_sample
            dec_inputs = latent_sample

            # BOW loss
            if config.use_bow_loss:
                bow_fc1 = layers.fully_connected(
                    bow_loss_inputs,
                    400,
                    activation_fn=tf.tanh,
                    scope="bow_fc1")  # MLPb network fc layer
                # error1:ValueError: The last dimension of the inputs to `Dense` should be defined. Found `None`.
                if config.keep_prob < 1.0:
                    bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
                self.bow_logits = layers.fully_connected(
                    bow_fc1,
                    self.vocab_size,
                    activation_fn=None,
                    scope="bow_project")  # MLPb network fc output

            # Y loss, include the other y.
            my_label_fc1 = layers.fully_connected(label_inputs,
                                                  400,
                                                  activation_fn=tf.tanh,
                                                  scope="my_label_fc1")
            if config.keep_prob < 1.0:
                my_label_fc1 = tf.nn.dropout(my_label_fc1, config.keep_prob)

            # my_label_fc2 = layers.fully_connected(my_label_fc1, 400, activation_fn=tf.tanh, scope="my_label_fc2")
            # if config.keep_prob < 1.0:
            #     my_label_fc2 = tf.nn.dropout(my_label_fc2, config.keep_prob)

            self.my_label_logits = layers.fully_connected(
                my_label_fc1, self.seen_intent_size,
                scope="my_label_project")  # MLPy fc output
            my_label_prob = tf.nn.softmax(
                self.my_label_logits
            )  # softmax output, (batch_size, label_vocab_size)
            self.my_label_prob = my_label_prob
            pred_my_label_embedding = tf.matmul(
                my_label_prob, self.la_embedding
            )  # predicted my label y. (batch_size, label_embed_size)

            if config.use_ot_label:
                if config.use_rand_ot_label:  # use one random other label
                    ot_label_fc1 = layers.fully_connected(
                        label_inputs,
                        400,
                        activation_fn=tf.tanh,
                        scope="ot_label_fc1")
                    if config.keep_prob < 1.0:
                        ot_label_fc1 = tf.nn.dropout(ot_label_fc1,
                                                     config.keep_prob)
                    self.ot_label_logits = layers.fully_connected(
                        ot_label_fc1,
                        self.rand_ot_label_num * self.seen_intent_size,
                        scope="ot_label_rand_project")
                    ot_label_logits_split = tf.reshape(
                        self.ot_label_logits,
                        [-1, self.rand_ot_label_num, self.seen_intent_size])
                    ot_label_prob_short = tf.nn.softmax(ot_label_logits_split)
                    ot_label_prob = tf.reshape(
                        ot_label_prob_short,
                        [-1, self.rand_ot_label_num * self.seen_intent_size]
                    )  # (batch_size, self.rand_ot_label_num*self.label_vocab_size)
                    pred_ot_label_embedding = tf.reshape(
                        tf.matmul(ot_label_prob_short, self.la_embedding),
                        [self.label_embed_size * self.rand_ot_label_num
                         ])  # predicted other label y2.
                else:
                    ot_label_fc1 = layers.fully_connected(
                        label_inputs,
                        400,
                        activation_fn=tf.tanh,
                        scope="ot_label_fc1")
                    if config.keep_prob < 1.0:
                        ot_label_fc1 = tf.nn.dropout(ot_label_fc1,
                                                     config.keep_prob)
                    self.ot_label_logits = layers.fully_connected(
                        ot_label_fc1,
                        self.seen_intent_size * (self.seen_intent_size - 1),
                        scope="ot_label_all_project")
                    ot_label_logits_split = tf.reshape(
                        self.ot_label_logits,
                        [-1, self.seen_intent_size - 1, self.seen_intent_size])
                    ot_label_prob_short = tf.nn.softmax(ot_label_logits_split)
                    ot_label_prob = tf.reshape(
                        ot_label_prob_short, [
                            -1, self.seen_intent_size *
                            (self.seen_intent_size - 1)
                        ]
                    )  # (batch_size, self.label_vocab_size*(self.label_vocab_size-1))
                    pred_ot_label_embedding = tf.reshape(
                        tf.matmul(ot_label_prob_short, self.la_embedding),
                        [self.label_embed_size * (self.seen_intent_size - 1)]
                    )  # predicted other all label y. (batch_size, self.label_embed_size*(self.label_vocab_size-1))
                    # note:matmul can calc (3, 4, 5) × (5, 4) = (3, 4, 4)
            else:  # only use label y.
                self.ot_label_logits = None
                pred_ot_label_embedding = None

            # Decoder, Response Network
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder", reuse=tf.AUTO_REUSE):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:  # test
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=None)  # a function
                dec_input_embedding = None
                dec_seq_lens = None
            else:  # train
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, None)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.io_tokens
                )  # x 's embedding (batch_size, utt_len, embed_size)
                dec_input_embedding = dec_input_embedding[:, 0:
                                                          -1, :]  # ignore the last </s>
                dec_seq_lens = self.io_lens - 1  # input placeholder

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

                # print("=======", dec_input_embedding) # Tensor("models/decoder/strided_slice:0", shape=(?, ?, 200), dtype=float32)

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens
            )  # dec_outs [batch_size, seq, features]

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(
                    dec_outs, axis=2)))  # get softmax vec's max index
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(
                    dec_outs,
                    2)  # (batch_size, utt_len), each element is index of word

        if not forward:
            with variable_scope.variable_scope("loss", reuse=tf.AUTO_REUSE):
                labels = self.io_tokens[:,
                                        1:]  # not include the first word <s>, (batch_size, utt_len)
                label_mask = tf.to_float(tf.sign(labels))

                labels = tf.one_hot(labels,
                                    depth=self.vocab_size,
                                    dtype=tf.float32)

                print(dec_outs)
                print(labels)
                # Tensor("models_1/decoder/dynamic_rnn_decoder/transpose_1:0", shape=(?, ?, 892), dtype=float32)
                # Tensor("models_1/loss/strided_slice:0", shape=(?, ?), dtype=int32)
                # rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=dec_outs, labels=labels) # response network loss
                rc_loss = tf.nn.softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)  # response network loss
                # logits_size=[390,892] labels_size=[1170,892]
                rc_loss = tf.reduce_sum(
                    rc_loss * label_mask,
                    reduction_indices=1)  # (batch_size,), except the word unk
                self.avg_rc_loss = tf.reduce_mean(rc_loss)  # scalar
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))
                """ as n-trial multimodal distribution. """
                tile_bow_logits = tf.tile(
                    tf.expand_dims(self.bow_logits, 1),
                    [1, max_out_len - 1, 1
                     ])  # (batch_size, max_out_len-1, vocab_size)
                bow_loss = tf.nn.softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels
                ) * label_mask  # labels shape less than logits shape, (batch_size, max_out_len-1)
                bow_loss = tf.reduce_sum(bow_loss,
                                         reduction_indices=1)  # (batch_size, )
                self.avg_bow_loss = tf.reduce_mean(bow_loss)  # scalar

                # the label y
                my_label_loss = tf.nn.softmax_cross_entropy_with_logits(
                    logits=my_label_prob,
                    labels=my_label_one_hot)  # label (batch_size,)
                self.avg_my_label_loss = tf.reduce_mean(my_label_loss)
                if config.use_ot_label:
                    ot_label_loss = -tf.nn.softmax_cross_entropy_with_logits(
                        logits=ot_label_prob, labels=ot_label_one_hot)
                    self.avg_ot_label_loss = tf.reduce_mean(ot_label_loss)
                else:
                    self.avg_ot_label_loss = 0.0

                kld = gaussian_kld(
                    recog_mu, recog_logvar, prior_mu,
                    prior_logvar)  # kl divergence, (batch_size,)
                self.avg_kld = tf.reduce_mean(kld)  # scalar
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld  # Restructure loss and kl divergence
                #=====================================================================================================total loss====================================================#
                if config.use_rand_ot_label:
                    aug_elbo = self.avg_bow_loss + 1000 * self.avg_my_label_loss + 10 * self.avg_ot_label_loss + self.elbo  # augmented loss
                    # (1/self.rand_ot_label_num)*
                else:
                    aug_elbo = self.avg_bow_loss + 1000 * self.avg_my_label_loss + 10 * self.avg_ot_label_loss + self.elbo  # augmented loss
                    # (1/(self.label_vocab_size-1))*

                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)
                tf.summary.scalar("my_label_loss", self.avg_my_label_loss)
                tf.summary.scalar("ot_label_loss", self.avg_ot_label_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)  # probability
                self.log_q_z_xy = norm_log_liklihood(
                    latent_sample, recog_mu, recog_logvar)  # probability
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
        print('model establish finish!')
Пример #22
0
    def save_recon(self, iters):
        self.set_mode(train=False)

        mkdirs(self.output_dir_recon)

        fixed_idxs = [3246, 7000, 14305, 19000, 27444, 33100, 38000, 45231, 51000, 55121]

        fixed_idxs60 = []
        for idx in fixed_idxs:
            for i in range(6):
                fixed_idxs60.append(idx + i)

        XA = [0] * len(fixed_idxs60)
        XB = [0] * len(fixed_idxs60)

        for i, idx in enumerate(fixed_idxs60):
            XA[i], XB[i] = \
                self.data_loader.dataset.__getitem__(idx)[0:2]

            if self.use_cuda:
                XA[i] = XA[i].cuda()
                XB[i] = XB[i].cuda()

        XA = torch.stack(XA)
        XB = torch.stack(XB)

        muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA)

        # zB, zS = encB(xB)
        muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB)

        # zS = encAB(xA,xB) via POE
        cate_prob_POE = torch.exp(
            torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB))

        # encoder samples (for training)
        ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA)
        ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB)
        ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE, train=False)

        # encoder samples (for cross-modal prediction)
        ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False)
        ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB, train=False)

        # reconstructed samples (given joint modal observation)
        XA_POE_recon = torch.sigmoid(self.decoderA(ZA_infA, ZS_POE))
        XB_POE_recon = torch.sigmoid(self.decoderB(ZB_infB, ZS_POE))

        # reconstructed samples (given single modal observation)
        XA_infA_recon = torch.sigmoid(self.decoderA(ZA_infA, ZS_infA))
        XB_infB_recon = torch.sigmoid(self.decoderB(ZB_infB, ZS_infB))

        WS = torch.ones(XA.shape)
        if self.use_cuda:
            WS = WS.cuda()

        n = XA.shape[0]
        perm = torch.arange(0, 4 * n).view(4, n).transpose(1, 0)
        perm = perm.contiguous().view(-1)

        ## img
        # merged = torch.cat(
        #     [ XA, XB, XA_infA_recon, XB_infB_recon,
        #       XA_POE_recon, XB_POE_recon, WS ], dim=0
        # )
        merged = torch.cat(
            [XA, XA_infA_recon, XA_POE_recon, WS], dim=0
        )
        merged = merged[perm, :].cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_recon, 'reconA_%s.jpg' % iters)
        mkdirs(self.output_dir_recon)
        save_image(
            tensor=merged, filename=fname, nrow=4 * int(np.sqrt(n)),
            pad_value=1
        )

        WS = torch.ones(XB.shape)
        if self.use_cuda:
            WS = WS.cuda()

        n = XB.shape[0]
        perm = torch.arange(0, 4 * n).view(4, n).transpose(1, 0)
        perm = perm.contiguous().view(-1)

        ## ingr
        merged = torch.cat(
            [XB, XB_infB_recon, XB_POE_recon, WS], dim=0
        )
        merged = merged[perm, :].cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_recon, 'reconB_%s.jpg' % iters)
        mkdirs(self.output_dir_recon)
        save_image(
            tensor=merged, filename=fname, nrow=4 * int(np.sqrt(n)),
            pad_value=1
        )
        self.set_mode(train=True)
Пример #23
0
    def train(self):

        self.set_mode(train=True)

        # prepare dataloader (iterable)
        print('Start loading data...')
        dset = DIGIT('./data', train=True)
        self.data_loader = torch.utils.data.DataLoader(dset, batch_size=self.batch_size, shuffle=True)
        test_dset = DIGIT('./data', train=False)
        self.test_data_loader = torch.utils.data.DataLoader(test_dset, batch_size=self.batch_size, shuffle=True)
        print('test: ', len(test_dset))
        self.N = len(self.data_loader.dataset)
        print('...done')

        # iterators from dataloader
        iterator1 = iter(self.data_loader)
        iterator2 = iter(self.data_loader)

        iter_per_epoch = min(len(iterator1), len(iterator2))

        start_iter = self.ckpt_load_iter + 1
        epoch = int(start_iter / iter_per_epoch)

        for iteration in range(start_iter, self.max_iter + 1):

            # reset data iterators for each epoch
            if iteration % iter_per_epoch == 0:
                print('==== epoch %d done ====' % epoch)
                epoch += 1
                iterator1 = iter(self.data_loader)
                iterator2 = iter(self.data_loader)

            # ============================================
            #          TRAIN THE VAE (ENC & DEC)
            # ============================================

            # sample a mini-batch
            XA, XB, index = next(iterator1)  # (n x C x H x W)

            index = index.cpu().detach().numpy()
            if self.use_cuda:
                XA = XA.cuda()
                XB = XB.cuda()

            # zA, zS = encA(xA)
            muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA)

            # zB, zS = encB(xB)
            muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB)

            # read current values

            # zS = encAB(xA,xB) via POE
            cate_prob_POE = torch.exp(
                torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB))

            # latent_dist = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]}
            # (kl_cont_loss, kl_disc_loss, cont_capacity_loss, disc_capacity_loss) = kl_loss_function(self.use_cuda, iteration, latent_dist)

            # kl losses
            #A
            latent_dist_infA = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]}
            (kl_cont_loss_infA, kl_disc_loss_infA, cont_capacity_loss_infA, disc_capacity_loss_infA) = kl_loss_function(
                self.use_cuda, iteration, latent_dist_infA)

            loss_kl_infA = kl_cont_loss_infA + kl_disc_loss_infA
            capacity_loss_infA = cont_capacity_loss_infA + disc_capacity_loss_infA

            #B
            latent_dist_infB = {'cont': (muB_infB, logvarB_infB), 'disc': [cate_prob_infB]}
            (kl_cont_loss_infB, kl_disc_loss_infB, cont_capacity_loss_infB, disc_capacity_loss_infB) = kl_loss_function(
                self.use_cuda, iteration, latent_dist_infB, cont_capacity=[0.0, 5.0, 50000, 100.0] , disc_capacity=[0.0, 10.0, 50000, 100.0])

            loss_kl_infB = kl_cont_loss_infB + kl_disc_loss_infB
            capacity_loss_infB = cont_capacity_loss_infB + disc_capacity_loss_infB


            loss_capa = capacity_loss_infB

            # encoder samples (for training)
            ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA)
            ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB)
            ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE)

            # encoder samples (for cross-modal prediction)
            ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA)
            ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB)

            # reconstructed samples (given joint modal observation)
            XA_POE_recon = self.decoderA(ZA_infA, ZS_POE)
            XB_POE_recon = self.decoderB(ZB_infB, ZS_POE)

            # reconstructed samples (given single modal observation)
            XA_infA_recon = self.decoderA(ZA_infA, ZS_infA)
            XB_infB_recon = self.decoderB(ZB_infB, ZS_infB)

            # loss_recon_infA = F.l1_loss(torch.sigmoid(XA_infA_recon), XA, reduction='sum').div(XA.size(0))
            loss_recon_infA = reconstruction_loss(XA, torch.sigmoid(XA_infA_recon), distribution="bernoulli")
            #
            loss_recon_infB = reconstruction_loss(XB, torch.sigmoid(XB_infB_recon), distribution="bernoulli")
            #
            loss_recon_POE = \
                F.l1_loss(torch.sigmoid(XA_POE_recon), XA, reduction='sum').div(XA.size(0)) + \
                F.l1_loss(torch.sigmoid(XB_POE_recon), XB, reduction='sum').div(XB.size(0))
            #

            loss_recon = loss_recon_infB

            # total loss for vae
            vae_loss = loss_recon + loss_capa

            # update vae
            self.optim_vae.zero_grad()
            vae_loss.backward()
            self.optim_vae.step()



            # print the losses
            if iteration % self.print_iter == 0:
                prn_str = ( \
                                      '[iter %d (epoch %d)] vae_loss: %.3f ' + \
                                      '(recon: %.3f, capa: %.3f)\n' + \
                                      '    rec_infA = %.3f, rec_infB = %.3f, rec_POE = %.3f\n' + \
                                      '    kl_infA = %.3f, kl_infB = %.3f' + \
                                      '    cont_capacity_loss_infA = %.3f, disc_capacity_loss_infA = %.3f\n' + \
                                      '    cont_capacity_loss_infB = %.3f, disc_capacity_loss_infB = %.3f\n'
                          ) % \
                          (iteration, epoch,
                           vae_loss.item(), loss_recon.item(), loss_capa.item(),
                           loss_recon_infA.item(), loss_recon_infB.item(), loss_recon.item(),
                           loss_kl_infA.item(), loss_kl_infB.item(),
                           cont_capacity_loss_infA.item(), disc_capacity_loss_infA.item(),
                           cont_capacity_loss_infB.item(), disc_capacity_loss_infB.item(),
                           )
                print(prn_str)
                if self.record_file:
                    record = open(self.record_file, 'a')
                    record.write('%s\n' % (prn_str,))
                    record.close()

            # save model parameters
            if iteration % self.ckpt_save_iter == 0:
                self.save_checkpoint(iteration)

            # save output images (recon, synth, etc.)
            if iteration % self.output_save_iter == 0:
                # self.save_embedding(iteration, index, muA_infA, muB_infB, muS_infA, muS_infB, muS_POE)

                # 1) save the recon images
                self.save_recon(iteration)

                # self.save_recon2(iteration, index, XA, XB,
                #     torch.sigmoid(XA_infA_recon).data,
                #     torch.sigmoid(XB_infB_recon).data,
                #     torch.sigmoid(XA_POE_recon).data,
                #     torch.sigmoid(XB_POE_recon).data,
                #     muA_infA, muB_infB, muS_infA, muS_infB, muS_POE,
                #     logalpha, logalphaA, logalphaB
                # )
                z_A, z_B, z_S = self.get_stat()

                #
                #
                #
                # # 2) save the pure-synthesis images
                # # self.save_synth_pure( iteration, howmany=100 )
                # #
                # # 3) save the cross-modal-synthesis images
                # self.save_synth_cross_modal(iteration, z_A, z_B, howmany=3)
                #
                # # 4) save the latent traversed images
                self.save_traverseB(iteration, z_A, z_B, z_S)

                # self.get_loglike(logalpha, logalphaA, logalphaB)

                # # 3) save the latent traversed images
                # if self.dataset.lower() == '3dchairs':
                #     self.save_traverse(iteration, limb=-2, limu=2, inter=0.5)
                # else:
                #     self.save_traverse(iteration, limb=-3, limu=3, inter=0.1)

            if iteration % self.eval_metrics_iter == 0:
                self.save_synth_cross_modal(iteration, z_A, z_B, train=False, howmany=3)

            # (visdom) insert current line stats
            if self.viz_on and (iteration % self.viz_ll_iter == 0):
                self.line_gather.insert(iter=iteration,
                                        recon_both=loss_recon_POE.item(),
                                        recon_A=loss_recon_infA.item(),
                                        recon_B=loss_recon_infB.item(),
                                        kl_A=loss_kl_infA.item(),
                                        kl_B=loss_kl_infB.item(),
                                        cont_capacity_loss_infA=cont_capacity_loss_infA.item(),
                                        disc_capacity_loss_infA=disc_capacity_loss_infA.item(),
                                        cont_capacity_loss_infB=cont_capacity_loss_infB.item(),
                                        disc_capacity_loss_infB=disc_capacity_loss_infB.item()
                                        )

            # (visdom) visualize line stats (then flush out)
            if self.viz_on and (iteration % self.viz_la_iter == 0):
                self.visualize_line()
                self.line_gather.flush()
Пример #24
0
    def generate_cf(self, x, y_de, mean_y, args):
        """
        :param x:
        :param mean_y: list, the class-wise feature y
        """
        if mean_y.dim() == 2:
            class_num = mean_y.size(0)
        elif mean_y.dim() == 3:
            class_num = mean_y.size(1)
        bs = x.size(0)

        enc1_1, mu_up1_1, var_up1_1 = self.CONV1_1.encode(x)
        enc1_2, mu_up1_2, var_up1_2 = self.CONV1_2.encode(enc1_1)

        enc2_1, mu_up2_1, var_up2_1 = self.CONV2_1.encode(enc1_2)
        enc2_2, mu_up2_2, var_up2_2 = self.CONV2_2.encode(enc2_1)

        enc3_1, mu_up3_1, var_up3_1 = self.CONV3_1.encode(enc2_2)
        enc3_2, mu_up3_2, var_up3_2 = self.CONV3_2.encode(enc3_1)

        enc4_1, mu_up4_1, var_up4_1 = self.CONV4_1.encode(enc3_2)
        enc4_2, mu_up4_2, var_up4_2 = self.CONV4_2.encode(enc4_1)

        enc5_1, mu_up5_1, var_up5_1 = self.CONV5_1.encode(enc4_2)
        enc5_2, mu_latent, var_latent = self.CONV5_2.encode(enc5_1)

        z_latent_mu, y_latent_mu = mu_latent.split([args.encode_z, 32], dim=1)
        z_latent_var, y_latent_var = var_latent.split([args.encode_z, 32],
                                                      dim=1)

        z_latent_mu = z_latent_mu.unsqueeze(1).repeat(1, class_num, 1)
        if mean_y.dim() == 2:
            y_mu = mean_y.unsqueeze(0).repeat(bs, 1, 1)
        elif mean_y.dim() == 3:
            y_mu = mean_y
        latent_zy = torch.cat([z_latent_mu, y_mu],
                              dim=2).view(bs * class_num, mu_latent.size(1))

        # latent = ut.sample_gaussian(mu_latent, var_latent)

        # partially downwards
        dec5_1, mu_dn5_1, var_dn5_1 = self.TCONV5_2.decode(latent_zy)
        prec_up5_1 = (var_up5_1**(-1)).repeat(class_num, 1)
        prec_dn5_1 = var_dn5_1**(-1)
        qmu5_1 = (mu_up5_1.repeat(class_num, 1) * prec_up5_1 +
                  mu_dn5_1 * prec_dn5_1) / (prec_up5_1 + prec_dn5_1)
        qvar5_1 = (prec_up5_1 + prec_dn5_1)**(-1)
        de_latent5_1 = ut.sample_gaussian(qmu5_1, qvar5_1)

        dec4_2, mu_dn4_2, var_dn4_2 = self.TCONV5_1.decode(de_latent5_1)
        prec_up4_2 = (var_up4_2**(-1)).repeat(class_num, 1)
        prec_dn4_2 = var_dn4_2**(-1)
        qmu4_2 = (mu_up4_2.repeat(class_num, 1) * prec_up4_2 +
                  mu_dn4_2 * prec_dn4_2) / (prec_up4_2 + prec_dn4_2)
        qvar4_2 = (prec_up4_2 + prec_dn4_2)**(-1)
        de_latent4_2 = ut.sample_gaussian(qmu4_2, qvar4_2)

        dec4_1, mu_dn4_1, var_dn4_1 = self.TCONV4_2.decode(de_latent4_2)
        prec_up4_1 = (var_up4_1**(-1)).repeat(class_num, 1)
        prec_dn4_1 = var_dn4_1**(-1)
        qmu4_1 = (mu_up4_1.repeat(class_num, 1) * prec_up4_1 +
                  mu_dn4_1 * prec_dn4_1) / (prec_up4_1 + prec_dn4_1)
        qvar4_1 = (prec_up4_1 + prec_dn4_1)**(-1)
        de_latent4_1 = ut.sample_gaussian(qmu4_1, qvar4_1)

        dec3_2, mu_dn3_2, var_dn3_2 = self.TCONV4_1.decode(de_latent4_1)
        prec_up3_2 = (var_up3_2**(-1)).repeat(class_num, 1)
        prec_dn3_2 = var_dn3_2**(-1)
        qmu3_2 = (mu_up3_2.repeat(class_num, 1) * prec_up3_2 +
                  mu_dn3_2 * prec_dn3_2) / (prec_up3_2 + prec_dn3_2)
        qvar3_2 = (prec_up3_2 + prec_dn3_2)**(-1)
        de_latent3_2 = ut.sample_gaussian(qmu3_2, qvar3_2)

        dec3_1, mu_dn3_1, var_dn3_1 = self.TCONV3_2.decode(de_latent3_2)
        prec_up3_1 = (var_up3_1**(-1)).repeat(class_num, 1)
        prec_dn3_1 = var_dn3_1**(-1)
        qmu3_1 = (mu_up3_1.repeat(class_num, 1) * prec_up3_1 +
                  mu_dn3_1 * prec_dn3_1) / (prec_up3_1 + prec_dn3_1)
        qvar3_1 = (prec_up3_1 + prec_dn3_1)**(-1)
        de_latent3_1 = ut.sample_gaussian(qmu3_1, qvar3_1)

        dec2_2, mu_dn2_2, var_dn2_2 = self.TCONV3_1.decode(de_latent3_1)
        prec_up2_2 = (var_up2_2**(-1)).repeat(class_num, 1)
        prec_dn2_2 = var_dn2_2**(-1)
        qmu2_2 = (mu_up2_2.repeat(class_num, 1) * prec_up2_2 +
                  mu_dn2_2 * prec_dn2_2) / (prec_up2_2 + prec_dn2_2)
        qvar2_2 = (prec_up2_2 + prec_dn2_2)**(-1)
        de_latent2_2 = ut.sample_gaussian(qmu2_2, qvar2_2)

        dec2_1, mu_dn2_1, var_dn2_1 = self.TCONV2_2.decode(de_latent2_2)
        prec_up2_1 = (var_up2_1**(-1)).repeat(class_num, 1)
        prec_dn2_1 = var_dn2_1**(-1)
        qmu2_1 = (mu_up2_1.repeat(class_num, 1) * prec_up2_1 +
                  mu_dn2_1 * prec_dn2_1) / (prec_up2_1 + prec_dn2_1)
        qvar2_1 = (prec_up2_1 + prec_dn2_1)**(-1)
        de_latent2_1 = ut.sample_gaussian(qmu2_1, qvar2_1)

        dec1_2, mu_dn1_2, var_dn1_2 = self.TCONV2_1.decode(de_latent2_1)
        prec_up1_2 = (var_up1_2**(-1)).repeat(class_num, 1)
        prec_dn1_2 = var_dn1_2**(-1)
        qmu1_2 = (mu_up1_2.repeat(class_num, 1) * prec_up1_2 +
                  mu_dn1_2 * prec_dn1_2) / (prec_up1_2 + prec_dn1_2)
        qvar1_2 = (prec_up1_2 + prec_dn1_2)**(-1)
        de_latent1_2 = ut.sample_gaussian(qmu1_2, qvar1_2)

        dec1_1, mu_dn1_1, var_dn1_1 = self.TCONV1_2.decode(de_latent1_2)
        prec_up1_1 = (var_up1_1**(-1)).repeat(class_num, 1)
        prec_dn1_1 = var_dn1_1**(-1)
        qmu1_1 = (mu_up1_1.repeat(class_num, 1) * prec_up1_1 +
                  mu_dn1_1 * prec_dn1_1) / (prec_up1_1 + prec_dn1_1)
        qvar1_1 = (prec_up1_1 + prec_dn1_1)**(-1)
        de_latent1_1 = ut.sample_gaussian(qmu1_1, qvar1_1)

        x_re = self.TCONV1_1.final_decode(de_latent1_1)

        return x_re.view(bs, class_num, *x.size()[1:])
Пример #25
0
    def __init__(self,
            num_symbols,
            num_embed_units,
            num_units,
            is_train,
            vocab=None,
            content_pos=None,
            rhetoric_pos = None,
            embed=None,
            learning_rate=0.1,
            learning_rate_decay_factor=0.9995,
            max_gradient_norm=5.0,
            max_length=30,
            latent_size=128,
            use_lstm=False,
            num_classes=3,
            full_kl_step=80000,
            mem_slot_num=4,
            mem_size=128):
        
        self.ori_sents = tf.placeholder(tf.string, shape=(None, None))
        self.ori_sents_length = tf.placeholder(tf.int32, shape=(None))
        self.rep_sents = tf.placeholder(tf.string, shape=(None, None))
        self.rep_sents_length = tf.placeholder(tf.int32, shape=(None))
        self.labels = tf.placeholder(tf.float32, shape=(None, num_classes))
        self.use_prior = tf.placeholder(tf.bool)
        self.global_t = tf.placeholder(tf.int32)
        self.content_mask = tf.reduce_sum(tf.one_hot(content_pos, num_symbols, 1.0, 0.0), axis = 0)
        self.rhetoric_mask = tf.reduce_sum(tf.one_hot(rhetoric_pos, num_symbols, 1.0, 0.0), axis = 0)

        topic_memory = tf.zeros(name="topic_memory", dtype=tf.float32,
                                  shape=[None, mem_slot_num, mem_size])

        w_topic_memory = tf.get_variable(name="w_topic_memory", dtype=tf.float32,
                                    initializer=tf.random_uniform([mem_size, mem_size], -0.1, 0.1))

        # build the vocab table (string to index)
        if is_train:
            self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        else:
            self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols")
        self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols, 
            tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)), 
            default_value=UNK_ID, name="symbol2index")

        self.ori_sents_input = self.symbol2index.lookup(self.ori_sents)
        self.rep_sents_target = self.symbol2index.lookup(self.rep_sents)
        batch_size, decoder_len = tf.shape(self.rep_sents)[0], tf.shape(self.rep_sents)[1]
        self.rep_sents_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID,
            tf.split(self.rep_sents_target, [decoder_len-1, 1], 1)[0]], 1)
        self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.rep_sents_length-1,
            decoder_len), reverse=True, axis=1), [-1, decoder_len])        
        
        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed)

        self.pattern_embed = tf.get_variable('pattern_embed', [num_classes, num_embed_units], tf.float32)
        
        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.ori_sents_input)
        self.decoder_input = tf.nn.embedding_lookup(self.embed, self.rep_sents_input)

        if use_lstm:
            cell_fw = LSTMCell(num_units)
            cell_bw = LSTMCell(num_units)
            cell_dec = LSTMCell(2*num_units)
        else:
            cell_fw = GRUCell(num_units)
            cell_bw = GRUCell(num_units)
            cell_dec = GRUCell(2*num_units)

        # origin sentence encoder
        with variable_scope.variable_scope("encoder"):
            encoder_output, encoder_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.encoder_input, 
                self.ori_sents_length, dtype=tf.float32)
            post_sum_state = tf.concat(encoder_state, 1)
            encoder_output = tf.concat(encoder_output, 2)

        # response sentence encoder
        with variable_scope.variable_scope("encoder", reuse = True):
            decoder_state, decoder_last_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.decoder_input, 
                self.rep_sents_length, dtype=tf.float32)
            response_sum_state = tf.concat(decoder_last_state, 1)

        # recognition network
        with variable_scope.variable_scope("recog_net"):
            recog_input = tf.concat([post_sum_state, response_sum_state], 1)
            recog_mulogvar = tf.contrib.layers.fully_connected(recog_input, latent_size * 2, activation_fn=None, scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        # prior network
        with variable_scope.variable_scope("prior_net"):
            prior_fc1 = tf.contrib.layers.fully_connected(post_sum_state, latent_size * 2, activation_fn=tf.tanh, scope="fc1")
            prior_mulogvar = tf.contrib.layers.fully_connected(prior_fc1, latent_size * 2, activation_fn=None, scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

        latent_sample = tf.cond(self.use_prior,
                                lambda: sample_gaussian(prior_mu, prior_logvar),
                                lambda: sample_gaussian(recog_mu, recog_logvar))


        # classifier
        with variable_scope.variable_scope("classifier"):
            classifier_input = latent_sample
            pattern_fc1 = tf.contrib.layers.fully_connected(classifier_input, latent_size, activation_fn=tf.tanh, scope="pattern_fc1")
            self.pattern_logits = tf.contrib.layers.fully_connected(pattern_fc1, num_classes, activation_fn=None, scope="pattern_logits")

        self.label_embedding = tf.matmul(self.labels, self.pattern_embed)

        output_fn, my_sequence_loss = output_projection_layer(2*num_units, num_symbols, latent_size, num_embed_units, self.content_mask, self.rhetoric_mask)

        attention_keys, attention_values, attention_score_fn, attention_construct_fn = my_attention_decoder_fn.prepare_attention(encoder_output, 'luong', 2*num_units)

        with variable_scope.variable_scope("dec_start"):
            temp_start = tf.concat([post_sum_state, self.label_embedding, latent_sample], 1)
            dec_fc1 = tf.contrib.layers.fully_connected(temp_start, 2*num_units, activation_fn=tf.tanh, scope="dec_start_fc1")
            dec_fc2 = tf.contrib.layers.fully_connected(dec_fc1, 2*num_units, activation_fn=None, scope="dec_start_fc2")

        if is_train:
            # rnn decoder
            topic_memory = self.update_memory(topic_memory, encoder_output)
            extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1)

            decoder_fn_train = my_attention_decoder_fn.attention_decoder_fn_train(dec_fc2, 
                attention_keys, attention_values, attention_score_fn, attention_construct_fn, extra_info)
            self.decoder_output, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_train, 
                self.decoder_input, self.rep_sents_length, scope = "decoder")

            # calculate the loss
            self.decoder_loss = my_loss.sequence_loss(logits = self.decoder_output, 
                targets = self.rep_sents_target, weights = self.decoder_mask,
                extra_information = latent_sample, label_embedding = self.label_embedding, softmax_loss_function = my_sequence_loss)
            temp_klloss = tf.reduce_mean(gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar))
            self.kl_weight = tf.minimum(tf.to_float(self.global_t)/full_kl_step, 1.0)
            self.klloss = self.kl_weight * temp_klloss
            temp_labels = tf.argmax(self.labels, 1)
            self.classifierloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.pattern_logits, labels=temp_labels))
            self.loss = self.decoder_loss + self.klloss + self.classifierloss  # need to anneal the kl_weight
            
            # building graph finished and get all parameters
            self.params = tf.trainable_variables()
        
            # initialize the training process
            self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(self.learning_rate * learning_rate_decay_factor)
            self.global_step = tf.Variable(0, trainable=False)
            
            # calculate the gradient of parameters
            opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9)
            gradients = tf.gradients(self.loss, self.params)
            clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, 
                    max_gradient_norm)
            self.update = opt.apply_gradients(zip(clipped_gradients, self.params), 
                    global_step=self.global_step)

        else:
            # rnn decoder
            topic_memory = self.update_memory(topic_memory, encoder_output)
            extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1)
            decoder_fn_inference = my_attention_decoder_fn.attention_decoder_fn_inference(output_fn, 
                dec_fc2, attention_keys, attention_values, attention_score_fn, 
                attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, extra_info)
            self.decoder_distribution, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_inference, scope="decoder")
            self.generation_index = tf.argmax(tf.split(self.decoder_distribution,
                [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index)
            
            self.params = tf.trainable_variables()

        self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, 
                max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
Пример #26
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.topic_vocab = api.topic_vocab
        self.topic_vocab_size = len(self.topic_vocab)

        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.bow_weights = config.bow_weights

        with tf.name_scope("io"):
            # all dialog context and known attributes
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, None,
                                                        self.max_utt_len),
                                                 name="context")
            self.context_lens = tf.placeholder(dtype=tf.int32,
                                               shape=(None, ),
                                               name="context_lens")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")
            self.output_topics = tf.placeholder(dtype=tf.int32,
                                                shape=(None, ),
                                                name="output_topic")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_context_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        if config.use_hcf:
            with variable_scope.variable_scope("topicEmbedding"):
                t_embedding = tf.get_variable(
                    "embedding",
                    [self.topic_vocab_size, config.topic_embed_size],
                    dtype=tf.float32)
                topic_embedding = embedding_ops.embedding_lookup(
                    t_embedding, self.output_topics)

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])
            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)

            # context nn
            if config.sent_type == "bow":
                input_embedding, sent_size = get_bow(input_embedding)
                output_embedding, _ = get_bow(output_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                input_embedding, sent_size = get_rnn_encode(input_embedding,
                                                            sent_cell,
                                                            scope="sent_rnn")
                output_embedding, _ = get_rnn_encode(output_embedding,
                                                     sent_cell,
                                                     self.output_lens,
                                                     scope="sent_rnn",
                                                     reuse=True)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_bi_rnn",
                                                        reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # reshape input into dialogs
            input_embedding = tf.reshape(input_embedding,
                                         [-1, max_context_len, sent_size])
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

        with variable_scope.variable_scope("contextRNN"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                input_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                enc_last_state = tf.concat(enc_last_state, 1)

        # combine with other attributes
        if config.use_hcf:
            attribute_embedding = topic_embedding
            attribute_fc1 = layers.fully_connected(attribute_embedding,
                                                   30,
                                                   activation_fn=tf.tanh,
                                                   scope="attribute_fc1")

        cond_embedding = enc_last_state

        with variable_scope.variable_scope("recognitionNetwork"):
            if config.use_hcf:
                recog_input = tf.concat(
                    [cond_embedding, output_embedding, attribute_fc1], 1)
            else:
                recog_input = tf.concat([cond_embedding, output_embedding], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            prior_fc1 = layers.fully_connected(cond_embedding,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample], 1)

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs,
                                             400,
                                             activation_fn=tf.tanh,
                                             scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1,
                                                     self.vocab_size,
                                                     activation_fn=None,
                                                     scope="bow_project")

            # Y loss
            if config.use_hcf:
                meta_fc1 = layers.fully_connected(latent_sample,
                                                  400,
                                                  activation_fn=tf.tanh,
                                                  scope="meta_fc1")
                if config.keep_prob < 1.0:
                    meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob)
                self.topic_logits = layers.fully_connected(
                    meta_fc1, self.topic_vocab_size, scope="topic_project")
                topic_prob = tf.nn.softmax(self.topic_logits)
                #pred_attribute_embedding = tf.matmul(topic_prob, t_embedding)
                pred_topic = tf.argmax(topic_prob, 1)
                pred_attribute_embedding = embedding_ops.embedding_lookup(
                    t_embedding, pred_topic)
                if forward:
                    selected_attribute_embedding = pred_attribute_embedding
                else:
                    selected_attribute_embedding = attribute_embedding
                dec_inputs = tf.concat(
                    [gen_inputs, selected_attribute_embedding], 1)
            else:
                self.topic_logits = tf.zeros(
                    (batch_size, self.topic_vocab_size))
                selected_attribute_embedding = None
                dec_inputs = gen_inputs

            # Decoder
            if config.num_layer > 1:
                dec_init_state = [
                    layers.fully_connected(dec_inputs,
                                           self.dec_cell_size,
                                           activation_fn=None,
                                           scope="init_state-%d" % i)
                    for i in range(config.num_layer)
                ]
                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens)
            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):
                labels = self.output_tokens[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))

                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))
                """ as n-trial multimodal distribution. """
                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1),
                                          [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss = tf.reduce_mean(bow_loss)
                bow_weights = tf.to_float(self.bow_weights)

                # reconstruct the meta info about X
                if config.use_hcf:
                    topic_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=self.topic_logits, labels=self.output_topics)
                    self.avg_topic_loss = tf.reduce_mean(topic_loss)
                else:
                    self.avg_topic_loss = 0.0

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = bow_weights * self.avg_bow_loss + self.avg_topic_loss + self.elbo

                tf.summary.scalar("topic_loss", self.avg_topic_loss)
                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
Пример #27
0
    def __init__(self,
                 sess,
                 config,
                 api,
                 log_dir,
                 forward,
                 scope=None,
                 name=None):
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.idf = api.index2idf
        self.gen_vocab_size = api.gen_vocab_size
        self.topic_vocab = api.topic_vocab
        self.topic_vocab_size = len(self.topic_vocab)
        self.da_vocab = api.dialog_act_vocab
        self.da_vocab_size = len(self.da_vocab)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.max_per_len = config.max_per_len
        self.max_per_line = config.max_per_line
        self.max_per_words = config.max_per_words
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.memory_cell_size = config.memory_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.hops = config.hops
        self.batch_size = config.batch_size
        self.test_samples = config.test_samples
        self.balance_factor = config.balance_factor

        with tf.name_scope("io"):
            self.first_dimension_size = self.batch_size
            self.input_contexts = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, None, self.max_utt_len),
                name="dialog_context")
            self.floors = tf.placeholder(dtype=tf.int32,
                                         shape=(self.first_dimension_size,
                                                None),
                                         name="floor")
            self.context_lens = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, ),
                name="context_lens")
            self.topics = tf.placeholder(dtype=tf.int32,
                                         shape=(self.first_dimension_size, ),
                                         name="topics")
            self.personas = tf.placeholder(dtype=tf.int32,
                                           shape=(self.first_dimension_size,
                                                  self.max_per_line,
                                                  self.max_per_len),
                                           name="personas")
            self.persona_words = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, self.max_per_line,
                       self.max_per_len),
                name="persona_words")
            self.persona_position = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, None),
                name="persona_position")
            self.selected_persona = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, 1),
                name="selected_persona")

            self.query = tf.placeholder(dtype=tf.int32,
                                        shape=(self.first_dimension_size,
                                               self.max_utt_len),
                                        name="query")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, None),
                name="output_token")
            self.output_lens = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, ),
                name="output_lens")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_context_lines = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask
            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])
            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)
            persona_input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.personas, [-1]))
            persona_input_embedding = tf.reshape(
                persona_input_embedding,
                [-1, self.max_per_len, config.embed_size])
            if config.sent_type == "bow":
                input_embedding, sent_size = get_bow(input_embedding)
                output_embedding, _ = get_bow(output_embedding)
                persona_input_embedding, _ = get_bow(persona_input_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                _, input_embedding, sent_size = get_rnn_encode(
                    input_embedding, sent_cell, scope="sent_rnn")
                _, output_embedding, _ = get_rnn_encode(output_embedding,
                                                        sent_cell,
                                                        self.output_lens,
                                                        scope="sent_rnn",
                                                        reuse=True)
                _, persona_input_embedding, _ = get_rnn_encode(
                    persona_input_embedding,
                    sent_cell,
                    scope="sent_rnn",
                    reuse=True)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                input_step_embedding, input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn")
                _, output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                           fwd_sent_cell,
                                                           bwd_sent_cell,
                                                           self.output_lens,
                                                           scope="sent_bi_rnn",
                                                           reuse=True)
                _, persona_input_embedding, _ = get_bi_rnn_encode(
                    persona_input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn",
                    reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")
            # reshape input into dialogs
            input_embedding = tf.reshape(input_embedding,
                                         [-1, max_context_lines, sent_size])
            self.input_step_embedding = input_step_embedding
            self.encoder_state_size = sent_size
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

        with variable_scope.variable_scope("personaMemory"):
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            A = tf.get_variable("persona_embedding_A",
                                [self.vocab_size, self.memory_cell_size],
                                dtype=tf.float32)
            A = A * embedding_mask
            C = []
            for hopn in range(self.hops):
                C.append(
                    tf.get_variable("persona_embedding_C_hop_{}".format(hopn),
                                    [self.vocab_size, self.memory_cell_size],
                                    dtype=tf.float32) * embedding_mask)

            q_emb = tf.nn.embedding_lookup(A, self.query)
            u_0 = tf.reduce_sum(q_emb, 1)
            u = [u_0]
            for hopn in range(self.hops):
                if hopn == 0:
                    m_emb_A = tf.nn.embedding_lookup(A, self.personas)
                    m_A = tf.reshape(m_emb_A, [
                        -1, self.max_per_len * self.max_per_line,
                        self.memory_cell_size
                    ])
                else:
                    with tf.variable_scope('persona_hop_{}'.format(hopn)):
                        m_emb_A = tf.nn.embedding_lookup(
                            C[hopn - 1], self.personas)
                        m_A = tf.reshape(m_emb_A, [
                            -1, self.max_per_len * self.max_per_line,
                            self.memory_cell_size
                        ])
                u_temp = tf.transpose(tf.expand_dims(u[-1], -1), [0, 2, 1])
                dotted = tf.reduce_sum(m_A * u_temp, 2)
                probs = tf.nn.softmax(dotted)
                probs_temp = tf.transpose(tf.expand_dims(probs, -1), [0, 2, 1])
                with tf.variable_scope('persona_hop_{}'.format(hopn)):
                    m_emb_C = tf.nn.embedding_lookup(
                        C[hopn],
                        tf.reshape(self.personas,
                                   [-1, self.max_per_len * self.max_per_line]))
                    m_emb_C = tf.expand_dims(m_emb_C, -2)
                    m_C = tf.reduce_sum(m_emb_C, axis=2)
                c_temp = tf.transpose(m_C, [0, 2, 1])
                o_k = tf.reduce_sum(c_temp * probs_temp, axis=2)
                u_k = u[-1] + o_k
                u.append(u_k)
            persona_memory = u[-1]

        with variable_scope.variable_scope("contextEmbedding"):
            context_layers = 2
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=context_layers)
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                input_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if context_layers > 1:
                if config.cell_type == 'lstm':
                    enc_last_state = [temp.h for temp in enc_last_state]

                enc_last_state = tf.concat(enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    enc_last_state = enc_last_state.h

        cond_embedding = tf.concat([persona_memory, enc_last_state], 1)

        with variable_scope.variable_scope("recognitionNetwork"):
            recog_input = tf.concat(
                [cond_embedding, output_embedding, persona_memory], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            prior_fc1 = layers.fully_connected(cond_embedding,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("personaSelecting"):
            condition = tf.concat([persona_memory, latent_sample], 1)

            self.persona_dist = tf.nn.log_softmax(
                layers.fully_connected(condition,
                                       self.max_per_line,
                                       activation_fn=tf.tanh,
                                       scope="persona_dist"))
            select_temp = tf.expand_dims(
                tf.argmax(self.persona_dist, 1, output_type=tf.int32), 1)
            index_temp = tf.expand_dims(
                tf.range(0, self.first_dimension_size, dtype=tf.int32), 1)
            persona_select = tf.concat([index_temp, select_temp], 1)
            selected_words_ordered = tf.reshape(
                tf.gather_nd(self.persona_words, persona_select),
                [self.max_per_len * self.first_dimension_size])
            self.selected_words = tf.gather_nd(self.persona_words,
                                               persona_select)
            label = tf.reshape(
                selected_words_ordered,
                [self.max_per_len * self.first_dimension_size, 1])
            index = tf.reshape(
                tf.range(self.first_dimension_size, dtype=tf.int32),
                [self.first_dimension_size, 1])
            index = tf.reshape(
                tf.tile(index, [1, self.max_per_len]),
                [self.max_per_len * self.first_dimension_size, 1])

            concated = tf.concat([index, label], 1)
            true_labels = tf.where(selected_words_ordered > 0)
            concated = tf.gather_nd(concated, true_labels)
            self.persona_word_mask = tf.sparse_to_dense(
                concated, [self.first_dimension_size, self.vocab_size],
                config.perw_weight, 0.0)
            self.other_word_mask = tf.sparse_to_dense(
                concated, [self.first_dimension_size, self.vocab_size], 0.0,
                config.othw_weight)
            self.persona_word_mask = self.persona_word_mask * self.idf

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample], 1)

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs,
                                             400,
                                             activation_fn=tf.tanh,
                                             scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1,
                                                     self.vocab_size,
                                                     activation_fn=None,
                                                     scope="bow_project")

            # Y loss
            dec_inputs = gen_inputs
            selected_attribute_embedding = None
            self.da_logits = tf.zeros((batch_size, self.da_vocab_size))

            # Decoder
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            pos_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            pos_cell = OutputProjectionWrapper(pos_cell, self.vocab_size)

            with variable_scope.variable_scope("position"):
                self.pos_w_1 = tf.get_variable("pos_w_1",
                                               [self.dec_cell_size, 2],
                                               dtype=tf.float32)
                self.pos_b_1 = tf.get_variable("pos_b_1", [2],
                                               dtype=tf.float32)

            def position_function(states, logp=False):
                states = tf.reshape(states, [-1, self.dec_cell_size])
                if logp:
                    return tf.reshape(
                        tf.nn.log_softmax(
                            tf.matmul(states, self.pos_w_1) + self.pos_b_1),
                        [self.first_dimension_size, -1, 2])
                return tf.reshape(
                    tf.nn.softmax(
                        tf.matmul(states, self.pos_w_1) + self.pos_b_1),
                    [self.first_dimension_size, -1, 2])

            if forward:
                loop_func = self.context_decoder_fn_inference(
                    position_function,
                    self.persona_word_mask,
                    self.other_word_mask,
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding,
                )
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = self.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1
                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            with variable_scope.variable_scope("dec_state"):
                dec_outs, _, final_context_state, rnn_states = dynamic_rnn_decoder(
                    dec_cell,
                    loop_func,
                    inputs=dec_input_embedding,
                    sequence_length=dec_seq_lens)
            with variable_scope.variable_scope("pos_state"):
                _, _, _, pos_states = dynamic_rnn_decoder(
                    pos_cell,
                    loop_func,
                    inputs=dec_input_embedding,
                    sequence_length=dec_seq_lens)

            self.position_dist = position_function(pos_states, logp=True)

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)
        if not forward:
            with variable_scope.variable_scope("loss"):
                labels = self.output_tokens[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))
                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))
                per_select_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tf.reshape(self.persona_dist,
                                      [self.first_dimension_size, 1, -1]),
                    labels=self.selected_persona)
                per_select_loss = tf.reduce_sum(per_select_loss,
                                                reduction_indices=1)
                self.avg_per_select_loss = tf.reduce_mean(per_select_loss)
                position_labels = self.persona_position[:, 1:]
                per_pos_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=self.position_dist, labels=position_labels)
                per_pos_loss = tf.reduce_sum(per_pos_loss, reduction_indices=1)
                self.avg_per_pos_loss = tf.reduce_mean(per_pos_loss)

                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1),
                                          [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss = tf.reduce_mean(bow_loss)
                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)

                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = self.elbo + self.avg_bow_loss + 0.1 * self.avg_per_select_loss + 0.05 * self.avg_per_pos_loss

                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("per_pos_loss", self.avg_per_pos_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
        def decode(self):

            z = utils.sample_gaussian(self.mu, self.sigma)
            matrix_pred = tf.matmul(z, z, transpose_a=False, transpose_b=True)

            return matrix_pred
Пример #29
0
    def __init__(self, tfFLAGS, embed=None):
        self.vocab_size = tfFLAGS.vocab_size
        self.embed_size = tfFLAGS.embed_size
        self.num_units = tfFLAGS.num_units
        self.num_layers = tfFLAGS.num_layers
        self.beam_width = tfFLAGS.beam_width
        self.use_lstm = tfFLAGS.use_lstm
        self.attn_mode = tfFLAGS.attn_mode
        self.train_keep_prob = tfFLAGS.keep_prob
        self.max_decode_len = tfFLAGS.max_decode_len
        self.bi_encode = tfFLAGS.bi_encode
        self.recog_hidden_units = tfFLAGS.recog_hidden_units
        self.prior_hidden_units = tfFLAGS.prior_hidden_units
        self.z_dim = tfFLAGS.z_dim
        self.full_kl_step = tfFLAGS.full_kl_step

        self.global_step = tf.Variable(0, name="global_step", trainable=False)
        self.max_gradient_norm = 5.0
        if tfFLAGS.opt == 'SGD':
            self.learning_rate = tf.Variable(float(tfFLAGS.learning_rate),
                                             trainable=False,
                                             dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(
                self.learning_rate * tfFLAGS.learning_rate_decay_factor)
            self.opt = tf.train.GradientDescentOptimizer(self.learning_rate)
        elif tfFLAGS.opt == 'Momentum':
            self.opt = tf.train.MomentumOptimizer(
                learning_rate=tfFLAGS.learning_rate, momentum=tfFLAGS.momentum)
        else:
            self.learning_rate = tfFLAGS.learning_rate
            self.opt = tf.train.AdamOptimizer()

        self._make_input(embed)

        with tf.variable_scope("output_layer"):
            self.output_layer = Dense(
                self.vocab_size,
                kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))

        with tf.variable_scope("encoders",
                               initializer=tf.orthogonal_initializer()):
            self.enc_post_outputs, self.enc_post_state = self._build_encoder(
                scope='post_encoder',
                inputs=self.enc_post,
                sequence_length=self.post_len)
            self.enc_ref_outputs, self.enc_ref_state = self._build_encoder(
                scope='ref_encoder',
                inputs=self.enc_ref,
                sequence_length=self.ref_len)
            self.enc_response_outputs, self.enc_response_state = self._build_encoder(
                scope='resp_encoder',
                inputs=self.enc_response,
                sequence_length=self.response_len)

            self.post_state = self._get_representation_from_enc_state(
                self.enc_post_state)
            self.ref_state = self._get_representation_from_enc_state(
                self.enc_ref_state)
            self.response_state = self._get_representation_from_enc_state(
                self.enc_response_state)
            self.cond_embed = tf.concat([self.post_state, self.ref_state],
                                        axis=-1)

        with tf.variable_scope("RecognitionNetwork"):
            recog_input = tf.concat([self.cond_embed, self.response_state],
                                    axis=-1)
            recog_hidden = tf.layers.dense(inputs=recog_input,
                                           units=self.recog_hidden_units,
                                           activation=tf.nn.tanh)
            recog_mulogvar = tf.layers.dense(inputs=recog_hidden,
                                             units=self.z_dim * 2,
                                             activation=None)
            # recog_mulogvar = tf.layers.dense(inputs=recog_input, units=self.z_dim * 2, activation=None)
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=-1)

        with tf.variable_scope("PriorNetwork"):
            prior_input = self.cond_embed
            prior_hidden = tf.layers.dense(inputs=prior_input,
                                           units=self.prior_hidden_units,
                                           activation=tf.nn.tanh)
            prior_mulogvar = tf.layers.dense(inputs=prior_hidden,
                                             units=self.z_dim * 2,
                                             activation=None)
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=-1)

        with tf.variable_scope("GenerationNetwork"):
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar),
                name='latent_sample')

            gen_input = tf.concat([self.cond_embed, latent_sample], axis=-1)
            if self.use_lstm:
                self.dec_init_state = tuple([
                    tf.contrib.rnn.LSTMStateTuple(
                        c=tf.layers.dense(inputs=gen_input,
                                          units=self.num_units,
                                          activation=None),
                        h=tf.layers.dense(inputs=gen_input,
                                          units=self.num_units,
                                          activation=None))
                    for _ in range(self.num_layers)
                ])
                print self.dec_init_state
            else:
                self.dec_init_state = tuple([
                    tf.layers.dense(inputs=gen_input,
                                    units=self.num_units,
                                    activation=None)
                    for _ in range(self.num_layers)
                ])

            kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar)
            self.avg_kld = tf.reduce_mean(kld)
            self.kl_weights = tf.minimum(
                tf.to_float(self.global_step) / self.full_kl_step, 1.0)
            self.kl_loss = self.kl_weights * self.avg_kld

        self._build_decoder()
        self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                    max_to_keep=1,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
        for var in tf.trainable_variables():
            print var
Пример #30
0
    def lnet(self, x, y_de, args):
        # ---deterministic upward pass
        # upwards
        enc1_1, mu_up1_1, var_up1_1 = self.CONV1_1.encode(x)
        enc1_2, mu_up1_2, var_up1_2 = self.CONV1_2.encode(enc1_1)

        enc2_1, mu_up2_1, var_up2_1 = self.CONV2_1.encode(enc1_2)
        enc2_2, mu_up2_2, var_up2_2 = self.CONV2_2.encode(enc2_1)

        enc3_1, mu_up3_1, var_up3_1 = self.CONV3_1.encode(enc2_2)
        enc3_2, mu_up3_2, var_up3_2 = self.CONV3_2.encode(enc3_1)

        enc4_1, mu_up4_1, var_up4_1 = self.CONV4_1.encode(enc3_2)
        enc4_2, mu_up4_2, var_up4_2 = self.CONV4_2.encode(enc4_1)

        enc5_1, mu_up5_1, var_up5_1 = self.CONV5_1.encode(enc4_2)
        enc5_2, mu_latent, var_latent = self.CONV5_2.encode(enc5_1)

        # split z and y
        if args.encode_z:
            z_latent_mu, y_latent_mu = mu_latent.split([args.encode_z, 32],
                                                       dim=1)
            z_latent_var, y_latent_var = var_latent.split([args.encode_z, 32],
                                                          dim=1)
            latent = ut.sample_gaussian(mu_latent, var_latent)
            latent_y = ut.sample_gaussian(y_latent_mu, y_latent_var)
        else:
            y_latent_mu = mu_latent
            y_latent_var = var_latent
            latent = ut.sample_gaussian(mu_latent, var_latent)
            latent_y = latent

        predict = F.log_softmax(self.classifier(latent_y), dim=1)
        predict_test = F.log_softmax(self.classifier(y_latent_mu), dim=1)
        yh = self.one_hot(y_de)

        # partially downwards
        dec5_1, mu_dn5_1, var_dn5_1 = self.TCONV5_2.decode(latent)
        prec_up5_1 = var_up5_1**(-1)
        prec_dn5_1 = var_dn5_1**(-1)
        qmu5_1 = (mu_up5_1 * prec_up5_1 +
                  mu_dn5_1 * prec_dn5_1) / (prec_up5_1 + prec_dn5_1)
        qvar5_1 = (prec_up5_1 + prec_dn5_1)**(-1)
        de_latent5_1 = ut.sample_gaussian(qmu5_1, qvar5_1)

        dec4_2, mu_dn4_2, var_dn4_2 = self.TCONV5_1.decode(de_latent5_1)
        prec_up4_2 = var_up4_2**(-1)
        prec_dn4_2 = var_dn4_2**(-1)
        qmu4_2 = (mu_up4_2 * prec_up4_2 +
                  mu_dn4_2 * prec_dn4_2) / (prec_up4_2 + prec_dn4_2)
        qvar4_2 = (prec_up4_2 + prec_dn4_2)**(-1)
        de_latent4_2 = ut.sample_gaussian(qmu4_2, qvar4_2)

        dec4_1, mu_dn4_1, var_dn4_1 = self.TCONV4_2.decode(de_latent4_2)
        prec_up4_1 = var_up4_1**(-1)
        prec_dn4_1 = var_dn4_1**(-1)
        qmu4_1 = (mu_up4_1 * prec_up4_1 +
                  mu_dn4_1 * prec_dn4_1) / (prec_up4_1 + prec_dn4_1)
        qvar4_1 = (prec_up4_1 + prec_dn4_1)**(-1)
        de_latent4_1 = ut.sample_gaussian(qmu4_1, qvar4_1)

        dec3_2, mu_dn3_2, var_dn3_2 = self.TCONV4_1.decode(de_latent4_1)
        prec_up3_2 = var_up3_2**(-1)
        prec_dn3_2 = var_dn3_2**(-1)
        qmu3_2 = (mu_up3_2 * prec_up3_2 +
                  mu_dn3_2 * prec_dn3_2) / (prec_up3_2 + prec_dn3_2)
        qvar3_2 = (prec_up3_2 + prec_dn3_2)**(-1)
        de_latent3_2 = ut.sample_gaussian(qmu3_2, qvar3_2)

        dec3_1, mu_dn3_1, var_dn3_1 = self.TCONV3_2.decode(de_latent3_2)
        prec_up3_1 = var_up3_1**(-1)
        prec_dn3_1 = var_dn3_1**(-1)
        qmu3_1 = (mu_up3_1 * prec_up3_1 +
                  mu_dn3_1 * prec_dn3_1) / (prec_up3_1 + prec_dn3_1)
        qvar3_1 = (prec_up3_1 + prec_dn3_1)**(-1)
        de_latent3_1 = ut.sample_gaussian(qmu3_1, qvar3_1)

        dec2_2, mu_dn2_2, var_dn2_2 = self.TCONV3_1.decode(de_latent3_1)
        prec_up2_2 = var_up2_2**(-1)
        prec_dn2_2 = var_dn2_2**(-1)
        qmu2_2 = (mu_up2_2 * prec_up2_2 +
                  mu_dn2_2 * prec_dn2_2) / (prec_up2_2 + prec_dn2_2)
        qvar2_2 = (prec_up2_2 + prec_dn2_2)**(-1)
        de_latent2_2 = ut.sample_gaussian(qmu2_2, qvar2_2)

        dec2_1, mu_dn2_1, var_dn2_1 = self.TCONV2_2.decode(de_latent2_2)
        prec_up2_1 = var_up2_1**(-1)
        prec_dn2_1 = var_dn2_1**(-1)
        qmu2_1 = (mu_up2_1 * prec_up2_1 +
                  mu_dn2_1 * prec_dn2_1) / (prec_up2_1 + prec_dn2_1)
        qvar2_1 = (prec_up2_1 + prec_dn2_1)**(-1)
        de_latent2_1 = ut.sample_gaussian(qmu2_1, qvar2_1)

        dec1_2, mu_dn1_2, var_dn1_2 = self.TCONV2_1.decode(de_latent2_1)
        prec_up1_2 = var_up1_2**(-1)
        prec_dn1_2 = var_dn1_2**(-1)
        qmu1_2 = (mu_up1_2 * prec_up1_2 +
                  mu_dn1_2 * prec_dn1_2) / (prec_up1_2 + prec_dn1_2)
        qvar1_2 = (prec_up1_2 + prec_dn1_2)**(-1)
        de_latent1_2 = ut.sample_gaussian(qmu1_2, qvar1_2)

        dec1_1, mu_dn1_1, var_dn1_1 = self.TCONV1_2.decode(de_latent1_2)
        prec_up1_1 = var_up1_1**(-1)
        prec_dn1_1 = var_dn1_1**(-1)
        qmu1_1 = (mu_up1_1 * prec_up1_1 +
                  mu_dn1_1 * prec_dn1_1) / (prec_up1_1 + prec_dn1_1)
        qvar1_1 = (prec_up1_1 + prec_dn1_1)**(-1)
        de_latent1_1 = ut.sample_gaussian(qmu1_1, qvar1_1)

        x_re = self.TCONV1_1.final_decode(de_latent1_1)

        if args.contrastive_loss and self.training:
            self.contra_loss = self.contrastive_loss(x, y_de, x_re, args)

        return latent, mu_latent, var_latent, \
               qmu5_1, qvar5_1, qmu4_2, qvar4_2, qmu4_1, qvar4_1, qmu3_2, qvar3_2, qmu3_1, qvar3_1, \
               qmu2_2, qvar2_2, qmu2_1, qvar2_1, qmu1_2, qvar1_2, qmu1_1, qvar1_1, \
               predict, predict_test, yh, \
               x_re, \
               mu_dn5_1, var_dn5_1, mu_dn4_2, var_dn4_2, mu_dn4_1, var_dn4_1, mu_dn3_2, var_dn3_2, mu_dn3_1, var_dn3_1, \
               mu_dn2_2, var_dn2_2, mu_dn2_1, var_dn2_1, mu_dn1_2, var_dn1_2, mu_dn1_1, var_dn1_1
Пример #31
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        self.vocab = api.vocab  # index2word
        self.rev_vocab = api.rev_vocab  # word2index
        self.vocab_size = len(self.vocab)  # vocab size
        self.emotion_vocab = api.emotion_vocab  # index2emotion
        self.emotion_vocab_size = len(self.emotion_vocab)
        # self.da_vocab = api.dialog_act_vocab
        # self.da_vocab_size = len(self.da_vocab)

        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]

        self.context_cell_size = config.cxt_cell_size  # dont need
        self.sent_cell_size = config.sent_cell_size  # for encode
        self.dec_cell_size = config.dec_cell_size  # for decode

        with tf.name_scope("io"):
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None,
                                                        self.max_utt_len),
                                                 name="input_contexts")
            # self.floors = tf.placeholder(dtype=tf.int32, shape=(None, None), name="floor")
            self.input_lens = tf.placeholder(dtype=tf.int32,
                                             shape=(None, ),
                                             name="input_lens")
            self.input_emotions = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, ),
                                                 name="input_emotions")
            # self.my_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="my_profile")
            # self.ot_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="ot_profile")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")
            self.output_emotions = tf.placeholder(dtype=tf.int32,
                                                  shape=(None, ),
                                                  name="output_emotions")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_dialog_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("emotionEmbedding"):
            t_embedding = tf.get_variable(
                "embedding",
                [self.emotion_vocab_size, config.topic_embed_size],
                dtype=tf.float32)
            inp_emotion_embedding = embedding_ops.embedding_lookup(
                t_embedding, self.input_emotions)
            outp_emotion_embedding = embedding_ops.embedding_lookup(
                t_embedding, self.output_emotions)

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])
            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)

            if config.sent_type == "rnn":
                enc_cell = self.get_rnncell(config.cell_type,
                                            self.context_cell_size,
                                            keep_prob=1.0,
                                            num_layer=config.num_layer)
                _, enc_last_state = tf.nn.dynamic_rnn(
                    enc_cell,
                    input_embedding,
                    dtype=tf.float32,
                    sequence_length=self.input_lens)

                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                # input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn")

                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        sent_cell,
                                                        self.output_lens,
                                                        scope="sent_rnn")

            elif config.sent_type == "bi_rnn":
                fwd_enc_cell = self.get_rnncell(config.cell_type,
                                                self.context_cell_size,
                                                keep_prob=1.0,
                                                num_layer=config.num_layer)
                bwd_enc_cell = self.get_rnncell(config.cell_type,
                                                self.context_cell_size,
                                                keep_prob=1.0,
                                                num_layer=config.num_layer)
                _, enc_last_state = tf.nn.bidirectional_dynamic_rnn(
                    fwd_enc_cell,
                    bwd_enc_cell,
                    input_embedding,
                    dtype=tf.float32,
                    sequence_length=self.input_lens)
                enc_last_state = enc_last_state[0] + enc_last_state[1]

                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                # input_embedding, sent_size = get_bi_rnn_encode(input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_bi_rnn")

            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [rnn, bi_rnn]")

            # reshape input into dialogs
            # input_embedding = tf.reshape(input_embedding, [-1, max_dialog_len, sent_size])
            # if config.keep_prob < 1.0:
            #     input_embedding = tf.nn.dropout(input_embedding, config.keep_prob)

            # convert floors into 1 hot
            # floor_one_hot = tf.one_hot(tf.reshape(self.floors, [-1]), depth=2, dtype=tf.float32)
            # floor_one_hot = tf.reshape(floor_one_hot, [-1, max_dialog_len, 2])

            # joint_embedding = tf.concat([input_embedding, floor_one_hot], 2, "joint_embedding")

        with variable_scope.variable_scope("contextRNN"):

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    enc_last_state = [temp.h for temp in enc_last_state]

                enc_last_state = tf.concat(enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    enc_last_state = enc_last_state.h

            attribute_fc1 = layers.fully_connected(outp_emotion_embedding,
                                                   30,
                                                   activation_fn=tf.tanh,
                                                   scope="attribute_fc1")

            cond_embedding = tf.concat([inp_emotion_embedding, enc_last_state],
                                       1)

        with variable_scope.variable_scope("recognitionNetwork"):
            recog_input = tf.concat(
                [cond_embedding, output_embedding, attribute_fc1], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            prior_fc1 = layers.fully_connected(cond_embedding,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample], 1)

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs,
                                             400,
                                             activation_fn=tf.tanh,
                                             scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1,
                                                     self.vocab_size,
                                                     activation_fn=None,
                                                     scope="bow_project")

            # Y loss
            meta_fc1 = layers.fully_connected(gen_inputs,
                                              400,
                                              activation_fn=tf.tanh,
                                              scope="meta_fc1")
            if config.keep_prob < 1.0:
                meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob)
            self.da_logits = layers.fully_connected(meta_fc1,
                                                    self.emotion_vocab_size,
                                                    scope="da_project")
            da_prob = tf.nn.softmax(self.da_logits)
            pred_attribute_embedding = tf.matmul(da_prob, t_embedding)
            if forward:
                selected_attribute_embedding = pred_attribute_embedding
            else:
                selected_attribute_embedding = outp_emotion_embedding
            dec_inputs = tf.concat([gen_inputs, selected_attribute_embedding],
                                   1)

            # Decoder
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens)

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):
                labels = self.output_tokens[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))

                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))

                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1),
                                          [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss = tf.reduce_mean(bow_loss)

                da_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=self.da_logits, labels=self.output_emotions)
                self.avg_da_loss = tf.reduce_mean(da_loss)

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = self.avg_bow_loss + self.avg_da_loss + self.elbo

                tf.summary.scalar("da_loss", self.avg_da_loss)
                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
Пример #32
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.num_topics = config.num_topics

        with tf.name_scope("io"):
            # all dialog context and known attributes
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, None,
                                                        self.max_utt_len),
                                                 name="dialog_context")
            self.floors = tf.placeholder(dtype=tf.float32,
                                         shape=(None, None),
                                         name="floor")  # TODO float
            self.floor_labels = tf.placeholder(dtype=tf.float32,
                                               shape=(None, 1),
                                               name="floor_labels")
            self.context_lens = tf.placeholder(dtype=tf.int32,
                                               shape=(None, ),
                                               name="context_lens")
            self.paragraph_topics = tf.placeholder(dtype=tf.float32,
                                                   shape=(None,
                                                          self.num_topics),
                                                   name="paragraph_topics")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")
            self.output_das = tf.placeholder(dtype=tf.float32,
                                             shape=(None, self.num_topics),
                                             name="output_dialog_acts")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_dialog_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            # embed the input
            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            # reshape embedding. -1 means that the first dimension can be whatever necessary to make the other 2 dimensions work w/the data
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])
            # embed the output so you can feed it into the VAE
            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)

            #
            if config.sent_type == "bow":
                input_embedding, sent_size = get_bow(input_embedding)
                output_embedding, _ = get_bow(output_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                input_embedding, sent_size = get_rnn_encode(input_embedding,
                                                            sent_cell,
                                                            scope="sent_rnn")
                output_embedding, _ = get_rnn_encode(output_embedding,
                                                     sent_cell,
                                                     self.output_lens,
                                                     scope="sent_rnn",
                                                     reuse=True)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_bi_rnn",
                                                        reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # reshape input into dialogs
            input_embedding = tf.reshape(input_embedding,
                                         [-1, max_dialog_len, sent_size])
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

            # reshape floors
            floor = tf.reshape(self.floors, [-1, max_dialog_len, 1])

            joint_embedding = tf.concat([input_embedding, floor], 2,
                                        "joint_embedding")

        with variable_scope.variable_scope("contextRNN"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                joint_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    enc_last_state = [temp.h for temp in enc_last_state]

                enc_last_state = tf.concat(enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    enc_last_state = enc_last_state.h

        # combine with other attributes
        if config.use_hcf:
            # TODO is this reshape ok?
            attribute_embedding = tf.reshape(
                self.output_das, [-1, self.num_topics])  # da_embedding
            attribute_fc1 = layers.fully_connected(attribute_embedding,
                                                   30,
                                                   activation_fn=tf.tanh,
                                                   scope="attribute_fc1")

        # conditions include topic and rnn of all previous birnn results and metadata about the two people
        cond_list = [self.paragraph_topics, enc_last_state]
        cond_embedding = tf.concat(cond_list, 1)  #float32

        with variable_scope.variable_scope("recognitionNetwork"):
            if config.use_hcf:
                recog_input = tf.concat(
                    [cond_embedding, output_embedding, attribute_fc1], 1)
            else:
                recog_input = tf.concat([cond_embedding, output_embedding], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            # mu and logvar are both vectors of size latent_size
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            # P(XYZ)=P(Z|X)P(X)P(Y|X,Z)
            prior_fc1 = layers.fully_connected(cond_embedding,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample],
                                   1)  #float32

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs,
                                             400,
                                             activation_fn=tf.tanh,
                                             scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1,
                                                     self.vocab_size,
                                                     activation_fn=None,
                                                     scope="bow_project")

            # Predicting Y (topic)
            if config.use_hcf:
                meta_fc1 = layers.fully_connected(gen_inputs,
                                                  400,
                                                  activation_fn=tf.tanh,
                                                  scope="meta_fc1")
                if config.keep_prob < 1.0:
                    meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob)
                self.da_logits = layers.fully_connected(
                    meta_fc1, self.num_topics, scope="da_project")  # float32

                da_prob = tf.nn.softmax(self.da_logits)
                pred_attribute_embedding = da_prob  # TODO change the name of this to predicted sentence topic
                # pred_attribute_embedding = tf.matmul(da_prob, d_embedding)

                if forward:
                    selected_attribute_embedding = pred_attribute_embedding
                else:
                    selected_attribute_embedding = attribute_embedding
                dec_inputs = tf.concat(
                    [gen_inputs, selected_attribute_embedding], 1)

            # if use_hcf not on, the model won't predict the Y
            else:
                self.da_logits = tf.zeros((batch_size, self.num_topics))
                dec_inputs = gen_inputs
                selected_attribute_embedding = None

            # Predicting whether or not end of paragraph
            self.paragraph_end_logits = layers.fully_connected(
                gen_inputs,
                1,
                activation_fn=tf.tanh,
                scope="paragraph_end_fc1")  # float32

            # Decoder
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        # initializer thing for lstm
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            # projects into thing of vocab size. TODO no softmax?
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    # get make of keep/throw-away
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens,
                name='output_node')

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):

                labels = self.output_tokens[:, 1:]  # correct word tokens
                label_mask = tf.to_float(tf.sign(labels))

                # Loss between words
                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))

                # BOW loss
                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1),
                                          [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss = tf.reduce_mean(bow_loss)

                # Predict 0/1 (1 = last sentence in paragraph)
                end_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.floor_labels, logits=self.paragraph_end_logits)
                self.avg_end_loss = tf.reduce_mean(end_loss)

                # Topic prediction loss
                if config.use_hcf:
                    div_prob = tf.divide(self.da_logits, self.output_das)
                    self.avg_da_loss = tf.reduce_mean(
                        -tf.nn.softmax_cross_entropy_with_logits(
                            logits=self.da_logits, labels=div_prob))

                else:
                    self.avg_da_loss = 0.0

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = self.avg_bow_loss + self.avg_da_loss + self.elbo + self.avg_end_loss

                tf.summary.scalar("da_loss", self.avg_da_loss)
                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)
                tf.summary.scalar("paragraph_end_loss", self.avg_end_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)