Пример #1
0
class ModelVAE(object):
    def __init__(self,
                 x,
                 h_dim,
                 z_dim,
                 activation=tf.nn.relu,
                 distribution='normal'):
        """
        ModelVAE initializer

        :param x: placeholder for input
        :param h_dim: dimension of the hidden layers
        :param z_dim: dimension of the latent representation
        :param activation: callable activation function
        :param distribution: string either `normal` or `vmf`, indicates which distribution to use
        """
        self.x, self.h_dim, self.z_dim, self.activation, self.distribution = x, h_dim, z_dim, activation, distribution

        self.z_mean, self.z_var = self._encoder(self.x)

        if distribution == 'normal':
            self.q_z = tf.distributions.Normal(self.z_mean, self.z_var)
        elif distribution == 'vmf':
            self.q_z = VonMisesFisher(self.z_mean, self.z_var)
        else:
            raise NotImplemented

        self.z = self.q_z.sample()

        self.logits = self._decoder(self.z)

    def _encoder(self, x):
        """
        Encoder network

        :param x: placeholder for input
        :return: tuple `(z_mean, z_var)` with mean and concentration around the mean
        """
        # 2 hidden layers encoder
        h0 = tf.layers.dense(x,
                             units=self.h_dim * 2,
                             activation=self.activation)
        h1 = tf.layers.dense(h0, units=self.h_dim, activation=self.activation)

        if self.distribution == 'normal':
            # compute mean and std of the normal distribution
            z_mean = tf.layers.dense(h1, units=self.z_dim, activation=None)
            z_var = tf.layers.dense(h1,
                                    units=self.z_dim,
                                    activation=tf.nn.softplus)
        elif self.distribution == 'vmf':
            # compute mean and concentration of the von Mises-Fisher
            z_mean = tf.layers.dense(
                h1,
                units=self.z_dim,
                activation=lambda x: tf.nn.l2_normalize(x, axis=-1))
            z_var = tf.layers.dense(h1, units=1, activation=tf.nn.softplus)
        else:
            raise NotImplemented

        return z_mean, z_var

    def _decoder(self, z):
        """
        Decoder network

        :param z: tensor, latent representation of input (x)
        :return: logits, `reconstruction = sigmoid(logits)`
        """
        # 2 hidden layers decoder
        h2 = tf.layers.dense(z, units=self.h_dim, activation=self.activation)
        h2 = tf.layers.dense(h2,
                             units=self.h_dim * 2,
                             activation=self.activation)
        logits = tf.layers.dense(h2, units=self.x.shape[-1], activation=None)

        return logits
