Beispiel #1
0
    def __init__(self, latent_spec, **kwargs):
        """
        Args:
            latent_spec (list): List of latent distributions.
             [(Distribution, bool)]
             The boolean indicates if the distribution should be used for
             regularization.
        """

        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        super(RegularizedGAN, self).__init__(**kwargs)
        d = {
            'latent_code_influence': {
                'tensor': 'get_latent_code_influence_g_input_tensor',
            },
            'linear_interpolation': {
                'tensor': 'get_linear_interpolation_g_input_tensor',
            },
        }
        self.sampling_functions.update(d)
Beispiel #2
0
    def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product([x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli))])

        image_size = image_shape[0]
        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size / 4 * image_size / 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 4, image_size / 4, 128]).
                     custom_deconv2d([0, image_size // 2, image_size // 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     flatten())
        else:
            raise NotImplementedError
Beispiel #3
0
    def __init__(self,
                 output_dist,
                 latent_spec,
                 is_reg,
                 batch_size,
                 image_shape,
                 network_type,
                 impr=False):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.impr = impr
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.is_reg = is_reg
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        self.keys = ['prob', 'logits', 'features']

        if self.is_reg:
            self.encoder_dim = self.reg_latent_dist.dist_flat_dim
            self.keys = self.keys + ['reg_dist_info']
        else:
            self.encoder_dim = None

        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        self.set_D_net()
        self.set_G_net()
Beispiel #4
0
    def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product([x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli))])
Beispiel #5
0
    def __init__(self, output_dist, latent_spec, batch_size, image_shape,
                 network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        pstr('output_dist', self.output_dist)
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        pstr('latent_dist', self.latent_dist)
        pstr('x in latent_spec', [x for x, _ in self.latent_spec])
        pstr('xreg in latent_spec', [xreg for _, xreg in self.latent_spec])
        #for x in enumerate(self.latent_spec):
        #   print '------------------------'
        #   for y in enumerate(x):
        #      pstrall('x----reg',y)

        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)
        #for x in self.reg_latent_dist.dists:
        #   pstr('x in reg_latent_dist.dists',x)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        pstr('image_shape', image_shape)
        pstr('image_shape[0]', image_shape[0])
        image_size = image_shape[0]

        #self.image_shape = (178, 218, 1)

        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(leaky_rectify).

                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).

                     custom_conv2d(256, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).

                     #custom_fully_connected(1024).
                     #fc_batch_norm().
                     #apply(leaky_rectify).
                     custom_conv2d(512, k_h=4, k_w=4))
                #conv_batch_norm().
                #apply(leaky_rectify2))

                #linear

                #apply(tf.nn.sigmoid))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                s = self.image_shape[0]
                s2, s4, s8, s16, s32 = int(s / 2), int(s / 4), int(s / 8), int(
                    s / 16), int(s / 32)
                self.generator_template = \
                    (pt.template("input").

                     custom_fully_connected(s16 * s16 * 512).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, s16, s16,  512]).

                     #custom_fully_connected(s32 * s32 * 1024).
                     #fc_batch_norm().
                     #apply(tf.nn.relu).
                     #reshape([-1, s32, s32,  1024]).

                     #custom_deconv2d([0, s16, s16,  512], k_h=4, k_w=4).
                     #conv_batch_norm().
                     #apply(tf.nn.relu).

                     custom_deconv2d([0, s8, s8,  256], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).

                     custom_deconv2d([0, s4, s4, 128], k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(tf.nn.relu).

                     custom_deconv2d([0, s2, s2, 64], k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     apply(tf.nn.tanh))

        else:
            raise NotImplementedError