Beispiel #1
0
    def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth):
        super().__init__()

        self.Q1 = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth)
        self.Q2 = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth)
        self.encoder = None
        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #2
0
    def __init__(self, obs_dim, action_dim, hidden_channels):
        super().__init__()

        self.Q1 = utils.mlp(obs_dim + action_dim, hidden_channels, 1)
        self.Q2 = utils.mlp(obs_dim + action_dim, hidden_channels, 1)

        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #3
0
    def __init__(self, fusion_dim, action_dim, hidden_dim, hidden_depth):
        super().__init__()

        self.Q1 = utils.mlp(fusion_dim + action_dim, hidden_dim, 1,
                            hidden_depth)
        self.Q2 = utils.mlp(fusion_dim + action_dim, hidden_dim, 1,
                            hidden_depth)

        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #4
0
    def __init__(self, encoder_cfg, action_shape, hidden_dim, hidden_depth):
        super().__init__()

        self.encoder = hydra.utils.instantiate(encoder_cfg)

        self.Q1 = utils.mlp(self.encoder.feature_dim + action_shape[0],
                            hidden_dim, 1, hidden_depth)
        self.Q2 = utils.mlp(self.encoder.feature_dim + action_shape[0],
                            hidden_dim, 1, hidden_depth)

        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #5
0
    def __init__(self, repr_dim, feature_dim, action_shape, hidden_dim,
                 hidden_depth):
        super().__init__()

        self.pre_fc = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                    nn.LayerNorm(feature_dim))
        self.Q1 = utils.mlp(feature_dim + action_shape[0], hidden_dim, 1,
                            hidden_depth)
        self.Q2 = utils.mlp(feature_dim + action_shape[0], hidden_dim, 1,
                            hidden_depth)

        self.apply(utils.weight_init)
Beispiel #6
0
    def __init__(self, obs_shape, feature_dim, action_shape, hidden_dim,
                 hidden_depth):
        super().__init__()

        self.encoder = Encoder(obs_shape, feature_dim)

        self.Q1 = utils.mlp(self.encoder.feature_dim + action_shape[0],
                            hidden_dim, 1, hidden_depth)
        self.Q2 = utils.mlp(self.encoder.feature_dim + action_shape[0],
                            hidden_dim, 1, hidden_depth)

        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #7
0
    def __init__(self, encoder_cfg, action_shape, state_input_shape, hidden_dim, hidden_depth, cat_fc_size):
        super().__init__()

        self.encoder = hydra.utils.instantiate(encoder_cfg)

        self.Q1_state = utils.mlp(state_input_shape + action_shape[0], hidden_dim, hidden_dim, 1, nn.Tanh(), nn.Tanh())
        self.Q2_state = utils.mlp(state_input_shape + action_shape[0], hidden_dim, hidden_dim, 1, nn.Tanh(), nn.Tanh())

        self.Q1 = utils.mlp(self.encoder.feature_dim + hidden_dim,
                            cat_fc_size, 1, 1, activation=nn.Tanh())
        self.Q2 = utils.mlp(self.encoder.feature_dim + hidden_dim,
                            cat_fc_size, 1, 1, activation=nn.Tanh())

        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #8
0
    def __init__(self, encoder_cfg, action_shape, state_input_shape, hidden_dim, hidden_depth, cat_fc_size,
                 log_std_bounds):
        super().__init__()

        self.encoder = hydra.utils.instantiate(encoder_cfg)

        self.log_std_bounds = log_std_bounds

        self.state_input = utils.mlp(state_input_shape, hidden_dim, hidden_dim, 1, nn.Tanh(), nn.Tanh())

        self.trunk = utils.mlp(self.encoder.feature_dim + hidden_dim, cat_fc_size,
                               2 * action_shape[0], 1, activation=nn.Tanh())

        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #9
0
def ddpg_mlp_actor_critic(obs, 
                          a, 
                          act_dim,
                          act_lim,
                          hidden_size=(256, 256), 
                          activation=tf.nn.relu,
                          out_activation=tf.nn.tanh,
                          ):
    with tf.variable_scope('pi'):
        pi = act_lim * mlp(obs, list(hidden_size)+[act_dim], activation, out_activation)
    with tf.variable_scope('q'):
        q = tf.squeeze(mlp(tf.concat([obs, a], -1), list(hidden_size)+[1], activation), -1)
    with tf.variable_scope('q', reuse=True):
        q_pi = tf.squeeze(mlp(tf.concat([obs, pi], -1), list(hidden_size)+[1], activation), -1)
    return pi, q, q_pi
Beispiel #10
0
    def __init__(self, from_images, observation_space, action_space,
                 feature_dim, hidden_sizes):
        super().__init__()

        if from_images:
            # images
            self.encoder = ConvolutionalEncoder(observation_space, feature_dim)
        else:
            # states
            self.encoder = ConcatenationEncoder(observation_space)

        self.Q1 = utils.mlp(self.encoder.output_dim + action_space.shape[0], 1,
                            hidden_sizes)
        self.Q2 = utils.mlp(self.encoder.output_dim + action_space.shape[0], 1,
                            hidden_sizes)