Пример #2
0
class VariationalAutoEncoder(object):
    def __init__(self,
                 n_input_units,
                 n_hidden_layers,
                 n_hidden_units,
                 n_latent_units,
                 learning_rate=0.05,
                 batch_size=100,
                 min_beta=1.0,
                 max_beta=1.0,
                 distribution='normal',
                 serial_layering=None):
        self.n_input_units = n_input_units
        self.n_hidden_layers = n_hidden_layers
        self.n_hidden_units = n_hidden_units
        self.n_latent_units = n_latent_units
        self.learning_rate = learning_rate
        self.batch_size = int(batch_size)
        self.min_beta = min_beta
        self.max_beta = max_beta
        self.distribution = distribution
        if serial_layering:
            if not isinstance(serial_layering, (list, tuple)):
                raise TypeError(
                    "Argument 'serial_layering' must be a list or tuple of integers."
                )
            elif not all([isinstance(x, int) for x in serial_layering]):
                raise TypeError(
                    "Argument 'serial_layering' must be a list or tuple of integers."
                )
            elif sum(serial_layering) != self.n_hidden_layers:
                raise ValueError(
                    "Groupings in 'serial_layering' must sum to 'n_hidden_layers'."
                )
        self.serial_layering = serial_layering or [self.n_hidden_layers]
        self.layer_sequence = [
            sum(self.serial_layering[:i + 1])
            for i in range(len(self.serial_layering))
        ]

    class Encoder(object):
        def __init__(self,
                     n_hidden_layers,
                     n_hidden_units,
                     n_latent_units,
                     distribution,
                     initializers=None):
            self.n_hidden_layers = n_hidden_layers
            self.n_hidden_units = n_hidden_units
            self.n_latent_units = n_latent_units
            self.distribution = distribution
            self.initializers = initializers

        def init_hidden_layers(self):
            self.hidden_layers = []
            self.applied_hidden_layers = []

        def add_hidden_layer(self, inputs):
            if self.initializers and self.initializers.get('layers', None):
                print("initializing encoder layer...")
                kernel_initializer, bias_initializer = self.initializers[
                    'layers'].pop(0)
            else:
                kernel_initializer, bias_initializer = None, None

            self.hidden_layers.append(
                tf.layers.Dense(units=self.n_hidden_units,
                                activation=tf.nn.sigmoid,
                                kernel_initializer=kernel_initializer,
                                bias_initializer=bias_initializer))
            self.applied_hidden_layers.append(
                self.hidden_layers[-1].apply(inputs))
            return self.applied_hidden_layers[-1]

        def add_mu(self, inputs):
            if self.initializers and self.initializers.get('mu', None):
                print("initializing encoder mu...")
                kernel_initializer, bias_initializer = self.initializers['mu']
            else:
                kernel_initializer, bias_initializer = None, None

            if self.distribution == 'normal':
                self.mu = tf.layers.Dense(
                    units=self.n_latent_units,
                    kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer)
            elif self.distribution == 'vmf':
                self.mu = tf.layers.Dense(
                    units=self.n_latent_units + 1,
                    activation=lambda x: tf.nn.l2_normalize(x, axis=-1),
                    kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer)
            else:
                raise NotImplemented

            self.applied_mu = self.mu.apply(inputs)
            return self.applied_mu

        def add_sigma(self, inputs):
            if self.initializers and self.initializers.get('sigma', None):
                print("initializing encoder sigma...")
                kernel_initializer, bias_initializer = self.initializers[
                    'sigma']
            else:
                kernel_initializer, bias_initializer = None, None

            if self.distribution == 'normal':
                self.sigma = tf.layers.Dense(
                    units=self.n_latent_units,
                    kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer)
                self.applied_sigma = self.sigma.apply(inputs)
            elif self.distribution == 'vmf':
                self.sigma = tf.layers.Dense(
                    units=1,
                    activation=tf.nn.softplus,
                    kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer)
                self.applied_sigma = self.sigma.apply(inputs) + 1
            else:
                raise NotImplemented
            return self.applied_sigma

        def build(self, inputs):
            self.init_hidden_layers()

            layer = self.add_hidden_layer(inputs)

            for i in range(self.n_hidden_layers - 1):
                layer = self.add_hidden_layer(layer)

            mu = self.add_mu(layer)
            sigma = self.add_sigma(layer)

            return mu, sigma

        def eval(self, sess):
            layers = [sess.run([l.kernel, l.bias]) for l in self.hidden_layers]

            mu = sess.run([self.mu.kernel, self.mu.bias])

            sigma = sess.run([self.sigma.kernel, self.sigma.bias])

            return layers, mu, sigma

    class Decoder(object):
        def __init__(self,
                     n_hidden_layers,
                     n_hidden_units,
                     n_output_units,
                     initializers=None):
            self.n_hidden_layers = n_hidden_layers
            self.n_hidden_units = n_hidden_units
            self.n_output_units = n_output_units
            self.initializers = initializers

        def init_hidden_layers(self):
            self.hidden_layers = []
            self.applied_hidden_layers = []

        def add_hidden_layer(self, inputs):
            if self.initializers and self.initializers.get('layers', None):
                print("initializing decoder layer...")
                kernel_initializer, bias_initializer = self.initializers[
                    'layers'].pop(0)
            else:
                kernel_initializer, bias_initializer = None, None

            self.hidden_layers.append(
                tf.layers.Dense(units=self.n_hidden_units,
                                activation=tf.nn.sigmoid,
                                kernel_initializer=kernel_initializer,
                                bias_initializer=bias_initializer))
            self.applied_hidden_layers.append(
                self.hidden_layers[-1].apply(inputs))
            return self.applied_hidden_layers[-1]

        def add_output(self, inputs):
            if self.initializers and self.initializers.get('output', None):
                print("initializing decoder output...")
                kernel_initializer, bias_initializer = self.initializers[
                    'output']
            else:
                kernel_initializer, bias_initializer = None, None

            self.output = tf.layers.Dense(
                units=self.n_output_units,
                kernel_initializer=kernel_initializer,
                bias_initializer=bias_initializer)
            self.applied_output = self.output.apply(inputs)
            return self.applied_output

        def build(self, inputs):
            self.init_hidden_layers()

            layer = self.add_hidden_layer(inputs)

            for i in range(self.n_hidden_layers - 1):
                layer = self.add_hidden_layer(layer)

            output = self.add_output(layer)

            return output

        def eval(self, sess):
            layers = [sess.run([l.kernel, l.bias]) for l in self.hidden_layers]

            output = sess.run([self.output.kernel, self.output.bias])

            return layers, output

    def sampled_z(self, mu, sigma, batch_size):
        if self.distribution == 'normal':
            epsilon = tf.random_normal(
                tf.stack([int(batch_size), self.n_latent_units]))
            z = mu + tf.multiply(epsilon, tf.exp(0.5 * sigma))
            loss = tf.reduce_mean(
                -0.5 * self.beta *
                tf.reduce_sum(1.0 + sigma - tf.square(mu) - tf.exp(sigma), 1))
        elif self.distribution == 'vmf':
            self.q_z = VonMisesFisher(mu,
                                      sigma,
                                      validate_args=True,
                                      allow_nan_stats=False)
            z = self.q_z.sample()
            self.p_z = HypersphericalUniform(self.n_latent_units,
                                             validate_args=True,
                                             allow_nan_stats=False)
            loss = tf.reduce_mean(-self.q_z.kl_divergence(self.p_z))
        else:
            raise NotImplemented

        return z, loss

    def build_feature_loss(self, x, output):
        return tf.reduce_mean(
            tf.reduce_sum(tf.squared_difference(x, output), 1))

    def build_encoder_initializers(self, sess, n_hidden_layers):
        if hasattr(self, 'encoder'):
            result = {'layers': []}
            layers, mu, sigma = self.encoder.eval(sess)
            for i in range(n_hidden_layers):
                if layers:
                    kernel, bias = layers.pop(0)
                    result['layers'].append((tf.constant_initializer(kernel),
                                             tf.constant_initializer(bias)))
                else:
                    result['layers'].append(
                        (tf.constant_initializer(
                            np.diag(np.ones(self.n_latent_units))),
                         tf.constant_initializer(np.diag(np.ones(1)))))

            result['mu'] = (tf.constant_initializer(mu[0]),
                            tf.constant_initializer(mu[1]))
            result['sigma'] = (tf.constant_initializer(sigma[0]),
                               tf.constant_initializer(sigma[1]))
        else:
            result = None

        return result

    def build_decoder_initializers(self, sess, n_hidden_layers):
        if hasattr(self, 'decoder'):
            result = {'layers': []}
            layers, output = self.decoder.eval(sess)
            for i in range(n_hidden_layers):
                if layers:
                    kernel, bias = layers.pop(0)
                    result['layers'].append((tf.constant_initializer(kernel),
                                             tf.constant_initializer(bias)))
                else:
                    result['layers'].append(
                        (tf.constant_initializer(
                            np.diag(np.ones(self.n_latent_units))),
                         tf.constant_initializer(np.diag(np.ones(1)))))

            result['output'] = (tf.constant_initializer(output[0]),
                                tf.constant_initializer(output[1]))
        else:
            result = None

        return result

    def build_initializers(self, attr_name, sess, n_hidden_layers):
        if hasattr(self, attr_name):
            layers = getattr(self, attr_name).eval(sess)[0]
            result = []
            for i in range(n_hidden_layers):
                if layers:
                    kernel, bias = layers.pop(0)
                    result.append((tf.constant_initializer(kernel),
                                   tf.constant_initializer(bias)))
                else:
                    result.append(
                        (tf.constant_initializer(
                            np.diag(np.ones(self.n_latent_units))),
                         tf.constant_initializer(np.diag(np.ones(1)))))
            return result
        else:
            return None

    def initialize_tensors(self, sess, n_hidden_layers=None):
        n_hidden_layers = n_hidden_layers or self.n_hidden_layers

        self.x = tf.placeholder("float32",
                                [self.batch_size, self.n_input_units])
        self.beta = tf.placeholder("float32", [1, 1])
        self.encoder = self.Encoder(
            n_hidden_layers,
            self.n_hidden_units,
            self.n_latent_units,
            self.distribution,
            initializers=self.build_encoder_initializers(
                sess, n_hidden_layers))
        mu, sigma = self.encoder.build(self.x)
        self.mu = mu
        self.sigma = sigma

        z, latent_loss = self.sampled_z(self.mu, self.sigma, self.batch_size)
        self.z = z
        self.latent_loss = latent_loss

        self.decoder = self.Decoder(
            n_hidden_layers,
            self.n_hidden_units,
            self.n_input_units,
            initializers=self.build_decoder_initializers(
                sess, n_hidden_layers))
        self.output = self.decoder.build(self.z)

        self.feature_loss = self.build_feature_loss(self.x, self.output)
        self.loss = self.feature_loss + self.latent_loss

    def total_steps(self, data_count, epochs):
        num_batches = int(data_count / self.batch_size)
        return (num_batches * epochs) - epochs

    def generate_beta_values(self, total_steps):
        beta_delta = self.max_beta - self.min_beta
        log_beta_step = 5 / float(total_steps)
        beta_values = [
            self.min_beta + (beta_delta * (1 - math.exp(-5 +
                                                        (i * log_beta_step))))
            for i in range(total_steps)
        ]
        return beta_values

    def train_from_rdd(self, data_rdd, epochs=1):
        data_count = data_rdd.count()
        total_steps = self.total_steps(data_count, epochs)
        beta_values = self.generate_beta_values(total_steps)

        layer_sequence_step = int(total_steps / len(self.layer_sequence))
        layer_sequence = self.layer_sequence.copy()

        with tf.Session() as sess:
            batch_index = 0
            for epoch_index in range(epochs):
                iterator = data_rdd.toLocalIterator()
                while True:
                    if (not batch_index %
                            layer_sequence_step) and layer_sequence:
                        n_hidden_layers = layer_sequence.pop(0)
                        self.initialize_tensors(sess, n_hidden_layers)
                        optimizer = tf.train.AdamOptimizer(
                            self.learning_rate).minimize(self.loss)
                        sess.run(tf.global_variables_initializer())

                    batch = np.array(list(islice(iterator, self.batch_size)))
                    if batch.shape[0] == self.batch_size:
                        beta = beta_values.pop(
                            0) if len(beta_values) > 0 else self.min_beta
                        feed_dict = {
                            self.x: np.array(batch),
                            self.beta: np.array([[beta]])
                        }

                        if not batch_index % 1000:
                            print("beta: {}".format(beta))
                            print("number of hidden layers: {}".format(
                                n_hidden_layers))
                            ls, f_ls, d_ls = sess.run([
                                self.loss, self.feature_loss, self.latent_loss
                            ],
                                                      feed_dict=feed_dict)
                            print(
                                "loss={}, avg_feature_loss={}, avg_latent_loss={}"
                                .format(ls, np.mean(f_ls), np.mean(d_ls)))
                            print('running batch {} (epoch {})'.format(
                                batch_index, epoch_index))
                        sess.run(optimizer, feed_dict=feed_dict)
                        batch_index += 1
                    else:
                        print("incomplete batch: {}".format(batch.shape))
                        break

            print("evaluating model...")
            encoder_layers, eval_mu, eval_sigma = self.encoder.eval(sess)
            decoder_layers, eval_output = self.decoder.eval(sess)

        return VariationalAutoEncoderModel(encoder_layers, eval_mu, eval_sigma,
                                           decoder_layers, eval_output)

    def train(self, data, visualize=False, epochs=1):
        data_size = data.shape[0]
        batch_size = self.batch_size
        total_steps = self.total_steps(data_size, epochs)
        beta_values = self.generate_beta_values(total_steps)

        layer_sequence_step = int(total_steps / len(self.layer_sequence))
        layer_sequence = self.layer_sequence.copy()

        with tf.Session() as sess:
            for epoch_index in range(epochs):
                i = 0
                while (i * batch_size) < data_size:
                    if (not i % layer_sequence_step) and layer_sequence:
                        n_hidden_layers = layer_sequence.pop(0)
                        self.initialize_tensors(sess, n_hidden_layers)
                        optimizer = tf.train.AdamOptimizer(
                            self.learning_rate).minimize(self.loss)
                        sess.run(tf.global_variables_initializer())

                    batch = data[i * batch_size:(i + 1) * batch_size]
                    beta = beta_values.pop(
                        0) if len(beta_values) > 0 else self.min_beta
                    feed_dict = {self.x: batch, self.beta: np.array([[beta]])}
                    sess.run(optimizer, feed_dict=feed_dict)
                    if visualize and (not i % int((data_size / batch_size) / 3)
                                      or i == int(data_size / batch_size) - 1):
                        ls, d, f_ls, d_ls = sess.run([
                            self.loss, self.output, self.feature_loss,
                            self.latent_loss
                        ],
                                                     feed_dict=feed_dict)
                        plt.scatter(batch[:, 0], batch[:, 1])
                        plt.show()
                        plt.scatter(d[:, 0], d[:, 1])
                        plt.show()
                        print(i, ls, np.mean(f_ls), np.mean(d_ls))

                    i += 1

            encoder_layers, eval_mu, eval_sigma = self.encoder.eval(sess)
            decoder_layers, eval_output = self.decoder.eval(sess)

        return VariationalAutoEncoderModel(encoder_layers, eval_mu, eval_sigma,
                                           decoder_layers, eval_output)
