Beispiel #1
0
    def __init__(self, latent_spec, **kwargs):
        """
        Args:
            latent_spec (list): List of latent distributions.
             [(Distribution, bool)]
             The boolean indicates if the distribution should be used for
             regularization.
        """

        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        super(RegularizedGAN, self).__init__(**kwargs)
        d = {
            'latent_code_influence': {
                'tensor': 'get_latent_code_influence_g_input_tensor',
            },
            'linear_interpolation': {
                'tensor': 'get_linear_interpolation_g_input_tensor',
            },
        }
        self.sampling_functions.update(d)
Beispiel #2
0
    def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product([x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli))])

        image_size = image_shape[0]
        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size // 4 * image_size // 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size // 4, image_size // 4, 128]).
                     custom_deconv2d([0, image_size // 2, image_size // 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     flatten())
        else:
            raise NotImplementedError
Beispiel #3
0
    def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product([x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli))])

        image_size = image_shape[0]
        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size / 4 * image_size / 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 4, image_size / 4, 128]).
                     custom_deconv2d([0, image_size // 2, image_size // 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     flatten())
        else:
            raise NotImplementedError
Beispiel #4
0
    def __init__(self,
                 output_dist,
                 latent_spec,
                 is_reg,
                 batch_size,
                 image_shape,
                 network_type,
                 impr=False):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.impr = impr
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.is_reg = is_reg
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        self.keys = ['prob', 'logits', 'features']

        if self.is_reg:
            self.encoder_dim = self.reg_latent_dist.dist_flat_dim
            self.keys = self.keys + ['reg_dist_info']
        else:
            self.encoder_dim = None

        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        self.set_D_net()
        self.set_G_net()
Beispiel #5
0
    def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product([x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli))])