Beispiel #11
0
    def __init__(self, obs_shape, action_shape, hidden_dim, hidden_depth):
        super().__init__()

        self.Q1 = utils.mlp(obs_shape[0] + action_shape[0],
                            hidden_dim,
                            1,
                            hidden_depth,
                            use_ln=True)
        self.Q2 = utils.mlp(obs_shape[0] + action_shape[0],
                            hidden_dim,
                            1,
                            hidden_depth,
                            use_ln=True)

        self.outputs = dict()
        self.apply(utils.weight_init)
    def __init__(
        self,
        obs_dim,
        act_dim,
        hidden_sizes,
        activation=nn.relu,
        output_activation=nn.relu,  # DEPLOY: Put back to identity when deploy
        name="lyapunov_critic",
        **kwargs,
    ):
        """Constructs all the necessary attributes for the Soft Q critic object.
        Args:
            obs_dim (int): Dimension of the observation space.

            act_dim (int): Dimension of the action space.

            hidden_sizes (list): Sizes of the hidden layers.

            activation (function): The hidden layer activation function.

            output_activation (function, optional): The activation function used for
                the output layers. Defaults to tf.keras.activations.linear.

            name (str, optional): The Lyapunov critic name. Defaults to
                "lyapunov_critic".
        """
        super().__init__(name=name, **kwargs)
        self.lya = mlp(
            [obs_dim + act_dim] + list(hidden_sizes),
            activation,
            output_activation,
            name=name,
        )
    def __init__(
            self,
            obs_dim,
            act_dim,
            hidden_sizes,
            activation=nn.ReLU,
            output_activation=nn.
        ReLU,  # DEPLOY: Put back to identity when deploy
    ):
        """Constructs all the necessary attributes for the Soft Q critic object.

        Args:
            obs_dim (int): Dimension of the observation space.

            act_dim (int): Dimension of the action space.

            hidden_sizes (list): Sizes of the hidden layers.

            activation (torch.nn.modules.activation): The hidden layer activation
                function.

            output_activation (torch.nn.modules.activation, optional): The activation
                function used for the output layers. Defaults to torch.nn.Identity.
        """
        super().__init__()
        self.lya = mlp([obs_dim + act_dim] + list(hidden_sizes), activation,
                       output_activation)
Beispiel #14
0
    def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth,
                 log_std_bounds):
        super().__init__()

        self.log_std_bounds = log_std_bounds
        self.trunk = utils.mlp(obs_dim, hidden_dim, 2 * action_dim,
                               hidden_depth)
        self.encoder = None
        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #15
0
    def __init__(self, repr_dim, feature_dim, action_shape, hidden_dim,
                 hidden_depth, log_std_bounds):
        super().__init__()

        self.log_std_bounds = log_std_bounds
        self.pre_fc = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                    nn.LayerNorm(feature_dim))
        self.fc = utils.mlp(feature_dim, hidden_dim, 2 * action_shape[0],
                            hidden_depth)

        self.apply(utils.weight_init)
Beispiel #16
0
    def __init__(self, fusion_dim,
                 action_dim, hidden_dim, hidden_depth,
                 log_std_bounds):
        super().__init__()

        self.log_std_bounds = log_std_bounds

        self.trunk = utils.mlp(fusion_dim, hidden_dim, 2 * action_dim,
                               hidden_depth)
        # print(self.trunk)
        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #17
0
    def __init__(self, obs_shape, feature_dim, action_shape, hidden_dim,
                 hidden_depth, log_std_bounds):
        super().__init__()

        self.encoder = Encoder(obs_shape, feature_dim)

        self.log_std_bounds = log_std_bounds
        self.trunk = utils.mlp(self.encoder.feature_dim, hidden_dim,
                               2 * action_shape[0], hidden_depth)

        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #18
0
    def __init__(self, encoder_cfg, action_shape, hidden_dim, hidden_depth,
                 log_std_bounds):
        super().__init__()

        self.encoder = hydra.utils.instantiate(encoder_cfg)

        self.log_std_bounds = log_std_bounds
        self.trunk = utils.mlp(self.encoder.feature_dim, hidden_dim,
                               2 * action_shape[0], hidden_depth)

        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #19
0
    def __init__(self, obs_dim, action_dim, hidden_channels,
                 log_std_bounds):
        super().__init__()

        # self.log_std_bounds = log_std_bounds
        log_std_min, log_std_max = log_std_bounds
        # conversion x:[-1,1]->y:[log_std_min, log_std_max] with y = k*x+b
        self.k = (log_std_max - log_std_min)/2.
        self.b = log_std_min + self.k
        
        self.trunk = utils.mlp(obs_dim, hidden_channels, 2 * action_dim)

        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #20
0
    def __init__(self, obs_shape, action_shape, hidden_dim, hidden_depth,
                 stddev, parameterization):
        super().__init__()

        assert parameterization in ['clipped', 'squashed']
        self.stddev = stddev
        self.dist_type = utils.SquashedNormal if parameterization == 'squashed' else utils.ClippedNormal

        self.trunk = utils.mlp(obs_shape[0],
                               hidden_dim,
                               action_shape[0],
                               hidden_depth,
                               use_ln=True)

        self.outputs = dict()
        self.apply(utils.weight_init)
Beispiel #21
0
    def __init__(
            self,
            from_images,
            observation_space,
            action_space,
            feature_dim,  # ignored for state-based
            hidden_sizes,
            log_std_bounds):
        super().__init__()

        if from_images:
            self.encoder = ConvolutionalEncoder(observation_space, feature_dim)
        else:
            self.encoder = ConcatenationEncoder(observation_space)

        self.log_std_bounds = log_std_bounds
        self.trunk = utils.mlp(
            self.encoder.output_dim, 2 * action_space.shape[0], hidden_sizes
        )  # twice the dimensions for outputs to split into mu and log_sigma
Beispiel #22
0
    def __init__(
        self,
        obs_dim,
        act_dim,
        hidden_sizes,
        act_limits=None,
        activation=nn.ReLU,
        output_activation=nn.ReLU,
    ):
        """Constructs all the necessary attributes for the Squashed Gaussian Actor
        object.

        Args:
            obs_dim (int): Dimension of the observation space.

            act_dim (int): Dimension of the action space.

            hidden_sizes (list): Sizes of the hidden layers.

            activation (torch.nn.modules.activation): The hidden layer activation
                function.

            output_activation (torch.nn.modules.activation, optional): The activation
                function used for the output layers. Defaults to torch.nn.Identity.

            act_limits (dict or , optional): The "high" and "low" action bounds of the
                environment. Used for rescaling the actions that comes out of network
                from (-1, 1) to (low, high). Defaults to (-1, 1).
        """
        super().__init__()
        # Set class attributes
        self.act_limits = act_limits

        # Create networks
        self.net = mlp([obs_dim] + list(hidden_sizes), activation,
                       output_activation)
        self.mu = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_sigma = nn.Linear(hidden_sizes[-1], act_dim)
