Example #1
0
    def gp0_indv(mode, disc_indv_fn, inp_1, inp_2):
        # Gradient penalty for individual discriminator
        if mode == "interpolation":
            print("[BaseLatentModel.gp0_indv] interpolation!")
            assert (inp_1 is not None) and (inp_2 is not None), "If 'mode' is 'interpolation', " \
                                                                "both 'inp_1' and 'inp_2' must be provided!"

            alpha = tf.random_uniform(shape=mixed_shape(inp_1),
                                      minval=0.,
                                      maxval=1.)
            # (batch, z_dim)
            inp_i = alpha * inp_1 + ((1 - alpha) * inp_2)
            # (batch, z_dim)
            D_logit_i = disc_indv_fn(inp_i)

            D_grad_i = tf.gradients(D_logit_i, [inp_i])[0]
            D_grad_i = flatten_right_from(D_grad_i, 1)
            assert len(
                D_grad_i.shape) == 2, "'D_grad_i' must have 2 dimensions!"

            gp = tf.reduce_mean(tf.reduce_sum(tf.square(D_grad_i), axis=1),
                                axis=0)

        elif mode == "self":
            print("[BaseLatentModel.gp0_indv] self!")
            assert inp_1 is not None, "If 'mode' is 'self', 'inp_1' must be provided!"

            D_logit_1 = disc_indv_fn(inp_1)

            D_grad_1 = tf.gradients(D_logit_1, [inp_1])[0]
            D_grad_1 = flatten_right_from(D_grad_1, 1)
            assert len(
                D_grad_1.shape) == 2, "'D_grad_1' must have 2 dimensions!"

            gp = tf.reduce_mean(tf.reduce_sum(tf.square(D_grad_1), axis=1),
                                axis=0)

        elif mode == "other":
            print("[BaseLatentModel.gp0_indv] other!")
            assert inp_2 is not None, "If 'mode' is 'other', 'inp_2' must be provided!"

            D_logit_2 = disc_indv_fn(inp_2)

            D_grad_2 = tf.gradients(D_logit_2, [inp_2])[0]
            D_grad_2 = flatten_right_from(D_grad_2, 1)
            assert len(
                D_grad_2.shape) == 2, "'D_grad_2' must have 2 dimensions!"

            gp = tf.reduce_mean(tf.reduce_sum(tf.square(D_grad_2), axis=1),
                                axis=0)

        else:
            raise ValueError("Do not support 'mode'='{}'".format(mode))

        return gp
    def __call__(self, z, is_train, return_top_hid=False, scope=None, reuse=None):
        weight_init = tf.truncated_normal_initializer(stddev=0.02)
        dense = functools.partial(tf.layers.dense, kernel_initializer=weight_init)

        with tf.variable_scope(scope or self.__class__.__name__, reuse=reuse):
            z_shape = mixed_shape(z)
            if len(z_shape) == 4:
                assert z_shape[1] == z_shape[2] == 1
                z = flatten_right_from(z, 1)

            h = z
            for i in range(5):
                with tf.variable_scope("layer_{}".format(i)):
                    h = dense(h, 1000, use_bias=True)
                    h = tf.nn.leaky_relu(h, 0.2)

            with tf.variable_scope("top"):
                dZ = dense(h, self.num_outputs, use_bias=True)

            if self.num_outputs == 1:
                dZ = tf.reshape(dZ, mixed_shape(dZ)[:-1])

            outputs = [dZ]
            if return_top_hid:
                outputs.append(h)

            return outputs[0] if (len(outputs) == 1) else tuple(outputs)
    def __call__(self, z, is_train, return_distribution=False, return_top_hid=False, scope=None, reuse=None):
        activation = self.activation

        weight_init = tf.truncated_normal_initializer(stddev=0.02)
        deconv2d = functools.partial(tf.layers.conv2d_transpose, kernel_initializer=weight_init)
        dense = functools.partial(tf.layers.dense, kernel_initializer=weight_init)

        with tf.variable_scope(scope or self.__class__.__name__, reuse=reuse):
            z_shape = mixed_shape(z)
            if len(z_shape) == 4:
                assert z_shape[1] == z_shape[2] == 1
                z = flatten_right_from(z, axis=1)
            batch_size = z_shape[0]

            # (z_dim,)
            h = z

            with tf.variable_scope("block_1"):
                # (128,)
                h = dense(h, 128, use_bias=True)
                h = activation(h)

            with tf.variable_scope("block_2"):
                h = dense(h, 4 * 4 * 64, use_bias=True)
                h = activation(h)

            with tf.variable_scope("block_3"):
                h = tf.reshape(h, [batch_size, 4, 4, 64])
                h = deconv2d(h, filters=64, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            with tf.variable_scope("block_4"):
                # (16, 16, 32)
                h = deconv2d(h, filters=32, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            with tf.variable_scope("block_5"):
                # (32, 32, 32)
                h = deconv2d(h, filters=32, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            with tf.variable_scope("top"):
                # (64, 64, 1)
                x_logit = deconv2d(h, filters=self.x_dim[-1], kernel_size=4, strides=2, padding="same")
                x = tf.nn.sigmoid(x_logit)

            outputs = [x]
            if return_distribution:
                outputs.append({'prob': x, 'logit': x_logit})
            if return_top_hid:
                outputs.append(h)

            return outputs[0] if (len(outputs) == 1) else tuple(outputs)
Example #4
0
    def gp_4_inp_list(center,
                      mode,
                      disc_fn,
                      inputs_1,
                      inputs_2,
                      at_D_comp=None):
        if mode == "interpolation":
            print("[BaseLatentModel.gp_4_inp_list] interpolation!")
            assert isinstance(inputs_1, (list, tuple))
            assert isinstance(inputs_1, (list, tuple))

            inputs_i = []

            for inp_1, inp_2 in zip(inputs_1, inputs_2):
                alpha = tf.random_uniform(shape=mixed_shape(inp_1),
                                          minval=0.,
                                          maxval=1.)
                # (batch, z_dim)
                inp_i = alpha * inp_1 + ((1 - alpha) * inp_2)
                inputs_i.append(inp_i)

            if at_D_comp is not None:
                D_logit_i = disc_fn(inputs_i)
                D_logit_i = D_logit_i[:, at_D_comp]
            else:
                D_logit_i = disc_fn(inputs_i)

            assert len(
                D_logit_i.shape.as_list()) == 1, D_logit_i.shape.as_list()

            # List of gradient for each input component
            D_grads = tf.gradients(D_logit_i, inputs_i)
            D_grads = [flatten_right_from(Dg, axis=1) for Dg in D_grads]

        else:
            raise ValueError("Do not support 'mode'='{}'".format(mode))

        gp = tf.constant(0, dtype=tf.float32)

        for Dg in D_grads:
            assert len(Dg.shape) == 2, "'Dg' must have 2 dimensions!"
            slope = tf.sqrt(tf.reduce_sum(tf.square(Dg), axis=1))
            gp += tf.reduce_mean((slope - center)**2, axis=0)

        return gp
    def build(self, loss_coeff_dict):
        lc = loss_coeff_dict
        coeff_fn = self.one_if_not_exist

        vae_loss = 0
        Dz_loss = 0

        # Encoder / Decoder
        # ================================== #
        x0 = self.x_ph
        z1_gen, z1_gen_dist = self.encoder_fn(x0, return_distribution=True)
        x1 = self.decoder_fn(z1_gen)

        z0 = self.z_ph
        x1_gen = self.decoder_fn(z0)

        self.set_output('z1_gen', z1_gen)
        self.set_output('z_mean', z1_gen_dist['mean'])
        self.set_output('z_stddev', z1_gen_dist['stddev'])

        self.set_output('x1_gen', x1_gen)
        self.set_output('x1', x1)
        # ================================== #

        # Reconstruction loss
        # ================================== #
        print("[FactorVAE] rec_x_mode: {}".format(self.rec_x_mode))
        if self.rec_x_mode == 'mse':
            rec_x = tf.reduce_sum(tf.square(
                flatten_right_from(x0, 1) - flatten_right_from(x1, 1)),
                                  axis=1)
        elif self.rec_x_mode == 'l1':
            rec_x = tf.reduce_sum(
                tf.abs(flatten_right_from(x0, 1) - flatten_right_from(x1, 1)),
                axis=1)
        elif self.rec_x_mode == 'bce':
            _, x1_dist = self.decoder_fn(z1_gen, return_distribution=True)
            rec_x = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                labels=flatten_right_from(x0, 1),
                logits=flatten_right_from(x1_dist['logit'], 1)),
                                  axis=1)
        else:
            raise ValueError("Do not support '{}'!".format(self.rec_x_mode))

        rec_x = tf.reduce_mean(rec_x, axis=0)

        assert rec_x.shape.ndims == 0, "rec_x.shape: {}".format(rec_x.shape)
        self.set_output('rec_x', rec_x)

        vae_loss += coeff_fn(lc, 'rec_x') * rec_x
        # ================================== #

        # KL divergence loss
        # ================================== #
        kld_loss = KLD_DiagN_N01(z1_gen_dist['mean'],
                                 z1_gen_dist['log_stddev'],
                                 from_axis=1)
        kld_loss = tf.reduce_mean(kld_loss, axis=0)

        assert kld_loss.shape.ndims == 0, "kld_loss.shape: {}".format(
            kld_loss.shape)
        self.set_output('kld_loss', kld_loss)

        vae_loss += coeff_fn(lc, 'kld_loss') * kld_loss
        # ================================== #

        # Discriminator loss to estimate total correlation
        # ================================== #
        xa = self.xa_ph
        z1a_gen = self.encoder_fn(xa, return_distribution=False)
        z1a_gen_perm = tf.stop_gradient(shuffle_batch_4_each_feature(z1a_gen))

        D_logit2_z1_gen = self.disc_z_fn(z1_gen)
        D_logit2_z1a_gen_perm = self.disc_z_fn(z1a_gen_perm)

        # Probability of 'being as usual' and probability of 'not being as usual'
        D_prob2_z1_gen = tf.nn.softmax(D_logit2_z1_gen)
        D_prob2_z1a_gen_perm = tf.nn.softmax(D_logit2_z1a_gen_perm)

        D_loss_z1_gen_tc = -tf.reduce_mean(
            tf.nn.log_softmax(D_logit2_z1_gen)[:, 1], axis=0)
        D_loss_z1a_gen_perm = -tf.reduce_mean(
            tf.nn.log_softmax(D_logit2_z1a_gen_perm)[:, 0], axis=0)

        self.set_output('Dz_prob2_normal', D_prob2_z1_gen)
        self.set_output('Dz_prob2_factor', D_prob2_z1a_gen_perm)
        self.set_output('Dz_avg_prob_normal',
                        tf.reduce_mean(D_prob2_z1_gen[:, 0], axis=0))
        self.set_output('Dz_avg_prob_factor',
                        tf.reduce_mean(D_prob2_z1a_gen_perm[:, 0], axis=0))
        self.set_output('Dz_loss_normal', D_loss_z1_gen_tc)
        self.set_output('Dz_loss_factor', D_loss_z1a_gen_perm)

        Dz_tc_loss = 0.5 * (D_loss_z1_gen_tc + D_loss_z1a_gen_perm)
        assert Dz_tc_loss.shape.ndims == 0, "Dz_tc_loss.shape: {}".format(
            Dz_tc_loss.shape)
        self.set_output('Dz_tc_loss', Dz_tc_loss)

        tc_loss = -tf.reduce_mean(
            D_logit2_z1_gen[:, 0] - D_logit2_z1_gen[:, 1], axis=0)
        assert tc_loss.shape.ndims == 0, "tc_loss.shape: {}".format(
            tc_loss.shape)
        self.set_output('tc_loss', tc_loss)

        Dz_loss += coeff_fn(lc, 'Dz_tc_loss') * Dz_tc_loss
        vae_loss += coeff_fn(lc, 'tc_loss') * tc_loss
        # ================================== #

        # Gradient penalty term for z
        # ================================== #
        if self.use_gp0_z_tc:
            print("[FactorVAE] use_gp0_z_tc: {}".format(self.use_gp0_z_tc))
            print("[FactorVAE] gp0_z_tc_mode: {}".format(self.gp0_z_tc_mode))

            gp0_z_tc = self.gp0(self.gp0_z_tc_mode,
                                self.disc_z_fn,
                                z1_gen,
                                z1a_gen_perm,
                                at_D_comp=0)

            self.set_output('gp0_z_tc', gp0_z_tc)

            Dz_loss += coeff_fn(lc, 'gp0_z_tc') * gp0_z_tc
        # ================================== #

        print("All loss coefficients:")
        pp.pprint(self.loss_coeff_dict)

        self.set_output('Dz_loss', Dz_loss)
        self.set_output('vae_loss', vae_loss)
Example #6
0
    def __call__(self, x, is_train, stochastic=None, scope=None, reuse=None):
        weight_init = get_weight_initializer('he_normal')
        conv2d = partial(conv2d_,
                         weight_initializer=weight_init,
                         activation=None)
        fc = partial(fc_, weight_initializer=weight_init, activation=None)
        bn = partial(batch_norm_, momentum=self.bn_momentum, epsilon=1e-8)
        act = partial(tf.nn.leaky_relu, alpha=0.1)

        if stochastic is None:
            stochastic = is_train

        with tf.variable_scope(self.scope, reuse=reuse):
            x_shape = x.shape.as_list()
            assert len(x_shape) == 4 and x_shape[1] == x_shape[2] == 32 and x_shape[3] == 3, \
                "x.shape={}".format(x_shape)

            h = x

            if self.use_gauss_noise:
                h = gauss_noise(h, stochastic, std=0.15)

            # conv 1
            # --------------------------------------- #
            with tf.variable_scope("conv_1a"):
                # (32, 32, 3) => (32, 32, 128)
                h = conv2d(h,
                           filters=128,
                           kernel_size=3,
                           strides=1,
                           padding="SAME",
                           use_bias=False)
                h = bn(h, is_train)
                h = act(h)

            with tf.variable_scope("conv_1b"):
                # (32, 32, 128) => (32, 32, 128)
                h = conv2d(h,
                           filters=128,
                           kernel_size=3,
                           strides=1,
                           padding="SAME",
                           use_bias=False)
                h = bn(h, is_train)
                h = act(h)

            with tf.variable_scope("conv_1c"):
                # (32, 32, 128) => (32, 32, 128)
                h = conv2d(h,
                           filters=128,
                           kernel_size=3,
                           strides=1,
                           padding="SAME",
                           use_bias=False)
                h = bn(h, is_train)
                h = act(h)
            # --------------------------------------- #

            # (32, 32, 128) => (16, 16, 128)
            h = max_pool_2D(h, pool_size=2)
            h = dropout_2D(h, stochastic, drop_rate=0.5)

            # conv 2
            # --------------------------------------- #
            with tf.variable_scope("conv_2a"):
                # (16, 16, 128) => (16, 16, 256)
                h = conv2d(h,
                           filters=256,
                           kernel_size=3,
                           strides=1,
                           padding="SAME",
                           use_bias=False)
                h = bn(h, is_train)
                h = act(h)

            with tf.variable_scope("conv_2b"):
                # (16, 16, 256) => (16, 16, 256)
                h = conv2d(h,
                           filters=256,
                           kernel_size=3,
                           strides=1,
                           padding="SAME",
                           use_bias=False)
                h = bn(h, is_train)
                h = act(h)

            with tf.variable_scope("conv_2c"):
                # (16, 16, 256) => (16, 16, 256)
                h = conv2d(h,
                           filters=256,
                           kernel_size=3,
                           strides=1,
                           padding="SAME",
                           use_bias=False)
                h = bn(h, is_train)
                h = act(h)
            # --------------------------------------- #

            # (16, 16, 256) => (8, 8, 256)
            h = max_pool_2D(h, pool_size=2)
            h = dropout_2D(h, stochastic, drop_rate=0.5)

            # conv 3
            # --------------------------------------- #
            with tf.variable_scope("conv_3a"):
                # (8, 8, 256) => (6, 6, 512)
                h = conv2d(h,
                           filters=512,
                           kernel_size=3,
                           strides=1,
                           padding="VALID",
                           use_bias=False)
                h = bn(h, is_train)
                h = act(h)

            with tf.variable_scope("conv_3b"):
                # (6, 6, 512) => (6, 6, 256)
                h = conv2d(h,
                           filters=256,
                           kernel_size=1,
                           strides=1,
                           padding="VALID",
                           use_bias=False)
                h = bn(h, is_train)
                h = act(h)

            with tf.variable_scope("conv_3c"):
                # (6, 6, 256) => (6, 6, 128)
                h = conv2d(h,
                           filters=128,
                           kernel_size=1,
                           strides=1,
                           padding="VALID",
                           use_bias=False)
                h = bn(h, is_train)
                h = act(h)
            # --------------------------------------- #

            # (1, 1, 128)
            h = avg_pool_2D(h, pool_size=6)
            # (128,)
            h = flatten_right_from(h, 1)

            y_logit = fc(h, hid_dim=self.num_classes, use_bias=True)
            y_prob = tf.nn.softmax(y_logit)

            return {'logit': y_logit, 'prob': y_prob, 'hid': h}
    def __call__(self, x, is_train, return_distribution=False, return_top_hid=False, scope=None, reuse=None):
        activation = self.activation

        weight_init = tf.truncated_normal_initializer(stddev=0.02)
        conv2d = functools.partial(tf.layers.conv2d, kernel_initializer=weight_init)
        dense = functools.partial(tf.layers.dense, kernel_initializer=weight_init)

        with tf.variable_scope(scope or self.__class__.__name__, reuse=reuse):
            x_shape = mixed_shape(x)
            assert len(x_shape) == 4 and x_shape[1] == x_shape[2] == 64

            h = x

            # with tf.variable_scope("block_1"):
            with tf.variable_scope("conv_1"):
                # (32, 32, 32)
                h = conv2d(h, filters=32, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            # with tf.variable_scope("block_2"):
            with tf.variable_scope("conv_2"):
                # (16, 16, 32)
                h = conv2d(h, filters=32, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            # with tf.variable_scope("block_3"):
            with tf.variable_scope("conv_3"):
                # (8, 8, 64)
                h = conv2d(h, filters=64, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            # with tf.variable_scope("block_4"):
            with tf.variable_scope("conv_4"):
                # (4, 4, 64)
                h = conv2d(h, filters=64, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            # with tf.variable_scope("block_5"):
            with tf.variable_scope("block_5"):
                # (4 * 4 * 64,)
                h = flatten_right_from(h, 1)
                h = dense(h, 128, use_bias=True)
                h = activation(h)

            with tf.variable_scope("top"):
                if self.stochastic:
                    mu = dense(h, self.z_dim, use_bias=True)
                    log_sigma = dense(h, self.z_dim, use_bias=True)
                    sigma = tf.exp(log_sigma)

                    eps = tf.random_normal(shape=mixed_shape(mu), mean=0.0, stddev=1.0)
                    z = mu + eps * sigma
                else:
                    z = dense(h, self.z_dim, use_bias=True)
                    mu = log_sigma = sigma = None

                dist = {'mean': mu, 'log_stddev': log_sigma, 'stddev': sigma}

            outputs = [z]
            if return_distribution:
                outputs.append(dist)
            if return_top_hid:
                outputs.append(h)

            return outputs[0] if (len(outputs) == 1) else tuple(outputs)
Example #8
0
    def build(self, loss_coeff_dict):
        lc = loss_coeff_dict
        coeff_fn = self.one_if_not_exist

        batch_size = mixed_shape(self.x_ph)[0]
        ones = tf.ones([batch_size], dtype=tf.float32)
        zeros = tf.zeros([batch_size], dtype=tf.float32)

        Dz_loss = 0
        AE_loss = 0

        # Encoder / Decoder
        # ================================== #
        x0 = self.x_ph
        z1_gen, z1_gen_dist = self.encoder_fn(x0, return_distribution=True)
        x1 = self.decoder_fn(z1_gen)

        z0 = self.z_ph
        x1_gen = self.decoder_fn(z0)

        self.set_output('z1_gen', z1_gen)

        if self.stochastic_z:
            assert z1_gen_dist is not None, "'z1_gen_dist' must be not None!"
            self.set_output('z_mean', z1_gen_dist['mean'])
            self.set_output('z_stddev', z1_gen_dist['stddev'])
        else:
            assert z1_gen_dist is None, "'z1_gen_dist' must be None!"
            self.set_output('z_mean', tf.zeros_like(z1_gen))
            self.set_output('z_stddev', tf.ones_like(z1_gen))

        self.set_output('x1_gen', x1_gen)
        self.set_output('x1', x1)
        # ================================== #

        # Reconstruct x
        # ================================== #
        print("[AAE] rec_x_mode: {}".format(self.rec_x_mode))
        if self.rec_x_mode == 'mse':
            rec_x = tf.reduce_sum(tf.square(flatten_right_from(x0, 1) -
                                            flatten_right_from(x1, 1)), axis=1)
        elif self.rec_x_mode == 'l1':
            rec_x = tf.reduce_sum(tf.abs(flatten_right_from(x0, 1) -
                                         flatten_right_from(x1, 1)), axis=1)
        elif self.rec_x_mode == 'bce':
            _, x1_dist = self.decoder_fn(z1_gen, return_distribution=True)
            rec_x = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                labels=flatten_right_from(x0, 1),
                logits=flatten_right_from(x1_dist['logit'], 1)), axis=1)
        else:
            raise ValueError("Do not support '{}'!".format(self.rec_x_mode))

        rec_x = tf.reduce_mean(rec_x, axis=0)

        assert rec_x.shape.ndims == 0, "rec_x.shape: {}".format(rec_x.shape)
        self.set_output('rec_x', rec_x)

        AE_loss += coeff_fn(lc, 'rec_x') * rec_x
        # ================================== #

        # Discriminate z
        # ================================== #
        # E_p(z)[log D(z0)] +  E_p(x)p(z|x)[log(1 - D(z1_gen))]
        D_logit_z0 = self.disc_z_fn(z0)
        D_logit_z1_gen = self.disc_z_fn(z1_gen)

        D_prob_z0 = tf.nn.sigmoid(D_logit_z0)
        D_prob_z1_gen = tf.nn.sigmoid(D_logit_z1_gen)

        D_loss_z0 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=D_logit_z0, labels=ones), axis=0)
        D_loss_z1_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=D_logit_z1_gen, labels=zeros), axis=0)
        G_loss_z1_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=D_logit_z1_gen, labels=ones), axis=0)

        self.set_output('D_logit_z0', D_logit_z0)
        self.set_output('D_prob_z0', D_prob_z0)
        self.set_output('D_prob_z1_gen', D_prob_z1_gen)
        self.set_output('D_avg_prob_z0', tf.reduce_mean(D_prob_z0, axis=0))
        self.set_output('D_avg_prob_z1_gen', tf.reduce_mean(D_prob_z1_gen, axis=0))

        self.set_output('D_loss_z0', D_loss_z0)
        self.set_output('D_loss_z1_gen', D_loss_z1_gen)
        self.set_output('G_loss_z1_gen', G_loss_z1_gen)

        Dz_loss += coeff_fn(lc, "D_loss_z1_gen") * (D_loss_z0 + D_loss_z1_gen)
        AE_loss += coeff_fn(lc, "G_loss_z1_gen") * G_loss_z1_gen
        # ================================== #

        # Gradient penalty term for z
        # ================================== #
        if self.use_gp0_z:
            print("[AAE] use_gp0_z: {}".format(self.use_gp0_z))
            print("[AAE] gp0_z_mode: {}".format(self.gp0_z_mode))

            gp0_z = self.gp0("interpolation", self.disc_z_fn, z1_gen, z0)
            self.set_output('gp0_z', gp0_z)

            Dz_loss += coeff_fn(lc, 'gp0_z') * gp0_z
        # ================================== #

        self.set_output('Dz_loss', Dz_loss)
        self.set_output('AE_loss', AE_loss)
    def __call__(self,
                 x,
                 is_train,
                 return_distribution=False,
                 return_top_hid=False,
                 scope=None):
        activation = self.activation

        weight_init = tf.truncated_normal_initializer(stddev=0.02)
        conv2d = functools.partial(tf.layers.conv2d,
                                   kernel_initializer=weight_init)
        dense = functools.partial(tf.layers.dense,
                                  kernel_initializer=weight_init)

        with tf.variable_scope(scope or self.__class__.__name__):
            x_shape = mixed_shape(x)
            assert len(x_shape) == 4 and x_shape[1] == x_shape[2] == 64

            h = x

            with tf.variable_scope("block_1"):
                # (32, 32, 32)
                h = conv2d(h,
                           filters=32,
                           kernel_size=4,
                           strides=2,
                           padding="same")
                h = activation(h)

            with tf.variable_scope("block_2"):
                # (16, 16, 32)
                h = conv2d(h,
                           filters=32,
                           kernel_size=4,
                           strides=2,
                           padding="same")
                h = activation(h)

            with tf.variable_scope("block_3"):
                # (8, 8, 64)
                h = conv2d(h,
                           filters=64,
                           kernel_size=4,
                           strides=2,
                           padding="same")
                h = activation(h)

            with tf.variable_scope("block_4"):
                # (4, 4, 64)
                h = conv2d(h,
                           filters=64,
                           kernel_size=4,
                           strides=2,
                           padding="same")
                h = activation(h)

            with tf.variable_scope("block_5"):
                # Only change 128 (dSprites) to 256 (Chairs3D, CelebA), keep other unchanged
                # (1, 1, 256)
                h = conv2d(h,
                           filters=256,
                           kernel_size=4,
                           strides=1,
                           padding="valid")
                h = activation(h)

            with tf.variable_scope("top"):
                # (256,)
                h = flatten_right_from(h, 1)

                if self.stochastic:
                    mu = dense(h, self.z_dim, use_bias=True)
                    log_sigma = dense(h, self.z_dim, use_bias=True)
                    sigma = tf.exp(log_sigma)

                    eps = tf.random_normal(shape=mixed_shape(mu),
                                           mean=0.0,
                                           stddev=1.0)
                    z = mu + eps * sigma

                    mu = reshape_4_batch(mu, self.z_shape, num_batch_axes=1)
                    log_sigma = reshape_4_batch(log_sigma,
                                                self.z_shape,
                                                num_batch_axes=1)
                    z = reshape_4_batch(z, self.z_shape, num_batch_axes=1)

                else:
                    z = dense(h, self.z_dim, use_bias=True)
                    mu = log_sigma = sigma = None

                    z = reshape_4_batch(z, self.z_shape, num_batch_axes=1)

                dist = {'mean': mu, 'log_stddev': log_sigma, 'stddev': sigma}

            outputs = [z]
            if return_distribution:
                outputs.append(dist)
            if return_top_hid:
                outputs.append(h)

            return outputs[0] if (len(outputs) == 1) else tuple(outputs)
    def __call__(self,
                 z,
                 is_train,
                 return_distribution=False,
                 return_top_hid=False,
                 scope=None):
        activation = self.activation

        weight_init = tf.truncated_normal_initializer(stddev=0.02)
        deconv2d = functools.partial(tf.layers.conv2d_transpose,
                                     kernel_initializer=weight_init)
        dense = functools.partial(tf.layers.dense,
                                  kernel_initializer=weight_init)

        with tf.variable_scope(scope or self.__class__.__name__):
            # z_shape = mixed_shape(z)
            # if len(z_shape) == 4:
            #     assert z_shape[1] == z_shape[2] == 1
            #     z = flatten_right_from(z, axis=1)
            # batch_size = z_shape[0]

            # (z_dim,)
            z = flatten_right_from(z, axis=1)
            h = z

            with tf.variable_scope("block_1"):
                # Only change 128 (dSprites) to 256 (Chairs3D, CelebA), keep other unchanged
                # (256,)
                h = dense(h, 256, use_bias=True)
                h = activation(h)

            with tf.variable_scope("block_2"):
                h = reshape_4_batch(h, [1, 1, 256], num_batch_axes=1)
                # (4, 4, 64)
                h = deconv2d(h,
                             filters=64,
                             kernel_size=4,
                             strides=1,
                             padding="valid")
                h = activation(h)

            with tf.variable_scope("block_3"):
                # First, we up-sample to achieve image size of (8, 8)
                # Then, we do transposed convolution
                # (8, 8, 64)
                h = deconv2d(h,
                             filters=64,
                             kernel_size=4,
                             strides=2,
                             padding="same")
                h = activation(h)

            with tf.variable_scope("block_4"):
                # (16, 16, 32)
                h = deconv2d(h,
                             filters=32,
                             kernel_size=4,
                             strides=2,
                             padding="same")
                h = activation(h)

            with tf.variable_scope("block_5"):
                # (32, 32, 32)
                h = deconv2d(h,
                             filters=32,
                             kernel_size=4,
                             strides=2,
                             padding="same")
                h = activation(h)

            with tf.variable_scope("top"):
                # (64, 64, 1) or (64, 64, 3)
                x_logit = deconv2d(h,
                                   filters=self.x_shape[-1],
                                   kernel_size=4,
                                   strides=2,
                                   padding="same")
                x = self.output_activation(x_logit)

            outputs = [x]
            if return_distribution:
                outputs.append({'prob': x, 'logit': x_logit})
            if return_top_hid:
                outputs.append(h)

            return outputs[0] if (len(outputs) == 1) else tuple(outputs)