Beispiel #6
0
class RegularizedGAN(object):
    def __init__(self, output_dist, latent_spec, batch_size, image_shape,
                 network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        image_size = image_shape[0]
        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size / 4 * image_size / 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 4, image_size / 4, 128]).
                     custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     flatten())

        elif network_type == "celebA":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_conv2d(256, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(image_size / 16 * image_size / 16 * 448).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 16, image_size / 16, 448]).
                     # I am *pretty sure* each of these dimensions grows by 2x
                     # because the stride==2.
                     custom_deconv2d([0, image_size / 8, image_size / 8, 256], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0, image_size / 4, image_size / 4, 128], k_h=4, k_w=4).
                     apply(tf.nn.relu).
                     custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4).
                     apply(tf.nn.relu).
                     custom_deconv2d([0, image_size / 1, image_size / 1, 3], k_h=4, k_w=4).
                     apply(tf.nn.tanh).
                     flatten())

        elif network_type == "face":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size / 4 * image_size / 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 4, image_size / 4, 128]).
                     custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     apply(tf.nn.sigmoid).
                     flatten())
        else:
            raise NotImplementedError

    def discriminate(self, x_var):
        d_out = self.discriminator_template.construct(input=x_var)
        d = tf.nn.sigmoid(d_out[:, 0])
        reg_dist_flat = self.encoder_template.construct(input=x_var)
        reg_dist_info = self.reg_latent_dist.activate_dist(reg_dist_flat)
        return d, self.reg_latent_dist.sample(
            reg_dist_info), reg_dist_info, reg_dist_flat

    def generate(self, z_var):
        x_dist_flat = self.generator_template.construct(input=z_var)
        x_dist_info = self.output_dist.activate_dist(x_dist_flat)
        return self.output_dist.sample(x_dist_info), x_dist_info

    def disc_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(z_i)
        return self.reg_disc_latent_dist.join_vars(ret)

    def cont_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, Gaussian):
                ret.append(z_i)
        return self.reg_cont_latent_dist.join_vars(ret)

    def disc_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(dist_info_i)
        return self.reg_disc_latent_dist.join_dist_infos(ret)

    def cont_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, Gaussian):
                ret.append(dist_info_i)
        return self.reg_cont_latent_dist.join_dist_infos(ret)

    def reg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            if reg_i:
                ret.append(z_i)
        return self.reg_latent_dist.join_vars(ret)

    def nonreg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            if not reg_i:
                ret.append(z_i)
        return self.nonreg_latent_dist.join_vars(ret)

    def reg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if reg_i:
                ret.append(dist_info_i)
        return self.reg_latent_dist.join_dist_infos(ret)

    def nonreg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if not reg_i:
                ret.append(dist_info_i)
        return self.nonreg_latent_dist.join_dist_infos(ret)

    def combine_reg_nonreg_z(self, reg_z_var, nonreg_z_var):
        reg_z_vars = self.reg_latent_dist.split_var(reg_z_var)
        reg_idx = 0
        nonreg_z_vars = self.nonreg_latent_dist.split_var(nonreg_z_var)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_z_vars[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_z_vars[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_vars(ret)

    def combine_reg_nonreg_dist_info(self, reg_dist_info, nonreg_dist_info):
        reg_dist_infos = self.reg_latent_dist.split_dist_info(reg_dist_info)
        reg_idx = 0
        nonreg_dist_infos = self.nonreg_latent_dist.split_dist_info(
            nonreg_dist_info)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_dist_infos[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_dist_infos[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_dist_infos(ret)
Beispiel #7
0
class RegularizedGAN(object):
    def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product([x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli))])

    def custom_batch_norm(self, input_layer, epsilon=1e-5):
        shape = input_layer.shape
        shp = shape[-1]
        mean, variance = tf.nn.moments(input_layer, [0])
        return tf.nn.batch_normalization( input_layer, mean, variance, None, None, epsilon)

    def create_discriminator(self, inp):
        image_size = self.image_shape[0]
        if self.network_type == "mnist":
            with tf.variable_scope("d_net"):
                with tf.variable_scope("conv1") as scope:
                    inp_ = tf.reshape(inp,[-1]+list(self.image_shape))
                    kernel = tf.get_variable('weight',[4,4,self.image_shape[-1],64], \
                        initializer = tf.truncated_normal_initializer(5e-2))
                    conv = tf.nn.conv2d(inp_,kernel,[1,1,1,1], padding='SAME')
                    bias = tf.get_variable('bias',[64], initializer=tf.constant_initializer(0.1))
                    pre_act = tf.nn.bias_add(conv,bias)
                    # applying batch_norm
                    pre_act = self.custom_batch_norm(pre_act)
                    # leaky ReLu with alpha = 0.01
                    conv1 = tf.maximum(0.01*pre_act, pre_act, name=scope.name)

                pool1 = tf.nn.max_pool(conv1, ksize=[1,3,3,1], strides=[1,2,2,1], padding='SAME', name='pool1')

                with tf.variable_scope("conv2") as scope:
                    kernel = tf.get_variable('weight',[4,4,64,128], \
                        initializer = tf.truncated_normal_initializer(5e-2))
                    conv = tf.nn.conv2d(pool1,kernel,[1,1,1,1], padding='SAME')
                    bias = tf.get_variable('bias',[128], initializer=tf.constant_initializer(0.1))
                    pre_act = tf.nn.bias_add(conv,bias)
                    # applying batch_norm
                    pre_act = self.custom_batch_norm(pre_act)
                    # leaky ReLu with alpha = 0.01
                    conv2 = tf.maximum(0.01*pre_act, pre_act, name=scope.name)

                pool2 = tf.nn.max_pool(conv1, ksize=[1,3,3,1], strides=[1,2,2,1], padding='SAME', name='pool2')

                with tf.variable_scope("fc3") as scope:
                    inp_ = tf.reshape(pool2, [self.batch_size, -1])
                    dim = inp_.get_shape()[1].value
                    weights = tf.get_variable('weight',[dim,1024],\
                        initializer=tf.truncated_normal_initializer(4e-3))
                    bias = tf.get_variable('bias',[1024],initializer=tf.constant_initializer(0.1))
                    pre_act = tf.matmul(inp_,weights) + bias
                    # applying batch_norm
                    pre_act = self.custom_batch_norm(pre_act)
                    # leaky ReLu with alpha = 0.01
                    fc3 = tf.maximum(0.01*pre_act, pre_act, name=scope.name)

                with tf.variable_scope("fc4") as scope:
                    weights = tf.get_variable('weight',[1024,128],\
                        initializer=tf.truncated_normal_initializer(4e-3))
                    bias = tf.get_variable('bias',[128],initializer=tf.constant_initializer(0.1))
                    pre_act = tf.nn.bias_add(tf.matmul(fc3,weights),bias,name=scope.name)
                    # applying batch_norm
                    pre_act = self.custom_batch_norm(pre_act)
                    # leaky ReLu with alpha = 0.01
                    fc4 = tf.maximum(0.01*pre_act, pre_act, name=scope.name)

                with tf.variable_scope("d_temp") as scope:
                    weights = tf.get_variable('weight',[1024,1],\
                        initializer=tf.truncated_normal_initializer(4e-3))
                    bias = tf.get_variable('bias',[1],initializer=tf.constant_initializer(0.1))
                    d_temp = tf.nn.bias_add(tf.matmul(fc3,weights),bias,name=scope.name)
                    
                self.discriminator_template = d_temp

                with tf.variable_scope("enc_temp") as scope:
                    weights = tf.get_variable('weight',[128,self.reg_latent_dist.dist_flat_dim],\
                        initializer=tf.truncated_normal_initializer(4e-3))
                    bias = tf.get_variable('bias',[self.reg_latent_dist.dist_flat_dim],\
                        initializer=tf.constant_initializer(0.1))
                    enc_temp = tf.nn.bias_add(tf.matmul(fc4,weights),bias,name=scope.name)

                self.encoder_template = enc_temp
        else:
            raise NotImplementedError

    def create_generator(self, inp):
        image_size = self.image_shape[0]
        if self.network_type == "mnist":
            with tf.variable_scope("g_net"):
                with tf.variable_scope("fc1") as scope:
                    weight_shape = [inp.shape[-1].value,1024]
                    weights = tf.get_variable('weight', weight_shape,\
                        initializer=tf.truncated_normal_initializer(0.1))
                    bias = tf.get_variable('bias',[1024],initializer=tf.constant_initializer(0.1))
                    pre_act = tf.nn.bias_add(tf.matmul(inp,weights),bias)
                    # applying batch_norm
                    pre_act = self.custom_batch_norm(pre_act)
                    # applying activation
                    fc1 = tf.nn.relu(pre_act, name=scope.name)

                with tf.variable_scope("fc2") as scope:
                    weight_shape = [1024, image_size/4*image_size/4*128]
                    weights = tf.get_variable('weight',weight_shape,initializer=tf.truncated_normal_initializer(0.1))
                    bias = tf.get_variable('bias',[weight_shape[-1]],initializer=tf.constant_initializer(0.1))
                    pre_act = tf.nn.bias_add(tf.matmul(fc1,weights),bias)
                    # applying batch_norm
                    pre_act = self.custom_batch_norm(pre_act)
                    # applying activation
                    fc2 = tf.nn.relu(pre_act, name=scope.name)
                    fc2_r = tf.reshape(fc2,[-1, image_size/4, image_size/4, 128])

                with tf.variable_scope("deconv3") as scope:
                    output_shape = [inp.shape[0].value, image_size/2, image_size/2, 64]
                    kernel = tf.get_variable('weight', [4, 4, 64, 128],\
                              initializer=tf.random_normal_initializer(stddev=0.02))
                    deconv = tf.nn.conv2d_transpose(fc2_r, kernel,\
                        output_shape=output_shape,strides=[1, 2, 2, 1])

                    bias = tf.get_variable('bias', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
                    deconv = tf.reshape(tf.nn.bias_add(deconv, bias), output_shape)
                    # applying batch_norm
                    pre_act = self.custom_batch_norm(deconv)
                    # applying activation
                    deconv3 = tf.nn.relu(pre_act, name=scope.name)

                with tf.variable_scope("deconv4") as scope:
                    output_shape = [inp.shape[0].value, image_size, image_size, 1]
                    kernel = tf.get_variable('weight', [4, 4, 1, 64],\
                              initializer=tf.random_normal_initializer(stddev=0.02))
                    deconv = tf.nn.conv2d_transpose(deconv3, kernel,\
                        output_shape=output_shape,strides=[1, 2, 2, 1])

                    bias = tf.get_variable('bias', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
                    deconv = tf.reshape(tf.nn.bias_add(deconv, bias), output_shape)
                    # applying batch_norm
                    pre_act = self.custom_batch_norm(deconv)
                    # applying activation
                    deconv4 = tf.nn.relu(pre_act, name=scope.name)

                self.generator_template = deconv4
        else:
            raise NotImplementedError


    def discriminate(self, x_var):
        self.create_discriminator(inp=x_var)
        d_out = self.discriminator_template
        d = tf.nn.sigmoid(d_out[:, 0])
        reg_dist_flat = self.encoder_template
        reg_dist_info = self.reg_latent_dist.activate_dist(reg_dist_flat)
        return d, self.reg_latent_dist.sample(reg_dist_info), reg_dist_info, reg_dist_flat

    def generate(self, z_var):
        self.create_generator(inp=z_var)
        x_dist_flat = self.generator_template
        x_dist_info = self.output_dist.activate_dist(x_dist_flat)
        return self.output_dist.sample(x_dist_info), x_dist_info

    def disc_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(z_i)
        return self.reg_disc_latent_dist.join_vars(ret)

    def cont_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, Gaussian):
                ret.append(z_i)
        return self.reg_cont_latent_dist.join_vars(ret)

    def disc_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(dist_info_i)
        return self.reg_disc_latent_dist.join_dist_infos(ret)

    def cont_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, Gaussian):
                ret.append(dist_info_i)
        return self.reg_cont_latent_dist.join_dist_infos(ret)

    def reg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec, self.latent_dist.split_var(z_var)):
            if reg_i:
                ret.append(z_i)
        return self.reg_latent_dist.join_vars(ret)

    def nonreg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec, self.latent_dist.split_var(z_var)):
            if not reg_i:
                ret.append(z_i)
        return self.nonreg_latent_dist.join_vars(ret)

    def reg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if reg_i:
                ret.append(dist_info_i)
        return self.reg_latent_dist.join_dist_infos(ret)

    def nonreg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if not reg_i:
                ret.append(dist_info_i)
        return self.nonreg_latent_dist.join_dist_infos(ret)

    def combine_reg_nonreg_z(self, reg_z_var, nonreg_z_var):
        reg_z_vars = self.reg_latent_dist.split_var(reg_z_var)
        reg_idx = 0
        nonreg_z_vars = self.nonreg_latent_dist.split_var(nonreg_z_var)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_z_vars[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_z_vars[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_vars(ret)

    def combine_reg_nonreg_dist_info(self, reg_dist_info, nonreg_dist_info):
        reg_dist_infos = self.reg_latent_dist.split_dist_info(reg_dist_info)
        reg_idx = 0
        nonreg_dist_infos = self.nonreg_latent_dist.split_dist_info(nonreg_dist_info)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_dist_infos[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_dist_infos[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_dist_infos(ret)
Beispiel #8
0
class RegularizedGAN(object):
    def __init__(self, output_dist, latent_spec, is_reg, batch_size,
                 image_shape, network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.is_reg = is_reg
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        self.keys = ['prob', 'logits', 'features']

        if self.is_reg:
            self.encoder_dim = self.reg_latent_dist.dist_flat_dim
            self.keys = self.keys + ['reg_dist_info']
        else:
            self.encoder_dim = None

        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        self.set_D_net()
        self.set_G_net()

    def set_D_net(self):
        with tf.variable_scope("d_net"):
            if self.network_type == "mnist":
                self.D_model = D.InfoGAN_MNIST_net(
                    image_shape=self.image_shape,
                    is_reg=self.is_reg,
                    encoder_dim=self.encoder_dim)
                shared_template = self.D_model.shared_template
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                self.encoder_template = self.D_model.encoder_template
            else:
                if self.network_type == 'dcgan':
                    self.D_model = D.dcgan_net(image_shape=self.image_shape,
                                               is_reg=self.is_reg,
                                               encoder_dim=self.encoder_dim)
                elif self.network_type == 'deeper_dcgan':
                    self.D_model = D.deeper_dcgan_net(
                        image_shape=self.image_shape,
                        is_reg=self.is_reg,
                        encoder_dim=self.encoder_dim)
                else:
                    raise NotImplementedError
                self.shared_template = self.D_model.shared_template
                self.discriminator_template = self.shared_template.custom_fully_connected(
                    1)
                self.encoder_template = self.D_model.encoder_template

    def set_G_net(self):
        with tf.variable_scope("g_net"):
            if self.network_type == 'mnist':
                self.gen_model = G.InfoGAN_mnist_net()
                self.generator_template = self.gen_model.infoGAN_mnist_net(
                    self.image_shape)
            else:
                if self.network_type == 'dcgan':
                    self.gen_model = G.dcgan_net()
                elif self.network_type == 'deeper_dcgan':
                    self.gen_model = G.deeper_dcgan_net()
                else:
                    raise NotImplementedError
                self.generator_template = self.gen_model.gen_net(
                    self.image_shape)

    def discriminate(self, x_var):
        d_features = self.shared_template.construct(input=x_var)
        d_logits = self.discriminator_template.construct(input=x_var)[:, 0]
        d_prob = tf.nn.sigmoid(d_logits)

        d_dict = dict.fromkeys(self.keys)

        d_dict['features'] = d_features
        d_dict['logits'] = d_logits
        d_dict['prob'] = d_prob

        if self.is_reg:
            reg_dist_flat = self.encoder_template.construct(input=x_var)
            reg_dist_info = self.reg_latent_dist.activate_dist(reg_dist_flat)
            d_dict['reg_dist_info'] = reg_dist_info

        return d_dict

    def generate(self, z_var):
        x_dist_flat = self.generator_template.construct(input=z_var)
        if self.network_type == "mnist":
            x_dist_info = self.output_dist.activate_dist(x_dist_flat)
            return self.output_dist.sample(x_dist_info), x_dist_info
        else:
            return x_dist_flat, None

    def disc_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(z_i)
        return self.reg_disc_latent_dist.join_vars(ret)

    def cont_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, Gaussian):
                ret.append(z_i)
        return self.reg_cont_latent_dist.join_vars(ret)

    def disc_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(dist_info_i)
        return self.reg_disc_latent_dist.join_dist_infos(ret)

    def cont_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, Gaussian):
                ret.append(dist_info_i)
        return self.reg_cont_latent_dist.join_dist_infos(ret)

    def reg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            if reg_i:
                ret.append(z_i)
        return self.reg_latent_dist.join_vars(ret)

    def nonreg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            if not reg_i:
                ret.append(z_i)
        return self.nonreg_latent_dist.join_vars(ret)

    def reg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if reg_i:
                ret.append(dist_info_i)
        return self.reg_latent_dist.join_dist_infos(ret)

    def nonreg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if not reg_i:
                ret.append(dist_info_i)
        return self.nonreg_latent_dist.join_dist_infos(ret)

    def combine_reg_nonreg_z(self, reg_z_var, nonreg_z_var):
        reg_z_vars = self.reg_latent_dist.split_var(reg_z_var)
        reg_idx = 0
        nonreg_z_vars = self.nonreg_latent_dist.split_var(nonreg_z_var)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_z_vars[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_z_vars[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_vars(ret)

    def combine_reg_nonreg_dist_info(self, reg_dist_info, nonreg_dist_info):
        reg_dist_infos = self.reg_latent_dist.split_dist_info(reg_dist_info)
        reg_idx = 0
        nonreg_dist_infos = self.nonreg_latent_dist.split_dist_info(
            nonreg_dist_info)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_dist_infos[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_dist_infos[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_dist_infos(ret)
Beispiel #9
0
class RegularizedGAN(object):
    def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product([x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli))])

        image_size = image_shape[0]
        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size // 4 * image_size // 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size // 4, image_size // 4, 128]).
                     custom_deconv2d([0, image_size // 2, image_size // 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     flatten())
        else:
            raise NotImplementedError

    def discriminate(self, x_var):
        d_out = self.discriminator_template.construct(input=x_var)
        d = tf.nn.sigmoid(d_out[:, 0])
        reg_dist_flat = self.encoder_template.construct(input=x_var)
        reg_dist_info = self.reg_latent_dist.activate_dist(reg_dist_flat)
        return d, self.reg_latent_dist.sample(reg_dist_info), reg_dist_info, reg_dist_flat

    def generate(self, z_var):
        x_dist_flat = self.generator_template.construct(input=z_var)
        x_dist_info = self.output_dist.activate_dist(x_dist_flat)
        return self.output_dist.sample(x_dist_info), x_dist_info

    def disc_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(z_i)
        return self.reg_disc_latent_dist.join_vars(ret)

    def cont_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, Gaussian):
                ret.append(z_i)
        return self.reg_cont_latent_dist.join_vars(ret)

    def disc_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(dist_info_i)
        return self.reg_disc_latent_dist.join_dist_infos(ret)

    def cont_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, Gaussian):
                ret.append(dist_info_i)
        return self.reg_cont_latent_dist.join_dist_infos(ret)

    def reg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec, self.latent_dist.split_var(z_var)):
            if reg_i:
                ret.append(z_i)
        return self.reg_latent_dist.join_vars(ret)

    def nonreg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec, self.latent_dist.split_var(z_var)):
            if not reg_i:
                ret.append(z_i)
        return self.nonreg_latent_dist.join_vars(ret)

    def reg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if reg_i:
                ret.append(dist_info_i)
        return self.reg_latent_dist.join_dist_infos(ret)

    def nonreg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if not reg_i:
                ret.append(dist_info_i)
        return self.nonreg_latent_dist.join_dist_infos(ret)

    def combine_reg_nonreg_z(self, reg_z_var, nonreg_z_var):
        reg_z_vars = self.reg_latent_dist.split_var(reg_z_var)
        reg_idx = 0
        nonreg_z_vars = self.nonreg_latent_dist.split_var(nonreg_z_var)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_z_vars[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_z_vars[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_vars(ret)

    def combine_reg_nonreg_dist_info(self, reg_dist_info, nonreg_dist_info):
        reg_dist_infos = self.reg_latent_dist.split_dist_info(reg_dist_info)
        reg_idx = 0
        nonreg_dist_infos = self.nonreg_latent_dist.split_dist_info(nonreg_dist_info)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_dist_infos[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_dist_infos[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_dist_infos(ret)
Beispiel #10
0
class RegularizedGAN(object):
    def __init__(self, output_dist, latent_spec, batch_size, image_shape,
                 network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        image_size = image_shape[0]
        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)

                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                #SOMEWHAT CONSISTENT. MIGHT CHANGE
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size / 4 * image_size / 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 4, image_size / 4, 128]).
                     custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     flatten())

        #HEART!!!
        elif network_type == 'heart':
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                #THIS ENCODER DOESNT SEEM CONISTENT WITH FACES. THAT'S OKAY. WILL
                #TRY ANYWAY.
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size / 4 * image_size / 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 4, image_size / 4, 128]).
                     custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     #THIS CONV APPEARS TO BE EXTRA. WILL KEEP ANYWAY
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     flatten())
        else:
            raise NotImplementedError

    def discriminate(self, x_var):
        d_out = self.discriminator_template.construct(input=x_var)
        d = tf.nn.sigmoid(d_out[:, 0])
        reg_dist_flat = self.encoder_template.construct(input=x_var)
        reg_dist_info = self.reg_latent_dist.activate_dist(reg_dist_flat)
        return d, self.reg_latent_dist.sample(
            reg_dist_info), reg_dist_info, reg_dist_flat

    def generate(self, z_var):
        x_dist_flat = self.generator_template.construct(input=z_var)
        x_dist_info = self.output_dist.activate_dist(x_dist_flat)
        return self.output_dist.sample(x_dist_info), x_dist_info

    def disc_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(z_i)
        return self.reg_disc_latent_dist.join_vars(ret)

    def cont_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, Gaussian):
                ret.append(z_i)
        return self.reg_cont_latent_dist.join_vars(ret)

    def disc_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(dist_info_i)
        return self.reg_disc_latent_dist.join_dist_infos(ret)

    def cont_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, Gaussian):
                ret.append(dist_info_i)
        return self.reg_cont_latent_dist.join_dist_infos(ret)

    def reg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            if reg_i:
                ret.append(z_i)
        return self.reg_latent_dist.join_vars(ret)

    def nonreg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            if not reg_i:
                ret.append(z_i)
        return self.nonreg_latent_dist.join_vars(ret)

    def reg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if reg_i:
                ret.append(dist_info_i)
        return self.reg_latent_dist.join_dist_infos(ret)

    def nonreg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if not reg_i:
                ret.append(dist_info_i)
        return self.nonreg_latent_dist.join_dist_infos(ret)

    def combine_reg_nonreg_z(self, reg_z_var, nonreg_z_var):
        reg_z_vars = self.reg_latent_dist.split_var(reg_z_var)
        reg_idx = 0
        nonreg_z_vars = self.nonreg_latent_dist.split_var(nonreg_z_var)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_z_vars[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_z_vars[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_vars(ret)

    def combine_reg_nonreg_dist_info(self, reg_dist_info, nonreg_dist_info):
        reg_dist_infos = self.reg_latent_dist.split_dist_info(reg_dist_info)
        reg_idx = 0
        nonreg_dist_infos = self.nonreg_latent_dist.split_dist_info(
            nonreg_dist_info)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_dist_infos[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_dist_infos[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_dist_infos(ret)
Beispiel #11
0
    def __init__(self, output_dist, latent_spec, batch_size, image_shape,
                 network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        pstr('output_dist', self.output_dist)
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        pstr('latent_dist', self.latent_dist)
        pstr('x in latent_spec', [x for x, _ in self.latent_spec])
        pstr('xreg in latent_spec', [xreg for _, xreg in self.latent_spec])
        #for x in enumerate(self.latent_spec):
        #   print '------------------------'
        #   for y in enumerate(x):
        #      pstrall('x----reg',y)

        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)
        #for x in self.reg_latent_dist.dists:
        #   pstr('x in reg_latent_dist.dists',x)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        pstr('image_shape', image_shape)
        pstr('image_shape[0]', image_shape[0])
        image_size = image_shape[0]

        #self.image_shape = (178, 218, 1)

        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(leaky_rectify).

                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).

                     custom_conv2d(256, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).

                     #custom_fully_connected(1024).
                     #fc_batch_norm().
                     #apply(leaky_rectify).
                     custom_conv2d(512, k_h=4, k_w=4))
                #conv_batch_norm().
                #apply(leaky_rectify2))

                #linear

                #apply(tf.nn.sigmoid))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                s = self.image_shape[0]
                s2, s4, s8, s16, s32 = int(s / 2), int(s / 4), int(s / 8), int(
                    s / 16), int(s / 32)
                self.generator_template = \
                    (pt.template("input").

                     custom_fully_connected(s16 * s16 * 512).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, s16, s16,  512]).

                     #custom_fully_connected(s32 * s32 * 1024).
                     #fc_batch_norm().
                     #apply(tf.nn.relu).
                     #reshape([-1, s32, s32,  1024]).

                     #custom_deconv2d([0, s16, s16,  512], k_h=4, k_w=4).
                     #conv_batch_norm().
                     #apply(tf.nn.relu).

                     custom_deconv2d([0, s8, s8,  256], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).

                     custom_deconv2d([0, s4, s4, 128], k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(tf.nn.relu).

                     custom_deconv2d([0, s2, s2, 64], k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     apply(tf.nn.tanh))

        else:
            raise NotImplementedError
Beispiel #12
0
class RegularizedGAN(object):
    def __init__(self, output_dist, latent_spec, batch_size, image_shape,
                 network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        pstr('output_dist', self.output_dist)
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        pstr('latent_dist', self.latent_dist)
        pstr('x in latent_spec', [x for x, _ in self.latent_spec])
        pstr('xreg in latent_spec', [xreg for _, xreg in self.latent_spec])
        #for x in enumerate(self.latent_spec):
        #   print '------------------------'
        #   for y in enumerate(x):
        #      pstrall('x----reg',y)

        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)
        #for x in self.reg_latent_dist.dists:
        #   pstr('x in reg_latent_dist.dists',x)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        pstr('image_shape', image_shape)
        pstr('image_shape[0]', image_shape[0])
        image_size = image_shape[0]

        #self.image_shape = (178, 218, 1)

        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(leaky_rectify).

                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).

                     custom_conv2d(256, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).

                     #custom_fully_connected(1024).
                     #fc_batch_norm().
                     #apply(leaky_rectify).
                     custom_conv2d(512, k_h=4, k_w=4))
                #conv_batch_norm().
                #apply(leaky_rectify2))

                #linear

                #apply(tf.nn.sigmoid))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                s = self.image_shape[0]
                s2, s4, s8, s16, s32 = int(s / 2), int(s / 4), int(s / 8), int(
                    s / 16), int(s / 32)
                self.generator_template = \
                    (pt.template("input").

                     custom_fully_connected(s16 * s16 * 512).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, s16, s16,  512]).

                     #custom_fully_connected(s32 * s32 * 1024).
                     #fc_batch_norm().
                     #apply(tf.nn.relu).
                     #reshape([-1, s32, s32,  1024]).

                     #custom_deconv2d([0, s16, s16,  512], k_h=4, k_w=4).
                     #conv_batch_norm().
                     #apply(tf.nn.relu).

                     custom_deconv2d([0, s8, s8,  256], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).

                     custom_deconv2d([0, s4, s4, 128], k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(tf.nn.relu).

                     custom_deconv2d([0, s2, s2, 64], k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     apply(tf.nn.tanh))

        else:
            raise NotImplementedError

    def discriminate(self, x_var):
        d_out = self.discriminator_template.construct(input=x_var)
        d = tf.nn.sigmoid(d_out[:, 0])
        #d = tf.nn.sigmoid(d_out)
        reg_dist_flat = self.encoder_template.construct(input=x_var)
        reg_dist_info = self.reg_latent_dist.activate_dist(reg_dist_flat)
        return d, self.reg_latent_dist.sample(
            reg_dist_info), reg_dist_info, reg_dist_flat, d_out

    def generate(self, z_var):
        x_dist_flat = self.generator_template.construct(input=z_var)
        x_dist_info = self.output_dist.activate_dist(x_dist_flat)
        return self.output_dist.sample(x_dist_info), x_dist_info

    def disc_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(z_i)
        return self.reg_disc_latent_dist.join_vars(ret)

    def cont_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, Gaussian):
                ret.append(z_i)
        return self.reg_cont_latent_dist.join_vars(ret)

    def disc_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(dist_info_i)
        return self.reg_disc_latent_dist.join_dist_infos(ret)

    def cont_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, Gaussian):
                ret.append(dist_info_i)
        return self.reg_cont_latent_dist.join_dist_infos(ret)

    def reg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            pstr('reg_z split_var  z_i', z_i)
            if reg_i:
                ret.append(z_i)
        return self.reg_latent_dist.join_vars(ret)

    def nonreg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            if not reg_i:
                ret.append(z_i)
        return self.nonreg_latent_dist.join_vars(ret)

    def reg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if reg_i:
                ret.append(dist_info_i)
        return self.reg_latent_dist.join_dist_infos(ret)

    def nonreg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if not reg_i:
                ret.append(dist_info_i)
        return self.nonreg_latent_dist.join_dist_infos(ret)

    def combine_reg_nonreg_z(self, reg_z_var, nonreg_z_var):
        reg_z_vars = self.reg_latent_dist.split_var(reg_z_var)
        reg_idx = 0
        nonreg_z_vars = self.nonreg_latent_dist.split_var(nonreg_z_var)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_z_vars[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_z_vars[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_vars(ret)

    def combine_reg_nonreg_dist_info(self, reg_dist_info, nonreg_dist_info):
        reg_dist_infos = self.reg_latent_dist.split_dist_info(reg_dist_info)
        reg_idx = 0
        nonreg_dist_infos = self.nonreg_latent_dist.split_dist_info(
            nonreg_dist_info)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_dist_infos[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_dist_infos[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_dist_infos(ret)
Beispiel #13
0
class RegularizedGAN(GANModel):
    def __init__(self, latent_spec, **kwargs):
        """
        Args:
            latent_spec (list): List of latent distributions.
             [(Distribution, bool)]
             The boolean indicates if the distribution should be used for
             regularization.
        """

        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        super(RegularizedGAN, self).__init__(**kwargs)
        d = {
            'latent_code_influence': {
                'tensor': 'get_latent_code_influence_g_input_tensor',
            },
            'linear_interpolation': {
                'tensor': 'get_linear_interpolation_g_input_tensor',
            },
        }
        self.sampling_functions.update(d)

    def __getstate__(self):
        pickling_dict = super(RegularizedGAN, self).__getstate__()
        del pickling_dict['encoder_template']
        return pickling_dict

    def build_g_input(self):
        return self.latent_dist.sample_prior(self.batch_size)

    def get_g_feed_dict(self):
        return None

    def get_reg_dist_info(self, x_var):
        reg_dist_flat = self.encoder_template.construct(input=x_var)
        reg_dist_info = self.reg_latent_dist.activate_dist(reg_dist_flat)
        return reg_dist_info

    def disc_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(z_i)
        return self.reg_disc_latent_dist.join_vars(ret)

    def cont_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, Gaussian):
                ret.append(z_i)
        return self.reg_cont_latent_dist.join_vars(ret)

    def disc_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(dist_info_i)
        return self.reg_disc_latent_dist.join_dist_infos(ret)

    def cont_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, Gaussian):
                ret.append(dist_info_i)
        return self.reg_cont_latent_dist.join_dist_infos(ret)

    def reg_z(self, z_var=None):
        """ Return the variables with distribution bool == True (concatenated). """
        ret = []
        if z_var is None:
            z_var = self.g_input
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            if reg_i:
                ret.append(z_i)
        return self.reg_latent_dist.join_vars(ret)

    def nonreg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            if not reg_i:
                ret.append(z_i)
        return self.nonreg_latent_dist.join_vars(ret)

    def reg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if reg_i:
                ret.append(dist_info_i)
        return self.reg_latent_dist.join_dist_infos(ret)

    def nonreg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if not reg_i:
                ret.append(dist_info_i)
        return self.nonreg_latent_dist.join_dist_infos(ret)

    def combine_reg_nonreg_z(self, reg_z_var, nonreg_z_var):
        reg_z_vars = self.reg_latent_dist.split_var(reg_z_var)
        reg_idx = 0
        nonreg_z_vars = self.nonreg_latent_dist.split_var(nonreg_z_var)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_z_vars[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_z_vars[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_vars(ret)

    def combine_reg_nonreg_dist_info(self, reg_dist_info, nonreg_dist_info):
        reg_dist_infos = self.reg_latent_dist.split_dist_info(reg_dist_info)
        reg_idx = 0
        nonreg_dist_infos = self.nonreg_latent_dist.split_dist_info(
            nonreg_dist_info)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_dist_infos[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_dist_infos[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_dist_infos(ret)

    ###########################################################################
    # SAMPLING
    ###########################################################################
    def get_train_g_input_tensor(self):
        if len(self.reg_latent_dist.dists) > 0:
            sampling_type = 'latent_code_influence'
        else:
            sampling_type = 'random'
        return make_list(self.get_g_input_tensor(sampling_type))

    def get_g_input_value(self, sampling_type, **kwargs):
        # TODO: self.get_g_input_tensor should return a Tensor and
        # this method should .eval() its result with a tf.Session()
        if sampling_type == 'random':
            return self.get_g_input(sampling_type, 'value', **kwargs)
        return self.get_g_input_tensor(sampling_type, **kwargs)

    def get_linear_interpolation_g_input_tensor(self,
                                                n_samples=10,
                                                n_variations=10):
        """
        Returns:
            (ndarray, str):
        """
        # TODO: return a tensor instead of ndarray
        with tf.Session():
            n = n_samples * n_variations
            all_z_start = self.latent_dist.sample_prior(n_samples).eval()
            all_z_end = self.latent_dist.sample_prior(n_samples).eval()
            coefficients = np.linspace(start=0, stop=1, num=n_variations)
            z_var = []
            for z_start, z_end in zip(all_z_start, all_z_end):
                for coeff in coefficients:
                    z_var.append(coeff * z_start + (1 - coeff) * z_end)
            other = self.latent_dist.sample_prior(self.batch_size - n).eval()
            z_var = np.concatenate([z_var, other], axis=0)
            z_var = np.asarray(z_var, dtype=np.float32)
            return z_var, 'linear_interpolations'

    def get_latent_code_influence_g_input_tensor(self,
                                                 n_samples=10,
                                                 n_variations=10,
                                                 min_continuous=-2.,
                                                 max_continuous=2.):
        """
        Args:
            n_samples (int): The number of different samples (n_columns).
            n_variations (int): The number of variations for each latent code
             (n_rows).
            min_continuous (float): The minimum value for a continuous latent
             code
            max_continuous (float): The maximum value for a continuous latent
             code
        Returns:
            (ndarray, str):
        """
        # TODO: return a tensor instead of ndarray
        if len(self.reg_latent_dist.dists) == 0:
            raise ValueError('The model must have at least one regularization '
                             'latent distribution.')
        with tf.Session():
            # (n, d) with 10 * 10 samples + other samples
            n = n_samples * n_variations
            fixed_noncat = self.nonreg_latent_dist.sample_prior(n_samples)
            fixed_noncat = np.tile(fixed_noncat.eval(), [n_variations, 1])
            other = self.nonreg_latent_dist.sample_prior(self.batch_size - n)
            other = other.eval()
            fixed_noncat = np.concatenate([fixed_noncat, other], axis=0)

            fixed_cat = self.reg_latent_dist.sample_prior(n_samples).eval()
            fixed_cat = np.tile(fixed_cat, [n_variations, 1])
            other = self.reg_latent_dist.sample_prior(self.batch_size - n)
            other = other.eval()
            fixed_cat = np.concatenate([fixed_cat, other], axis=0)

        offset = 0
        z_vars_and_names = []
        for dist_idx, dist in enumerate(self.reg_latent_dist.dists):
            if isinstance(dist, Gaussian):
                assert dist.dim == 1, "Only dim=1 is currently supported"
                vary_cat = dist.varying_values(min_continuous, max_continuous,
                                               n_variations)
                vary_cat = np.repeat(vary_cat, n_samples)
                other = np.zeros(self.batch_size - n)
                vary_cat = np.concatenate((vary_cat, other)).reshape((-1, 1))
                vary_cat = np.asarray(vary_cat, dtype=np.float32)

                cur_cat = np.copy(fixed_cat)
                cur_cat[:, offset:offset + 1] = vary_cat
                offset += 1
            elif isinstance(dist, Categorical):
                lookup = np.eye(dist.dim, dtype=np.float32)
                cat_ids = []
                for idx in xrange(n_variations):
                    cat_ids.extend([idx] * n_samples)
                cat_ids.extend([0] * (self.batch_size - n))
                cur_cat = np.copy(fixed_cat)
                cur_cat[:, offset:offset + dist.dim] = lookup[cat_ids]
                offset += dist.dim
            elif isinstance(dist, Bernoulli):
                assert dist.dim == 1, "Only dim=1 is currently supported"
                lookup = np.eye(dist.dim, dtype=np.float32)
                cat_ids = []
                for idx in xrange(n_variations):
                    cat_ids.extend([int(idx / 5)] * n_samples)
                cat_ids.extend([0] * (self.batch_size - n))
                cur_cat = np.copy(fixed_cat)
                cur_cat[:, offset:offset + dist.dim] = np.expand_dims(
                    np.array(cat_ids), axis=-1)
                # import ipdb; ipdb.set_trace()
                offset += dist.dim
            else:
                raise NotImplementedError
            # (n, d)
            # The 10 first rows have different z and c and are tiled 10 times
            # except for the varying c that is the same by blocks of 10 rows
            # and linearly varies between blocks
            z_var = np.concatenate([fixed_noncat, cur_cat], axis=1)

            # Images where each column had a different fixed z and c
            # The varying c varies along each column
            # (a different value for each row)
            name = 'image_{}_{}'.format(dist_idx, dist.__class__.__name__)
            z_vars_and_names.append((z_var, name))
        return z_vars_and_names