Beispiel #23
0
    def __init__(self, vocab_size, mlp_arr, n_topic, learning_rate, batch_size,
                 non_linearity, adam_beta1, adam_beta2, dir_prior, n_class, N,
                 seed):
        np.random.seed(seed)
        tf.set_random_seed(seed)

        tf.reset_default_graph()
        self.vocab_size = vocab_size
        self.n_hidden = mlp_arr[0]
        self.n_topic = n_topic
        self.non_linearity = non_linearity
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_class = n_class

        self.y = tf.placeholder(tf.float32, [None, n_class], name='input_y')
        self.lab = tf.placeholder(tf.float32, [1], name='input_with_lab')
        self.x = tf.placeholder(tf.float32, [None, vocab_size], name='input')
        self.idx = tf.placeholder(tf.int32, [1], name="index")
        self.mask = tf.placeholder(tf.float32, [None],
                                   name='mask')  # mask paddings
        self.warm_up = tf.placeholder(tf.float32, (),
                                      name='warm_up')  # warm up
        self.training = tf.placeholder(tf.bool, (), name="training")
        self.adam_beta1 = adam_beta1
        self.adam_beta2 = adam_beta2
        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        self.prob = tf.placeholder(tf.float32, name='prob')

        self.min_alpha = tf.placeholder(tf.float32, (), name='min_alpha')
        back_mlp = False
        cat_distribution = False
        use_gaus = False

        # encoder
        with tf.variable_scope('encoder'):
            self.enc_input = tf.concat([self.x, self.y], axis=1)
            self.enc_vec = utils.mlp(self.enc_input, mlp_arr,
                                     self.non_linearity)

            self.enc_vec = tf.nn.dropout(self.enc_vec, self.keep_prob)
            self.mean = tf.contrib.layers.batch_norm(utils.linear(
                self.enc_vec, self.n_topic, scope='mean'),
                                                     is_training=self.training)
            if use_gaus:
                self.logsigm = tf.contrib.layers.batch_norm(
                    utils.linear(self.enc_vec, self.n_topic, scope='logsigm'),
                    is_training=self.training)

                self.kld = -0.5 * tf.reduce_sum(
                    1 - tf.square(self.mean) + 2 * self.logsigm -
                    tf.exp(2 * self.logsigm), 1)
                self.analytical_kld = self.mask * self.kld  # mask paddings
            else:

                self.alpha = tf.maximum(self.min_alpha,
                                        tf.log(1. + tf.exp(self.mean)))

                self.prior = tf.ones(
                    (batch_size, self.n_topic), dtype=tf.float32,
                    name='prior') * dir_prior

                self.analytical_kld = tf.lgamma(
                    tf.reduce_sum(self.alpha, axis=1)) - tf.lgamma(
                        tf.reduce_sum(self.prior, axis=1))
                self.analytical_kld -= tf.reduce_sum(tf.lgamma(self.alpha),
                                                     axis=1)
                self.analytical_kld += tf.reduce_sum(tf.lgamma(self.prior),
                                                     axis=1)
                minus = self.alpha - self.prior
                # test = tf.reduce_sum(minus,1)
                test = tf.reduce_sum(
                    tf.multiply(
                        minus,
                        tf.digamma(self.alpha) -
                        tf.reshape(tf.digamma(tf.reduce_sum(self.alpha, 1)),
                                   (batch_size, 1))), 1)
                self.analytical_kld += test
                self.analytical_kld = self.mask * self.analytical_kld  # mask paddings

            self.clss_mlp = utils.mlp(self.x,
                                      mlp_arr,
                                      self.non_linearity,
                                      scope="classifier_mlp")
            self.clss_mlp = tf.nn.dropout(self.clss_mlp, self.prob)
            self.phi = tf.contrib.layers.batch_norm(
                utils.linear(self.clss_mlp, n_class, scope='phi'),
                is_training=self.training)  #y logits

            # class propabilities
            one_hot_dist = tfd.OneHotCategorical(logits=self.phi)
            hot_out = tf.squeeze(one_hot_dist.sample(1))
            hot_out.set_shape(self.phi.get_shape())
            self.dec_input = tf.cast(hot_out, dtype=tf.float32)

            self.out_y = tf.nn.softmax(self.phi, name="probabilities_y")
            # do NOT use output of softmax here!
            self.clss_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=self.y, logits=self.phi) * 0.1 * N * self.mask

        with tf.variable_scope('decoder'):
            with tf.variable_scope('prob'):
                # Dirichlet
                if use_gaus:
                    eps = tf.random_normal((batch_size, self.n_topic), 0, 1)
                    self.doc_vec = tf.multiply(tf.exp(self.logsigm),
                                               eps) + self.mean
                else:
                    self.doc_vec = tf.squeeze(
                        tfd.Dirichlet(
                            self.alpha).sample(1))  # tf.shape(self.alpha)
                    self.doc_vec.set_shape(self.alpha.get_shape())
                    dir1 = tf.contrib.distributions.Dirichlet(self.prior)
                    dir2 = tf.contrib.distributions.Dirichlet(self.alpha)

                    self.kld = dir2.log_prob(self.doc_vec) - dir1.log_prob(
                        self.doc_vec)
            # reconstruction
            if cat_distribution:
                self.merge = tf.cond(
                    (tf.reduce_sum(self.lab)) > 0,
                    lambda: tf.concat([self.doc_vec, self.y], axis=1),
                    lambda: tf.concat([self.doc_vec, self.dec_input], axis=1))
            else:
                self.merge = tf.cond(
                    (tf.reduce_sum(self.lab)) > 0,
                    lambda: tf.concat([self.doc_vec, self.y], axis=1),
                    lambda: tf.concat([self.doc_vec, self.out_y], axis=1))

            if not back_mlp:
                logits = tf.nn.log_softmax(
                    tf.contrib.layers.batch_norm(utils.linear(
                        self.merge,
                        self.vocab_size,
                        scope='projection',
                        no_bias=True),
                                                 is_training=self.training))
            else:
                # this might
                logits = tf.nn.log_softmax(
                    utils.mlp(self.merge,
                              list(reversed(mlp_arr)) + [self.vocab_size],
                              scope='projection'))

            self.recons_loss = -tf.reduce_sum(tf.multiply(logits, self.x), 1)

        self.min_l = self.recons_loss + self.warm_up * self.analytical_kld
        self.min_l_analytical = self.recons_loss + self.analytical_kld
        self.out_y_col = tf.transpose(
            tf.gather(tf.transpose(self.out_y), indices=self.idx))
        self.cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=self.phi, labels=self.out_y) * self.mask

        self.objective = tf.cond((tf.reduce_sum(self.lab)) > 0,
                                 lambda: self.min_l + self.clss_loss,
                                 lambda: self.min_l + self.cross_entropy
                                 )  #tf.multiply(self.min_l,self.out_y_col)

        self.analytical_objective = tf.cond(
            (tf.reduce_sum(self.lab)) > 0,
            lambda: self.min_l_analytical + self.clss_loss,
            lambda: self.min_l_analytical + self.cross_entropy
        )  #tf.multiply(self.min_l_analytical,self.out_y_col)

        fullvars = tf.trainable_variables()

        enc_vars = utils.variable_parser(fullvars, 'encoder')
        dec_vars = utils.variable_parser(fullvars, 'decoder')

        # this is the standard gradient for the reconstruction network
        dec_grads = tf.gradients(self.objective, dec_vars)
        enc_grads = tf.gradients(self.objective, enc_vars)

        optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate,
                                           beta1=self.adam_beta1,
                                           beta2=self.adam_beta2)
        self.optim_enc = optimizer.apply_gradients(zip(enc_grads, enc_vars))
        self.optim_dec = optimizer.apply_gradients(zip(dec_grads, dec_vars))

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            self.optim_all = optimizer.apply_gradients(
                list(zip(enc_grads, enc_vars)) +
                list(zip(dec_grads, dec_vars)))
            #self.optim_all=optimizer.minimize(self.objective)
        self.merged = tf.summary.merge_all()