Пример #3
0
class ExplicitAE(object):
    def __init__(self,
                 x,
                 h_dim,
                 z_dim,
                 activation=tf.nn.relu,
                 distribution='normal',
                 rescale_sph_latent=False):
        """
        :param x: placeholder for input
        :param h_dim: dimension of the hidden layers
        :param z_dim: dimension of the latent representation
        :param activation: callable activation function
        :param distribution: string either `normal` or `vmf`, indicates which distribution to use
        """
        self.x, self.h_dim, self.z_dim, self.activation, self.distribution = x, h_dim, z_dim, \
            activation, distribution
        self.rescale_sph_latent = rescale_sph_latent

        self.z_mean, self.z_var = self._encoder(self.x)

        if distribution == 'normal':
            self.q_z = tf.distributions.Normal(self.z_mean, self.z_var)
        elif distribution == 'vmf':
            self.q_z = VonMisesFisher(self.z_mean, self.z_var)
        else:
            raise NotImplemented

        self.z = self.q_z.sample()

        self.logits = self._decoder(self.z)

    def _encoder(self, x):
        """
        Encoder network

        :param x: placeholder for input
        :return: tuple `(z_mean, z_var)` with mean and concentration around the mean
        """

        with tf.variable_scope(ENCODER, reuse=AUTO_REUSE):
            # 2 hidden layers encoder
            h0 = tf.layers.dense(x,
                                 units=self.h_dim,
                                 activation=self.activation)
            h1 = tf.layers.dense(h0,
                                 units=self.h_dim,
                                 activation=self.activation)

            if self.distribution == 'normal':
                # compute mean and std of the normal distribution
                z_mean = tf.layers.dense(h1, units=self.z_dim, activation=None)
                z_var = tf.layers.dense(h1,
                                        units=self.z_dim,
                                        activation=tf.nn.softplus)
            elif self.distribution == 'vmf':
                # compute mean and concentration of the von Mises-Fisher
                z_mean = tf.layers.dense(
                    h1,
                    units=self.z_dim,
                    activation=lambda x: tf.nn.l2_normalize(x, axis=-1))
                # the `+ 1` prevent collapsing behaviors
                z_var = tf.layers.dense(h1, units=1,
                                        activation=tf.nn.softplus) + 1
            else:
                raise NotImplemented

            return z_mean, z_var

    def _decoder(self, z):
        """
        Decoder network

        :param z: tensor, latent representation of input (x)
        :return: logits, `reconstruction = sigmoid(logits)`
        """
        # 2 hidden layers decoder
        if self.distribution == 'vmf' and self.rescale_sph_latent:
            z = z * tf.sqrt(tf.to_float(self.z_dim))
        with tf.variable_scope(DECODER, reuse=AUTO_REUSE):
            h2 = tf.layers.dense(z,
                                 units=self.h_dim,
                                 activation=self.activation)
            h2 = tf.layers.dense(h2,
                                 units=self.h_dim,
                                 activation=self.activation)
            logits = tf.layers.dense(h2,
                                     units=self.x.shape[-1],
                                     activation=None)

        return logits