Example #11
0
    def gp(center, mode, disc_fn, inp_1, inp_2, at_D_comp=None):
        if mode == "interpolation":
            print("[BaseLatentModel.gp] interpolation!")
            assert (inp_1 is not None) and (inp_2 is not None), "If 'mode' is 'interpolation', " \
                "both 'inp_1' and 'inp_2' must be provided!"

            # IMPORTANT: The formula below is wrong (compared to the original).
            # IMPORTANT: It must be [batch, 1, 1, ...], not mixed_shape(inp_1)
            # alpha = tf.random_uniform(shape=mixed_shape(inp_1), minval=0., maxval=1.)

            shape = mixed_shape(inp_1)
            batch_size = shape[0]
            # (batch, 1)
            alpha = tf.random_uniform(shape=[batch_size] + [1] *
                                      (len(shape) - 1),
                                      minval=0.,
                                      maxval=1.)

            # (batch, z_dim)
            inp_i = alpha * inp_1 + ((1 - alpha) * inp_2)

            if at_D_comp is not None:
                D_logit = disc_fn(inp_i)
                D_logit_i = D_logit[:, at_D_comp]
            else:
                D_logit_i = disc_fn(inp_i)

            assert len(
                D_logit_i.shape.as_list()) == 1, D_logit_i.shape.as_list()

            D_grad = tf.gradients(D_logit_i, [inp_i])[0]
            D_grad = flatten_right_from(D_grad, 1)

        elif mode == "self":
            print("[BaseLatentModel.gp] self!")
            assert inp_1 is not None, "If 'mode' is 'self', 'inp_1' must be provided!"

            if at_D_comp is not None:
                D_logit = disc_fn(inp_1)
                D_logit_1 = D_logit[:, at_D_comp]
            else:
                D_logit_1 = disc_fn(inp_1)

            assert len(
                D_logit_1.shape.as_list()) == 1, D_logit_1.shape.as_list()

            D_grad = tf.gradients(D_logit_1, [inp_1])[0]
            D_grad = flatten_right_from(D_grad, 1)

        elif mode == "other":
            print("[BaseLatentModel.gp] other!")
            assert inp_2 is not None, "If 'mode' is 'other', 'inp_2' must be provided!"

            if at_D_comp is not None:
                D_logit = disc_fn(inp_2)
                D_logit_2 = D_logit[:, at_D_comp]
            else:
                D_logit_2 = disc_fn(inp_2)

            assert len(
                D_logit_2.shape.as_list()) == 1, D_logit_2.shape.as_list()

            D_grad = tf.gradients(D_logit_2, [inp_2])[0]
            D_grad = flatten_right_from(D_grad, 1)
            assert len(D_grad.shape) == 2, "'D_grad' must have 2 dimensions!"

        else:
            raise ValueError("Do not support 'mode'='{}'".format(mode))

        assert len(D_grad.shape) == 2, "'D_grad' must have 2 dimensions!"

        if center == 0:
            gp = tf.reduce_mean(tf.reduce_sum(tf.square(D_grad), axis=1),
                                axis=0)
        else:
            slope = tf.sqrt(tf.reduce_sum(tf.square(D_grad), axis=1))
            gp = tf.reduce_mean((slope - center)**2, axis=0)

        return gp