Beispiel #24
0
    def __init__(self, vocab_size, n_hidden, n_topic, n_sample, learning_rate,
                 batch_size, non_linearity):
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        self.n_topic = n_topic
        self.n_sample = n_sample
        self.non_linearity = non_linearity
        self.learning_rate = learning_rate
        self.batch_size = batch_size

        self.x = tf.placeholder(tf.float32, [None, vocab_size], name='input')
        self.mask = tf.placeholder(tf.float32, [None],
                                   name='mask')  # mask paddings

        # encoder
        with tf.variable_scope('encoder'):
            self.enc_vec = utils.mlp(self.x, [self.n_hidden],
                                     self.non_linearity)
            self.mean = utils.linear(self.enc_vec, self.n_topic, scope='mean')
            self.logsigm = utils.linear(self.enc_vec,
                                        self.n_topic,
                                        bias_start_zero=True,
                                        matrix_start_zero=True,
                                        scope='logsigm')
            self.kld = -0.5 * tf.reduce_sum(
                1 - tf.square(self.mean) + 2 * self.logsigm -
                tf.exp(2 * self.logsigm), 1)
            self.kld = self.mask * self.kld  # mask paddings

        with tf.variable_scope('decoder'):
            if self.n_sample == 1:  # single sample
                eps = tf.random_normal((batch_size, self.n_topic), 0, 1)
                doc_vec = tf.multiply(tf.exp(self.logsigm), eps) + self.mean
                logits = tf.nn.log_softmax(
                    utils.linear(doc_vec, self.vocab_size, scope='projection'))
                self.recons_loss = -tf.reduce_sum(tf.multiply(logits, self.x),
                                                  1)
            # multiple samples
            else:
                eps = tf.random_normal(
                    (self.n_sample * batch_size, self.n_topic), 0, 1)
                eps_list = tf.split(0, self.n_sample, eps)
                recons_loss_list = []
                for i in range(self.n_sample):
                    if i > 0: tf.get_variable_scope().reuse_variables()
                    curr_eps = eps_list[i]
                    doc_vec = tf.multiply(tf.exp(self.logsigm),
                                          curr_eps) + self.mean
                    logits = tf.nn.log_softmax(
                        utils.linear(doc_vec,
                                     self.vocab_size,
                                     scope='projection'))
                    recons_loss_list.append(
                        -tf.reduce_sum(tf.multiply(logits, self.x), 1))
                self.recons_loss = tf.add_n(recons_loss_list) / self.n_sample

        self.objective = self.recons_loss + self.kld

        optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
        fullvars = tf.trainable_variables()

        enc_vars = utils.variable_parser(fullvars, 'encoder')
        dec_vars = utils.variable_parser(fullvars, 'decoder')

        enc_grads = tf.gradients(self.objective, enc_vars)
        dec_grads = tf.gradients(self.objective, dec_vars)

        self.optim_enc = optimizer.apply_gradients(zip(enc_grads, enc_vars))
        self.optim_dec = optimizer.apply_gradients(zip(dec_grads, dec_vars))
    def __init__(self, 
                 vocab_size,
                 n_hidden,
                 n_topic, 
                 n_sample,
                 learning_rate, 
                 batch_size,
                 non_linearity,
                 constrained,
                 adam_beta1,
                 adam_beta2,
                 B,
                 dir_prior,
                 correction):
        tf.reset_default_graph()
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        self.n_topic = n_topic
        self.n_sample = n_sample
        self.non_linearity = non_linearity
        self.learning_rate = learning_rate
        self.batch_size = batch_size

        lda=False
        self.x = tf.placeholder(tf.float32, [None, vocab_size], name='input')
        self.mask = tf.placeholder(tf.float32, [None], name='mask')  # mask paddings
        self.warm_up = tf.placeholder(tf.float32, (), name='warm_up')  # warm up
        self.B=tf.placeholder(tf.int32, (), name='B')
        self.adam_beta1=adam_beta1
        self.adam_beta2=adam_beta2
        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        self.min_alpha = tf.placeholder(tf.float32,(), name='min_alpha')
        self.constrained = constrained
        # encoder
        with tf.variable_scope('encoder'): 
          self.enc_vec = utils.mlp(self.x, [self.n_hidden], self.non_linearity)
          self.enc_vec = tf.nn.dropout(self.enc_vec,self.keep_prob)
          self.mean = tf.contrib.layers.batch_norm(utils.linear(self.enc_vec, self.n_topic, scope='mean'))
          zero = tf.constant(0, dtype=tf.float32)

          self.bernoulli = diff_round(tf.nn.sigmoid(tf.contrib.layers.batch_norm(utils.linear(self.enc_vec, self.n_topic, scope='bernoulli'))))
          tf.summary.histogram('mean', self.mean)
          if constrained:
            self.alpha =tf.maximum(self.mean,1e-2)
          else:
            self.alpha = tf.maximum(0.01,tf.log(1.+tf.exp(self.mean)))
          
          #Dirichlet prior alpha0
          self.prior = tf.ones((batch_size,self.n_topic), dtype=tf.float32, name='prior')*dir_prior
         
          self.analytical_kld = tf.lgamma(tf.reduce_sum(self.bernoulli*self.alpha,axis=1)+1e-10)-tf.lgamma(tf.reduce_sum(self.bernoulli*self.prior,axis=1)+1e-10)
          self.analytical_kld-=tf.reduce_sum(self.bernoulli*tf.lgamma(self.alpha),axis=1)
          self.analytical_kld+=tf.reduce_sum(self.bernoulli*tf.lgamma(self.prior),axis=1)
          minus = self.alpha-self.prior
          test = tf.reduce_sum(tf.multiply(self.bernoulli*minus,self.bernoulli*tf.digamma(self.alpha)-tf.reshape(tf.digamma(tf.reduce_sum(self.alpha*self.bernoulli,1)+1e-10),(batch_size,1))),1)
          self.analytical_kld+=test
          self.analytical_kld = self.mask*self.analytical_kld  # mask paddings
          
        with tf.variable_scope('decoder'):
          if self.n_sample ==1:  # single sample
            #sample gammas
            gam = tf.squeeze(tf.random_gamma(shape = (1,),alpha=self.alpha+tf.to_float(self.B)))
            #reverse engineer the random variables used in the gamma rejection sampler
            eps = tf.stop_gradient(calc_epsilon(gam,self.alpha+tf.to_float(self.B)))
            #uniform variables for shape augmentation of gamma
            u = tf.random_uniform((self.B,batch_size,self.n_topic))
            with tf.variable_scope('prob'):
                #this is the sampled gamma for this document, boosted to reduce the variance of the gradient
                self.doc_vec = self.bernoulli*gamma_h_boosted(eps,u,self.alpha,self.B)
                #normalize
                self.doc_vec = tf.div(self.doc_vec,tf.reshape(tf.reduce_sum(self.doc_vec,1)+1e-10, (-1, 1)))
                tf.summary.histogram('doc_vec', self.doc_vec)
                self.doc_vec.set_shape(self.alpha.get_shape())
            #reconstruction
            if lda:
              logits = tf.log(tf.clip_by_value(utils.linear_LDA(self.doc_vec, self.vocab_size, scope='projection',no_bias=True),1e-10,1.0))
            else:
              logits = tf.nn.log_softmax(tf.contrib.layers.batch_norm(utils.linear(self.doc_vec, self.vocab_size, scope='projection',no_bias=True)))
            self.recons_loss = -tf.reduce_sum(tf.multiply(logits, self.x), 1)
            self.kld = log_dirichlet(self.doc_vec,self.alpha,self.bernoulli)-log_dirichlet(self.doc_vec,self.prior,self.bernoulli)
          # multiple samples
          else:
            gam = tf.squeeze(tf.random_gamma(shape = (self.n_sample,),alpha=self.alpha+tf.to_float(self.B)))
            u = tf.random_uniform((self.n_sample,self.B,batch_size,self.n_topic))
            recons_loss_list = []
            kld_list = []
            for i in range(self.n_sample):
              if i > 0: tf.get_variable_scope().reuse_variables()
              curr_gam = gam[i]
              eps = tf.stop_gradient(calc_epsilon(curr_gam,self.alpha+tf.to_float(self.B)))
              curr_u = u[i]
              self.doc_vec = self.bernoulli*gamma_h_boosted(eps,curr_u,self.alpha,self.B)
              self.doc_vec = tf.div(self.doc_vec,tf.reshape(tf.reduce_sum(self.doc_vec,1), (-1, 1)))
              self.doc_vec.set_shape(self.alpha.get_shape())
              
              if lda:
                logits = tf.log(tf.clip_by_value(utils.linear_LDA(self.doc_vec, self.vocab_size, scope='projection',no_bias=True),1e-10,1.0))
              else:
                logits = tf.nn.log_softmax(tf.contrib.layers.batch_norm(utils.linear(self.doc_vec, self.vocab_size, scope='projection',no_bias=True),scope ='projection'))
              loss = -tf.reduce_sum(tf.multiply(logits, self.x), 1)
              loss2 = tf.stop_gradient(-tf.reduce_sum(tf.multiply(logits, self.x), 1))
              recons_loss_list.append(loss)
              kld = log_dirichlet(self.doc_vec,self.alpha,self.bernoulli)-log_dirichlet(self.doc_vec,self.prior,self.bernoulli)
              kld_list.append(kld)
            self.recons_loss = tf.add_n(recons_loss_list) / self.n_sample
            self.kld = tf.add_n(kld_list) / self.n_sample
        
        self.objective = self.recons_loss + self.warm_up*self.kld
        self.true_objective = self.recons_loss + self.kld
        self.analytical_objective = self.recons_loss+self.analytical_kld
        tf.summary.scalar('objective', tf.exp(tf.reduce_sum(self.true_objective)/tf.reduce_sum(self.x)))
        fullvars = tf.trainable_variables()

        enc_vars = utils.variable_parser(fullvars, 'encoder')
        dec_vars = utils.variable_parser(fullvars, 'decoder')
        
        #this is the standard gradient for the reconstruction network
        dec_grads = tf.gradients(self.objective, dec_vars)
        
        #####################################################
        #Now calculate the gradient for the encoding network#
        #####################################################
       
        #again redefine some stuff for proper gradient back propagation
        if self.n_sample ==1:
          gammas = self.bernoulli*gamma_h_boosted(eps,u,self.alpha,self.B)
          self.doc_vec = tf.div(gammas,tf.reshape(tf.reduce_sum(gammas,1), (-1, 1))+1e-10)
          self.doc_vec.set_shape(self.alpha.get_shape())
          with tf.variable_scope("decoder", reuse=True):
              logits2 = tf.nn.log_softmax(tf.contrib.layers.batch_norm(utils.linear(self.doc_vec, self.vocab_size, scope='projection',no_bias=True)))
              self.recons_loss2 = -tf.reduce_sum(tf.multiply(logits2, self.x), 1)
              self.kld2 = log_dirichlet(self.doc_vec,self.alpha,self.bernoulli)-log_dirichlet(self.doc_vec,self.prior,self.bernoulli)
        else:
          with tf.variable_scope("decoder", reuse=True):
            recons_loss_list2 = []
            kld_list2 = []
            
            for i in range(self.n_sample):
              curr_gam = gam[i]
              eps = tf.stop_gradient(calc_epsilon(curr_gam,self.alpha+tf.to_float(self.B)))
              curr_u = u[i]
              self.doc_vec = self.bernoulli*gamma_h_boosted(eps,curr_u,self.alpha,self.B)
              self.doc_vec = tf.div(self.doc_vec,tf.reshape(tf.reduce_sum(self.doc_vec,1), (-1, 1)))
              self.doc_vec.set_shape(self.alpha.get_shape())
              if lda:
                logits2 = tf.log(tf.clip_by_value(utils.linear_LDA(self.doc_vec, self.vocab_size, scope='projection',no_bias=True),1e-10,1.0))
              else:
                logits2 = tf.nn.log_softmax(tf.contrib.layers.batch_norm(utils.linear(self.doc_vec, self.vocab_size, scope='projection',no_bias=True),scope ='projection'))
              loss = -tf.reduce_sum(tf.multiply(logits2, self.x), 1)
              recons_loss_list2.append(loss)
              prior_sample = tf.squeeze(tf.random_gamma(shape = (1,),alpha=self.prior))
              prior_sample = tf.div(prior_sample,tf.reshape(tf.reduce_sum(prior_sample,1), (-1, 1)))
              kld2 = log_dirichlet(self.doc_vec,self.alpha,self.bernoulli)-log_dirichlet(self.doc_vec,self.prior,self.bernoulli)
              kld_list2.append(kld2)
            self.recons_loss2 = tf.add_n(recons_loss_list2) / self.n_sample
            
            self.kld2 = tf.add_n(kld_list2)/self.n_sample
            
        kl_grad = tf.gradients(self.kld2,enc_vars)
            
        #this is the gradient we would use if the rejection sampler for the Gamma would always accept
        g_rep = tf.gradients(self.recons_loss2,enc_vars)
        
        #now define the gradient for the correction part
        reshaped1 = tf.reshape(self.recons_loss,(batch_size,1))
        reshaped2 = tf.reshape(self.recons_loss,(batch_size,1,1))
        reshaped21 = tf.reshape(self.kld,(batch_size,1))
        reshaped22 = tf.reshape(self.kld,(batch_size,1,1))
        if not correction:
          enc_grads = [g_r+self.warm_up*g_e for g_r,g_e in zip(g_rep,kl_grad)]#+g_c
       
        optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate,beta1=self.adam_beta1,beta2=self.adam_beta2)
        self.optim_enc = optimizer.apply_gradients(zip(enc_grads, enc_vars))
        self.optim_dec = optimizer.apply_gradients(zip(dec_grads, dec_vars))
        self.optim_all = optimizer.apply_gradients(list(zip(enc_grads, enc_vars))+list(zip(dec_grads, dec_vars)))
    def __init__(self, vocab_size, n_hidden, n_topic, learning_rate,
                 batch_size, non_linearity, adam_beta1, adam_beta2, dir_prior):
        tf.reset_default_graph()
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        self.n_topic = n_topic
        self.non_linearity = non_linearity
        self.learning_rate = learning_rate
        self.batch_size = batch_size

        lda = False
        self.x = tf.placeholder(tf.float32, [None, vocab_size], name='input')
        self.mask = tf.placeholder(tf.float32, [None],
                                   name='mask')  # mask paddings
        self.warm_up = tf.placeholder(tf.float32, (),
                                      name='warm_up')  # warm up
        self.adam_beta1 = adam_beta1
        self.adam_beta2 = adam_beta2
        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        self.min_alpha = tf.placeholder(tf.float32, (), name='min_alpha')
        # encoder
        with tf.variable_scope('encoder'):
            self.enc_vec = utils.mlp(self.x, [self.n_hidden],
                                     self.non_linearity)
            self.enc_vec = tf.nn.dropout(self.enc_vec, self.keep_prob)
            self.mean = tf.contrib.layers.batch_norm(
                utils.linear(self.enc_vec, self.n_topic, scope='mean'))
            self.alpha = tf.maximum(self.min_alpha,
                                    tf.log(1. + tf.exp(self.mean)))
            #Dirichlet prior alpha0
            self.prior = tf.ones(
                (batch_size, self.n_topic), dtype=tf.float32,
                name='prior') * dir_prior

            self.analytical_kld = tf.lgamma(tf.reduce_sum(
                self.alpha, axis=1)) - tf.lgamma(
                    tf.reduce_sum(self.prior, axis=1))
            self.analytical_kld -= tf.reduce_sum(tf.lgamma(self.alpha), axis=1)
            self.analytical_kld += tf.reduce_sum(tf.lgamma(self.prior), axis=1)
            minus = self.alpha - self.prior
            test = tf.reduce_sum(
                tf.multiply(
                    minus,
                    tf.digamma(self.alpha) -
                    tf.reshape(tf.digamma(tf.reduce_sum(self.alpha, 1)),
                               (batch_size, 1))), 1)
            self.analytical_kld += test
            self.analytical_kld = self.mask * self.analytical_kld  # mask paddings
            max_kld = tf.argmax(self.analytical_kld, 0)

        with tf.variable_scope('decoder'):
            with tf.variable_scope('prob'):
                #Dirichlet
                self.doc_vec = tf.squeeze(tfd.Dirichlet(
                    self.alpha).sample(1))  #tf.shape(self.alpha)
                self.doc_vec.set_shape(self.alpha.get_shape())
            #reconstruction
            if lda:
                logits = tf.log(
                    tf.clip_by_value(
                        utils.linear_LDA(self.doc_vec,
                                         self.vocab_size,
                                         scope='projection',
                                         no_bias=True), 1e-10, 1.0))
            else:
                logits = tf.nn.log_softmax(
                    tf.contrib.layers.batch_norm(
                        utils.linear(self.doc_vec,
                                     self.vocab_size,
                                     scope='projection',
                                     no_bias=True)))
            self.recons_loss = -tf.reduce_sum(tf.multiply(logits, self.x), 1)

            dir1 = tf.contrib.distributions.Dirichlet(self.prior)
            dir2 = tf.contrib.distributions.Dirichlet(self.alpha)

            self.kld = dir2.log_prob(self.doc_vec) - dir1.log_prob(
                self.doc_vec)
            max_kld_sampled = tf.arg_max(self.kld, 0)

        self.objective = self.recons_loss + self.warm_up * self.analytical_kld
        self.true_objective = self.recons_loss + self.kld

        self.analytical_objective = self.recons_loss + self.analytical_kld

        fullvars = tf.trainable_variables()

        enc_vars = utils.variable_parser(fullvars, 'encoder')
        dec_vars = utils.variable_parser(fullvars, 'decoder')

        #this is the standard gradient for the reconstruction network
        dec_grads = tf.gradients(self.objective, dec_vars)

        #####################################################
        #Now calculate the gradient for the encoding network#
        #####################################################

        kl_grad = tf.gradients(self.analytical_kld, enc_vars)
        g_rep = tf.gradients(self.recons_loss, enc_vars)
        enc_grads = [
            g_r + self.warm_up * g_e for g_r, g_e in zip(g_rep, kl_grad)
        ]

        optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate,
                                           beta1=self.adam_beta1,
                                           beta2=self.adam_beta2)
        self.optim_enc = optimizer.apply_gradients(zip(enc_grads, enc_vars))
        self.optim_dec = optimizer.apply_gradients(zip(dec_grads, dec_vars))
        self.optim_all = optimizer.apply_gradients(
            list(zip(enc_grads, enc_vars)) + list(zip(dec_grads, dec_vars)))