Пример #4
0
class VAE(CAE):
    def __init__(
            self,
            vtype,
            output_low_bound,
            output_up_bound,
            # relu bounds
            nonlinear_low_bound,
            nonlinear_up_bound,
            # conv layers
            conv_filter_sizes=[3, 3],  #[[3,3], [3,3], [3,3], [3,3], [3,3]], 
            conv_strides=[1, 1],  #[[1,1], [1,1], [1,1], [1,1], [1,1]],
            conv_padding="SAME",  #["SAME", "SAME", "SAME", "SAME", "SAME"],
            conv_channel_sizes=[128, 128, 128, 64, 64, 64,
                                3],  # [128, 128, 128, 128, 1]
            conv_leaky_ratio=[0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1],
            # deconv layers
            decv_filter_sizes=[3, 3],  #[[3,3], [3,3], [3,3], [3,3], [3,3]], 
            decv_strides=[1, 1],  #[[1,1], [1,1], [1,1], [1,1], [1,1]],
            decv_padding="SAME",  #["SAME", "SAME", "SAME", "SAME", "SAME"],
            decv_channel_sizes=[3, 64, 64, 64, 128, 128,
                                128],  # [1, 128, 128, 128, 128]
            decv_leaky_ratio=[0.1, 0.2, 0.2, 0.2, 0.4, 0.4, 0.01],
            # encoder fc layers
            enfc_state_sizes=[4096],
            enfc_leaky_ratio=[0.2, 0.2],
            enfc_drop_rate=[0, 0.75],
            # bottleneck
            central_state_size=2048,
            # decoder fc layers
            defc_state_sizes=[4096],
            defc_leaky_ratio=[0.2, 0.2],
            defc_drop_rate=[0.75, 0],
            # img channel
            img_channel=None,
            # switch
            use_norm=None):
        self.vtype = vtype
        super().__init__(
            output_low_bound, output_up_bound, nonlinear_low_bound,
            nonlinear_up_bound, conv_filter_sizes, conv_strides, conv_padding,
            conv_channel_sizes, conv_leaky_ratio, decv_filter_sizes,
            decv_strides, decv_padding, decv_channel_sizes, decv_leaky_ratio,
            enfc_state_sizes, enfc_leaky_ratio, enfc_drop_rate,
            central_state_size, defc_state_sizes, defc_leaky_ratio,
            defc_drop_rate, img_channel, use_norm)

    @lazy_method
    def enfc_weights_biases(self):
        in_size = self.conv_out_shape[0] * self.conv_out_shape[
            1] * self.conv_out_shape[2]
        state_sizes = self.enfc_state_sizes + [self.central_state_size]
        return self._fc_weights_biases("W_enfc",
                                       "b_enfc",
                                       in_size,
                                       state_sizes,
                                       sampling=True)

    def _fc_weights_biases(self,
                           W_name,
                           b_name,
                           in_size,
                           state_sizes,
                           sampling=False):
        num_layer = len(state_sizes)
        _weights = {}
        _biases = {}

        def _func(in_size, out_size, idx, postfix=""):
            W_key = "{}{}{}".format(W_name, idx, postfix)
            W_shape = [in_size, out_size]
            _weights[W_key] = ne.weight_variable(W_shape, name=W_key)

            b_key = "{}{}{}".format(b_name, idx, postfix)
            b_shape = [out_size]
            _biases[b_key] = ne.bias_variable(b_shape, name=b_key)

            in_size = out_size

            # tensorboard
            tf.summary.histogram("Weight_" + W_key, _weights[W_key])
            tf.summary.histogram("Bias_" + b_key, _biases[b_key])

            return in_size

        for idx in range(num_layer - 1):
            in_size = _func(in_size, state_sizes[idx], idx)
        # Last layer
        if sampling:
            if self.vtype == "gauss":
                for postfix in ["_mu", "_sigma"]:
                    _func(in_size, state_sizes[num_layer - 1], num_layer - 1,
                          postfix)
            elif self.vtype == "vmf":
                _func(in_size, state_sizes[num_layer - 1], num_layer - 1,
                      "_mu")
                _func(in_size, 1, num_layer - 1, "_sigma")
            else:
                raise NotImplemented
        else:
            _func(in_size, state_sizes[num_layer - 1], num_layer - 1)
        #import pdb; pdb.set_trace()

        return _weights, _biases, num_layer

    @lazy_method
    def enfc_layers(self, inputs, W_name="W_enfc", b_name="b_enfc"):
        net = tf.reshape(inputs, [
            -1, self.conv_out_shape[0] * self.conv_out_shape[1] *
            self.conv_out_shape[2]
        ])

        def _func(net, layer_id, postfix="", act_func="leaky"):
            weight_name = "{}{}{}".format(W_name, layer_id, postfix)
            bias_name = "{}{}{}".format(b_name, layer_id, postfix)
            curr_weight = self.enfc_weights[weight_name]
            curr_bias = self.enfc_biases[bias_name]
            net = ne.fully_conn(net, weights=curr_weight, biases=curr_bias)
            # batch normalization
            if self.use_norm == "BATCH":
                net = ne.batch_norm(net, self.is_training, axis=1)
            elif self.use_norm == "LAYER":
                net = ne.layer_norm(net, self.is_training)
            #net = ne.leaky_brelu(net, self.enfc_leaky_ratio[layer_id], self.enfc_low_bound[layer_id], self.enfc_up_bound[layer_id]) # Nonlinear act
            if act_func == "leaky":
                net = ne.leaky_relu(net, self.enfc_leaky_ratio[layer_id])
            elif act_func == "soft":
                net = tf.nn.softplus(net)
            #net = ne.drop_out(net, self.enfc_drop_rate[layer_id], self.is_training)
            return net

        for layer_id in range(self.num_enfc - 1):
            net = _func(net, layer_id)
        # Last layer
        if self.vtype == "gauss":
            # compute mean and log of var of the normal distribution
            """net_mu = tf.minimum(tf.maximum(-5.0, _func(net, self.num_enfc-1, "_mu")), 5.0)
            ## Set low and up bounds for log_sigma_sq
            '''net_log_sigma_sq = tf.minimum(tf.maximum(-10.0, _func(net, self.num_enfc-1, "_sigma")), 5.0)
            net_sigma = tf.sqrt(tf.exp(net_log_sigma_sq))'''
            net_sigma = tf.maximum(_func(net, self.num_enfc-1, "_sigma", "soft"), 5.0)"""
            net_mu = _func(net, self.num_enfc - 1, "_mu")
            net_log_sigma_sq = tf.minimum(
                tf.maximum(-10.0, _func(net, self.num_enfc - 1, "_sigma")),
                5.0)
            net_sigma = tf.sqrt(tf.exp(net_log_sigma_sq))
        elif self.vtype == "vmf":
            # compute mean and log of var of the von Mises-Fisher
            #net_mu = tf.minimum(tf.maximum(0.0, _func(net, self.num_enfc-1, "_mu", None)), 0.0)
            net_mu = _func(net, self.num_enfc - 1, "_mu", None)
            net_mu = tf.nn.l2_normalize(net_mu, axis=-1)
            #net_mu = tf.nn.l2_normalize(_func(net, self.num_enfc-1, "_mu"), axis=1)
            ## Set low and up bounds for log_sigma_sq
            #net_log_sigma_sq = tf.minimum(tf.maximum(0.0, _func(net, self.num_enfc-1, "_log_sigma_sq")), 10.0)
            net_sigma = _func(net, self.num_enfc - 1, "_sigma", "soft") + 200.0
        else:
            raise NotImplemented

        net_mu = tf.identity(net_mu, name="output_mu")
        net_sigma = tf.identity(net_sigma, name="output_sigma")
        return net_mu, net_sigma

    @lazy_method
    def encoder(self, inputs):
        conv = self.conv_layers(inputs)
        assert conv.get_shape().as_list()[1:] == self.conv_out_shape
        self.central_mu, self.central_sigma = self.enfc_layers(conv)
        if self.vtype == "gauss":
            assert self.central_mu.get_shape().as_list()[1:] == [
                self.central_state_size
            ]
        elif self.vtype == "vmf":
            assert self.central_sigma.get_shape().as_list()[1:] == [1]
        """# epsilon
        eps = tf.random_normal(tf.shape(self.central_mu), 0, 1, dtype=tf.float32)
        # z = mu + sigma*epsilon
        enfc = tf.add(self.central_mu, tf.multiply(tf.sqrt(tf.exp(self.central_log_sigma_sq)), eps))"""
        if self.vtype == "gauss":
            self.central_distribution = tf.distributions.Normal(
                self.central_mu, self.central_sigma)
        elif self.vtype == "vmf":
            self.central_distribution = VonMisesFisher(self.central_mu,
                                                       self.central_sigma)
        self.central_states = self.central_distribution.sample()
        return self.central_states

    @lazy_method
    def kl_distance(self):
        if self.vtype == "gauss":
            self.prior = tf.distributions.Normal(
                tf.zeros(self.central_state_size),
                tf.ones(self.central_state_size))
            self.kl = self.central_distribution.kl_divergence(self.prior)
            loss_kl = tf.reduce_mean(tf.reduce_sum(self.kl, axis=1))
        elif self.vtype == 'vmf':
            self.prior = HypersphericalUniform(self.central_state_size - 1,
                                               dtype=tf.float32)
            self.kl = self.central_distribution.kl_divergence(self.prior)
            loss_kl = tf.reduce_mean(self.kl)
        else:
            raise NotImplemented
        return loss_kl

    @lazy_method
    def gauss_kl_distance(self):
        loss = -0.5 * tf.reduce_sum(
            1 + self.central_log_sigma_sq - tf.square(self.central_mu) -
            tf.exp(self.central_log_sigma_sq), 1)

        return loss

    def tf_load(self, sess, path, name='deep_vcae.ckpt', spec=""):
        #saver = tf.train.Saver(dict(self.conv_filters, **self.conv_biases, **self.decv_filters, **self.decv_biases))
        saver = tf.train.Saver(var_list=tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, scope='autoencoder'))
        saver.restore(sess, path + '/' + name + spec)

    def tf_save(self, sess, path, name='deep_vcae.ckpt', spec=""):
        #saver = tf.train.Saver(dict(self.conv_filters, **self.conv_biases, **self.decv_filters, **self.decv_biases))
        saver = tf.train.Saver(var_list=tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, scope='autoencoder'))
        saver.save(sess, path + '/' + name + spec)