Beispiel #27
0
    def __init__(self, vocab_size, n_hidden, n_topic, n_sample, learning_rate,
                 batch_size, n_householder, non_linearity):
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        self.n_topic = n_topic
        self.n_sample = n_sample
        self.non_linearity = non_linearity
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_householder = n_householder

        self.x = tf.placeholder(tf.float32, [None, vocab_size], name='input')
        self.mask = tf.placeholder(tf.float32, [None],
                                   name='mask')  # mask paddings

        # encoder
        with tf.variable_scope('encoder'):
            self.enc_vec = utils.mlp(self.x, [self.n_hidden],
                                     self.non_linearity)
            self.mean = utils.linear(self.enc_vec, self.n_topic, scope='mean')
            self.logsigm = utils.linear(self.enc_vec,
                                        self.n_topic,
                                        bias_start_zero=True,
                                        matrix_start_zero=True,
                                        scope='logsigm')

            # -------------------- cal the householder matrix -------------------------------
            self.tmp_mean = tf.expand_dims(
                tf.expand_dims(
                    tf.rsqrt(tf.reduce_sum(tf.square(self.mean), 1)), 1) *
                self.mean, 2)
            self.tmp_mean_t = tf.transpose(self.tmp_mean, perm=[0, 2, 1])
            self.vk = self.tmp_mean
            self.Hk = tf.expand_dims(tf.eye(self.n_topic), 0) - \
                      2 * tf.matmul(self.tmp_mean, self.tmp_mean_t)

            self.U = self.Hk
            self.tmp_vk = self.vk
            self.invalid = []
            self.vk_show = tf.constant(-1.0)
            for k in range(1, self.n_householder + 1):
                self.tmp_vk = self.vk
                self.tmp_vk = tf.expand_dims(
                    tf.rsqrt(tf.reduce_sum(tf.square(self.tmp_vk), 1)) *
                    tf.squeeze(self.tmp_vk, [2, 2]), 2)
                self.vk = tf.matmul(self.Hk, self.vk)
                self.Hk = tf.expand_dims(tf.eye(self.n_topic), 0) - \
                          2 * tf.matmul(self.tmp_vk, tf.transpose(self.tmp_vk, perm=[0, 2, 1]))
                self.U = tf.matmul(self.U, self.Hk)

            self.Umean = tf.squeeze(tf.matmul(self.U, self.tmp_mean), [2, 2])

            # ------------------------ KL divergence after Householder -------------------------------------
            self.kld = -0.5 * (tf.reduce_sum(
                1 - tf.square(self.Umean) + 2 * self.logsigm, 1) - \
                               tf.trace(tf.matmul(tf.transpose(tf.multiply(tf.expand_dims(tf.exp(2 * self.logsigm), 2),
                                                                           tf.transpose(self.U, perm=[0, 2, 1])),
                                                               perm=[0, 2, 1]), tf.transpose(self.U, perm=[0, 2, 1]))))
            # kk = tf.trace(tf.matmul(tf.transpose(tf.multiply(tf.expand_dims(tf.exp(2 * self.logsigm), 2), tf.transpose(self.U, perm=[0,2,1])), perm=[0,2,1]), tf.transpose(self.U, perm=[0,2,1])))
            self.log_squre = tf.trace(
                tf.matmul(
                    tf.transpose(tf.multiply(
                        tf.expand_dims(tf.exp(2 * self.logsigm), 2),
                        tf.transpose(self.U, perm=[0, 2, 1])),
                                 perm=[0, 2, 1]),
                    tf.transpose(self.U, perm=[0, 2, 1])))
            self.mean_squre = tf.reduce_sum(tf.square(self.Umean), 1)
            self.kld = self.mask * self.kld  # mask paddings

            if self.n_sample == 1:  # single sample
                eps = tf.random_normal((batch_size, self.n_topic), 0, 1)
                doc_vec = tf.multiply(tf.exp(self.logsigm), eps) + self.mean
            else:
                doc_vec_list = []
                for i in range(self.n_sample):
                    epsilon = tf.random_normal((self.batch_size, self.n_topic),
                                               0, 1)
                    doc_vec_list.append(
                        self.mean + tf.multiply(epsilon, tf.exp(self.logsigm)))
                doc_vec = tf.add_n(doc_vec_list) / self.n_sample

            doc_vec = tf.squeeze(tf.matmul(self.U, tf.expand_dims(doc_vec, 2)))
            self.theta = tf.nn.softmax(doc_vec)

        with tf.variable_scope('decoder'):
            topic_vec = tf.get_variable('topic_vec',
                                        shape=[self.n_topic, self.n_hidden])
            word_vec = tf.get_variable('word_vec',
                                       shape=[self.vocab_size, self.n_hidden])

            # n_topic x vocab_size
            beta = tf.matmul(topic_vec, tf.transpose(word_vec))

            logits = tf.nn.log_softmax(tf.matmul(doc_vec, beta))

            self.beta = tf.nn.softmax(beta)

            mean = tf.reduce_mean(self.theta, -1, keep_dims=True)  # bs x 1
            self.variance = tf.sqrt(
                tf.reduce_sum(
                    tf.square(self.theta - tf.tile(mean, [1, self.n_topic])),
                    -1) / self.n_topic)

            self.recons_loss = -tf.reduce_sum(tf.multiply(logits, self.x), 1)

        self.objective = self.recons_loss + self.kld
        self.loss_func = self.objective

        optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
        fullvars = tf.trainable_variables()

        enc_vars = utils.variable_parser(fullvars, 'encoder')
        dec_vars = utils.variable_parser(fullvars, 'decoder')

        enc_grads, _ = tf.clip_by_global_norm(
            tf.gradients(self.loss_func, enc_vars), 5)
        dec_grads, _ = tf.clip_by_global_norm(
            tf.gradients(self.loss_func, dec_vars), 5)

        self.optim_enc = optimizer.apply_gradients(zip(enc_grads, enc_vars))
        self.optim_dec = optimizer.apply_gradients(zip(dec_grads, dec_vars))
Beispiel #28
0
    def __init__(self, latent_dim, hidden_dim, hidden_depth, action_dim):
        super().__init__()

        self.linear_module = utils.mlp(latent_dim, hidden_dim, action_dim,
                                       hidden_depth)
Beispiel #29
0
    def __init__(self, fusion_dim, hidden_dim, hidden_depth, latent_dim):
        super().__init__()

        self.linear_module = utils.mlp(fusion_dim * 2, hidden_dim, latent_dim,
                                       hidden_depth)
Beispiel #30
0
    def tower(self, x, y, reuse=False, scope="mlp_tower", labeled=False):
        cat_dis = False
        with tf.variable_scope(scope) as sc:
            if reuse:
                sc.reuse_variables()
            enc_input = tf.concat([x, y], axis=1)
            enc_vec = utils.mlp(enc_input,
                                self.mlp_arr,
                                self.non_linearity,
                                is_training=self.training)

            enc_vec = tf.nn.dropout(enc_vec, self.keep_prob)
            mean = tf.contrib.layers.batch_norm(utils.linear(enc_vec,
                                                             self.n_topic,
                                                             scope='mean'),
                                                is_training=self.training)
            alpha = tf.maximum(self.min_alpha, tf.log(1. + tf.exp(mean)))

            prior = tf.ones((self.batch_size, self.n_topic),
                            dtype=tf.float32,
                            name='prior') * self.dir_prior

            analytical_kld = tf.lgamma(tf.reduce_sum(
                alpha, axis=1)) - tf.lgamma(tf.reduce_sum(prior, axis=1))
            analytical_kld -= tf.reduce_sum(tf.lgamma(alpha), axis=1)
            analytical_kld += tf.reduce_sum(tf.lgamma(prior), axis=1)
            minus = alpha - prior

            test = tf.reduce_sum(
                tf.multiply(
                    minus,
                    tf.digamma(alpha) -
                    tf.reshape(tf.digamma(tf.reduce_sum(alpha, 1)),
                               (self.batch_size, 1))), 1)
            analytical_kld += test
            analytical_kld = self.mask * analytical_kld  # mask paddings

            clss_mlp = utils.mlp(x,
                                 self.mlp_arr,
                                 self.non_linearity,
                                 scope="classifier_mlp",
                                 is_training=self.training)
            clss_mlp = tf.nn.dropout(clss_mlp, self.keep_prob)
            phi = tf.contrib.layers.batch_norm(
                utils.linear(clss_mlp, self.n_class, scope='phi'),
                is_training=self.training)  # y logits
            out_y = tf.nn.softmax(phi, name="probabilities_y")

            if cat_dis:
                one_hot_dist = tfd.OneHotCategorical(logits=phi)
                hot_out = tf.squeeze(one_hot_dist.sample(1))
                hot_out.set_shape(phi.get_shape())
                cat_input = tf.cast(hot_out, dtype=tf.float32)
            else:
                cat_input = out_y

            # Dirichlet
            doc_vec = tf.squeeze(
                tfd.Dirichlet(alpha).sample(1))  # tf.shape(self.alpha)
            doc_vec.set_shape(alpha.get_shape())
            # reconstruction
            if labeled:
                merge = tf.concat([doc_vec, y], axis=1)
            else:
                merge = tf.concat([doc_vec, cat_input], axis=1)

            logits = tf.nn.log_softmax(
                tf.contrib.layers.batch_norm(utils.linear(merge,
                                                          self.vocab_size,
                                                          scope='projection',
                                                          no_bias=True),
                                             is_training=self.training))

            recons_loss = -tf.reduce_sum(tf.multiply(logits, x), 1)

            L_x_y = recons_loss + self.warm_up * analytical_kld

        return L_x_y, phi, out_y