def __call__(self, z, is_train, return_top_hid=False, scope=None, reuse=None):
        weight_init = tf.truncated_normal_initializer(stddev=0.02)
        dense = functools.partial(tf.layers.dense, kernel_initializer=weight_init)

        with tf.variable_scope(scope or self.__class__.__name__, reuse=reuse):
            z_shape = mixed_shape(z)
            if len(z_shape) == 4:
                assert z_shape[1] == z_shape[2] == 1
                z = flatten_right_from(z, 1)

            h = z
            for i in range(5):
                with tf.variable_scope("layer_{}".format(i)):
                    h = dense(h, 1000, use_bias=True)
                    h = tf.nn.leaky_relu(h, 0.2)

            with tf.variable_scope("top"):
                dZ = dense(h, self.num_outputs, use_bias=True)

            if self.num_outputs == 1:
                dZ = tf.reshape(dZ, mixed_shape(dZ)[:-1])

            outputs = [dZ]
            if return_top_hid:
                outputs.append(h)

            return outputs[0] if (len(outputs) == 1) else tuple(outputs)
    def gp0_indv(mode, disc_indv_fn, inp_1, inp_2):
        # Gradient penalty for individual discriminator
        if mode == "interpolation":
            print("[BaseLatentModel.gp0_indv] interpolation!")
            assert (inp_1 is not None) and (inp_2 is not None), "If 'mode' is 'interpolation', " \
                                                                "both 'inp_1' and 'inp_2' must be provided!"

            alpha = tf.random_uniform(shape=mixed_shape(inp_1),
                                      minval=0.,
                                      maxval=1.)
            # (batch, z_dim)
            inp_i = alpha * inp_1 + ((1 - alpha) * inp_2)
            # (batch, z_dim)
            D_logit_i = disc_indv_fn(inp_i)

            D_grad_i = tf.gradients(D_logit_i, [inp_i])[0]
            D_grad_i = flatten_right_from(D_grad_i, 1)
            assert len(
                D_grad_i.shape) == 2, "'D_grad_i' must have 2 dimensions!"

            gp = tf.reduce_mean(tf.reduce_sum(tf.square(D_grad_i), axis=1),
                                axis=0)

        elif mode == "self":
            print("[BaseLatentModel.gp0_indv] self!")
            assert inp_1 is not None, "If 'mode' is 'self', 'inp_1' must be provided!"

            D_logit_1 = disc_indv_fn(inp_1)

            D_grad_1 = tf.gradients(D_logit_1, [inp_1])[0]
            D_grad_1 = flatten_right_from(D_grad_1, 1)
            assert len(
                D_grad_1.shape) == 2, "'D_grad_1' must have 2 dimensions!"

            gp = tf.reduce_mean(tf.reduce_sum(tf.square(D_grad_1), axis=1),
                                axis=0)

        elif mode == "other":
            print("[BaseLatentModel.gp0_indv] other!")
            assert inp_2 is not None, "If 'mode' is 'other', 'inp_2' must be provided!"

            D_logit_2 = disc_indv_fn(inp_2)

            D_grad_2 = tf.gradients(D_logit_2, [inp_2])[0]
            D_grad_2 = flatten_right_from(D_grad_2, 1)
            assert len(
                D_grad_2.shape) == 2, "'D_grad_2' must have 2 dimensions!"

            gp = tf.reduce_mean(tf.reduce_sum(tf.square(D_grad_2), axis=1),
                                axis=0)

        else:
            raise ValueError("Do not support 'mode'='{}'".format(mode))

        return gp
    def __call__(self, z, is_train, return_distribution=False, return_top_hid=False, scope=None, reuse=None):
        activation = self.activation

        weight_init = tf.truncated_normal_initializer(stddev=0.02)
        deconv2d = functools.partial(tf.layers.conv2d_transpose, kernel_initializer=weight_init)
        dense = functools.partial(tf.layers.dense, kernel_initializer=weight_init)

        with tf.variable_scope(scope or self.__class__.__name__, reuse=reuse):
            z_shape = mixed_shape(z)
            if len(z_shape) == 4:
                assert z_shape[1] == z_shape[2] == 1
                z = flatten_right_from(z, axis=1)
            batch_size = z_shape[0]

            # (z_dim,)
            h = z

            with tf.variable_scope("block_1"):
                # (128,)
                h = dense(h, 128, use_bias=True)
                h = activation(h)

            with tf.variable_scope("block_2"):
                h = dense(h, 4 * 4 * 64, use_bias=True)
                h = activation(h)

            with tf.variable_scope("block_3"):
                h = tf.reshape(h, [batch_size, 4, 4, 64])
                h = deconv2d(h, filters=64, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            with tf.variable_scope("block_4"):
                # (16, 16, 32)
                h = deconv2d(h, filters=32, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            with tf.variable_scope("block_5"):
                # (32, 32, 32)
                h = deconv2d(h, filters=32, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            with tf.variable_scope("top"):
                # (64, 64, 1)
                x_logit = deconv2d(h, filters=self.x_dim[-1], kernel_size=4, strides=2, padding="same")
                x = tf.nn.sigmoid(x_logit)

            outputs = [x]
            if return_distribution:
                outputs.append({'prob': x, 'logit': x_logit})
            if return_top_hid:
                outputs.append(h)

            return outputs[0] if (len(outputs) == 1) else tuple(outputs)
    def gp_4_inp_list(center,
                      mode,
                      disc_fn,
                      inputs_1,
                      inputs_2,
                      at_D_comp=None):
        if mode == "interpolation":
            print("[BaseLatentModel.gp_4_inp_list] interpolation!")
            assert isinstance(inputs_1, (list, tuple))
            assert isinstance(inputs_1, (list, tuple))

            inputs_i = []

            for inp_1, inp_2 in zip(inputs_1, inputs_2):
                alpha = tf.random_uniform(shape=mixed_shape(inp_1),
                                          minval=0.,
                                          maxval=1.)
                # (batch, z_dim)
                inp_i = alpha * inp_1 + ((1 - alpha) * inp_2)
                inputs_i.append(inp_i)

            if at_D_comp is not None:
                D_logit_i = disc_fn(inputs_i)
                D_logit_i = D_logit_i[:, at_D_comp]
            else:
                D_logit_i = disc_fn(inputs_i)

            assert len(
                D_logit_i.shape.as_list()) == 1, D_logit_i.shape.as_list()

            # List of gradient for each input component
            D_grads = tf.gradients(D_logit_i, inputs_i)
            D_grads = [flatten_right_from(Dg, axis=1) for Dg in D_grads]

        else:
            raise ValueError("Do not support 'mode'='{}'".format(mode))

        gp = tf.constant(0, dtype=tf.float32)

        for Dg in D_grads:
            assert len(Dg.shape) == 2, "'Dg' must have 2 dimensions!"
            slope = tf.sqrt(tf.reduce_sum(tf.square(Dg), axis=1))
            gp += tf.reduce_mean((slope - center)**2, axis=0)

        return gp
    def __call__(self, x, is_train, return_distribution=False, return_top_hid=False, scope=None, reuse=None):
        activation = self.activation

        weight_init = tf.truncated_normal_initializer(stddev=0.02)
        conv2d = functools.partial(tf.layers.conv2d, kernel_initializer=weight_init)
        dense = functools.partial(tf.layers.dense, kernel_initializer=weight_init)

        with tf.variable_scope(scope or self.__class__.__name__, reuse=reuse):
            x_shape = mixed_shape(x)
            assert len(x_shape) == 4 and x_shape[1] == x_shape[2] == 64

            h = x

            # with tf.variable_scope("block_1"):
            with tf.variable_scope("conv_1"):
                # (32, 32, 32)
                h = conv2d(h, filters=32, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            # with tf.variable_scope("block_2"):
            with tf.variable_scope("conv_2"):
                # (16, 16, 32)
                h = conv2d(h, filters=32, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            # with tf.variable_scope("block_3"):
            with tf.variable_scope("conv_3"):
                # (8, 8, 64)
                h = conv2d(h, filters=64, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            # with tf.variable_scope("block_4"):
            with tf.variable_scope("conv_4"):
                # (4, 4, 64)
                h = conv2d(h, filters=64, kernel_size=4, strides=2, padding="same")
                h = activation(h)

            # with tf.variable_scope("block_5"):
            with tf.variable_scope("block_5"):
                # (4 * 4 * 64,)
                h = flatten_right_from(h, 1)
                h = dense(h, 128, use_bias=True)
                h = activation(h)

            with tf.variable_scope("top"):
                if self.stochastic:
                    mu = dense(h, self.z_dim, use_bias=True)
                    log_sigma = dense(h, self.z_dim, use_bias=True)
                    sigma = tf.exp(log_sigma)

                    eps = tf.random_normal(shape=mixed_shape(mu), mean=0.0, stddev=1.0)
                    z = mu + eps * sigma
                else:
                    z = dense(h, self.z_dim, use_bias=True)
                    mu = log_sigma = sigma = None

                dist = {'mean': mu, 'log_stddev': log_sigma, 'stddev': sigma}

            outputs = [z]
            if return_distribution:
                outputs.append(dist)
            if return_top_hid:
                outputs.append(h)

            return outputs[0] if (len(outputs) == 1) else tuple(outputs)
Beispiel #6
0
    def build(self, loss_coeff_dict):
        lc = loss_coeff_dict
        coeff_fn = self.one_if_not_exist

        batch_size = mixed_shape(self.x_ph)[0]
        ones = tf.ones([batch_size], dtype=tf.float32)
        zeros = tf.zeros([batch_size], dtype=tf.float32)

        Dz_loss = 0
        AE_loss = 0

        # Encoder / Decoder
        # ================================== #
        x0 = self.x_ph
        z1_gen, z1_gen_dist = self.encoder_fn(x0, return_distribution=True)
        x1 = self.decoder_fn(z1_gen)

        z0 = self.z_ph
        x1_gen = self.decoder_fn(z0)

        self.set_output('z1_gen', z1_gen)

        if self.stochastic_z:
            assert z1_gen_dist is not None, "'z1_gen_dist' must be not None!"
            self.set_output('z_mean', z1_gen_dist['mean'])
            self.set_output('z_stddev', z1_gen_dist['stddev'])
        else:
            assert z1_gen_dist is None, "'z1_gen_dist' must be None!"
            self.set_output('z_mean', tf.zeros_like(z1_gen))
            self.set_output('z_stddev', tf.ones_like(z1_gen))

        self.set_output('x1_gen', x1_gen)
        self.set_output('x1', x1)
        # ================================== #

        # Reconstruct x
        # ================================== #
        print("[AAE] rec_x_mode: {}".format(self.rec_x_mode))
        if self.rec_x_mode == 'mse':
            rec_x = tf.reduce_sum(tf.square(flatten_right_from(x0, 1) -
                                            flatten_right_from(x1, 1)), axis=1)
        elif self.rec_x_mode == 'l1':
            rec_x = tf.reduce_sum(tf.abs(flatten_right_from(x0, 1) -
                                         flatten_right_from(x1, 1)), axis=1)
        elif self.rec_x_mode == 'bce':
            _, x1_dist = self.decoder_fn(z1_gen, return_distribution=True)
            rec_x = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                labels=flatten_right_from(x0, 1),
                logits=flatten_right_from(x1_dist['logit'], 1)), axis=1)
        else:
            raise ValueError("Do not support '{}'!".format(self.rec_x_mode))

        rec_x = tf.reduce_mean(rec_x, axis=0)

        assert rec_x.shape.ndims == 0, "rec_x.shape: {}".format(rec_x.shape)
        self.set_output('rec_x', rec_x)

        AE_loss += coeff_fn(lc, 'rec_x') * rec_x
        # ================================== #

        # Discriminate z
        # ================================== #
        # E_p(z)[log D(z0)] +  E_p(x)p(z|x)[log(1 - D(z1_gen))]
        D_logit_z0 = self.disc_z_fn(z0)
        D_logit_z1_gen = self.disc_z_fn(z1_gen)

        D_prob_z0 = tf.nn.sigmoid(D_logit_z0)
        D_prob_z1_gen = tf.nn.sigmoid(D_logit_z1_gen)

        D_loss_z0 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=D_logit_z0, labels=ones), axis=0)
        D_loss_z1_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=D_logit_z1_gen, labels=zeros), axis=0)
        G_loss_z1_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=D_logit_z1_gen, labels=ones), axis=0)

        self.set_output('D_logit_z0', D_logit_z0)
        self.set_output('D_prob_z0', D_prob_z0)
        self.set_output('D_prob_z1_gen', D_prob_z1_gen)
        self.set_output('D_avg_prob_z0', tf.reduce_mean(D_prob_z0, axis=0))
        self.set_output('D_avg_prob_z1_gen', tf.reduce_mean(D_prob_z1_gen, axis=0))

        self.set_output('D_loss_z0', D_loss_z0)
        self.set_output('D_loss_z1_gen', D_loss_z1_gen)
        self.set_output('G_loss_z1_gen', G_loss_z1_gen)

        Dz_loss += coeff_fn(lc, "D_loss_z1_gen") * (D_loss_z0 + D_loss_z1_gen)
        AE_loss += coeff_fn(lc, "G_loss_z1_gen") * G_loss_z1_gen
        # ================================== #

        # Gradient penalty term for z
        # ================================== #
        if self.use_gp0_z:
            print("[AAE] use_gp0_z: {}".format(self.use_gp0_z))
            print("[AAE] gp0_z_mode: {}".format(self.gp0_z_mode))

            gp0_z = self.gp0("interpolation", self.disc_z_fn, z1_gen, z0)
            self.set_output('gp0_z', gp0_z)

            Dz_loss += coeff_fn(lc, 'gp0_z') * gp0_z
        # ================================== #

        self.set_output('Dz_loss', Dz_loss)
        self.set_output('AE_loss', AE_loss)
    def __call__(self,
                 x,
                 is_train,
                 return_distribution=False,
                 return_top_hid=False,
                 scope=None):
        activation = self.activation

        weight_init = tf.truncated_normal_initializer(stddev=0.02)
        conv2d = functools.partial(tf.layers.conv2d,
                                   kernel_initializer=weight_init)
        dense = functools.partial(tf.layers.dense,
                                  kernel_initializer=weight_init)

        with tf.variable_scope(scope or self.__class__.__name__):
            x_shape = mixed_shape(x)
            assert len(x_shape) == 4 and x_shape[1] == x_shape[2] == 64

            h = x

            with tf.variable_scope("block_1"):
                # (32, 32, 32)
                h = conv2d(h,
                           filters=32,
                           kernel_size=4,
                           strides=2,
                           padding="same")
                h = activation(h)

            with tf.variable_scope("block_2"):
                # (16, 16, 32)
                h = conv2d(h,
                           filters=32,
                           kernel_size=4,
                           strides=2,
                           padding="same")
                h = activation(h)

            with tf.variable_scope("block_3"):
                # (8, 8, 64)
                h = conv2d(h,
                           filters=64,
                           kernel_size=4,
                           strides=2,
                           padding="same")
                h = activation(h)

            with tf.variable_scope("block_4"):
                # (4, 4, 64)
                h = conv2d(h,
                           filters=64,
                           kernel_size=4,
                           strides=2,
                           padding="same")
                h = activation(h)

            with tf.variable_scope("block_5"):
                # Only change 128 (dSprites) to 256 (Chairs3D, CelebA), keep other unchanged
                # (1, 1, 256)
                h = conv2d(h,
                           filters=256,
                           kernel_size=4,
                           strides=1,
                           padding="valid")
                h = activation(h)

            with tf.variable_scope("top"):
                # (256,)
                h = flatten_right_from(h, 1)

                if self.stochastic:
                    mu = dense(h, self.z_dim, use_bias=True)
                    log_sigma = dense(h, self.z_dim, use_bias=True)
                    sigma = tf.exp(log_sigma)

                    eps = tf.random_normal(shape=mixed_shape(mu),
                                           mean=0.0,
                                           stddev=1.0)
                    z = mu + eps * sigma

                    mu = reshape_4_batch(mu, self.z_shape, num_batch_axes=1)
                    log_sigma = reshape_4_batch(log_sigma,
                                                self.z_shape,
                                                num_batch_axes=1)
                    z = reshape_4_batch(z, self.z_shape, num_batch_axes=1)

                else:
                    z = dense(h, self.z_dim, use_bias=True)
                    mu = log_sigma = sigma = None

                    z = reshape_4_batch(z, self.z_shape, num_batch_axes=1)

                dist = {'mean': mu, 'log_stddev': log_sigma, 'stddev': sigma}

            outputs = [z]
            if return_distribution:
                outputs.append(dist)
            if return_top_hid:
                outputs.append(h)

            return outputs[0] if (len(outputs) == 1) else tuple(outputs)
Beispiel #8
0
    def _consistency_loss(self):
        from tensorflow.contrib.graph_editor import graph_replace

        x = self.x_ph

        shape_x = mixed_shape(x)
        batch_size = shape_x[0]

        rand_ids = tf.random_shuffle(tf.range(batch_size, dtype=tf.int32))
        x_rand = tf.gather(x, rand_ids)

        # (batch_size)
        lam = self.beta.sample(batch_size)
        lam_x = tf.reshape(lam, [batch_size] + [1] * len(self.x_shape))
        lam_y = tf.reshape(lam, [batch_size, 1])

        x_mixed = lam_x * x + (1 - lam_x) * x_rand

        y_prob_stu = self.get_output('y_dist_stu_sto')['prob']
        y_prob_stu_on_x_mixed = graph_replace(y_prob_stu, {x: x_mixed})

        if self.cons_against_mean:
            print("Use teacher (deterministic) for consistency!")
            y_prob_tea = self.get_output('y_dist_tea_det')['prob']
        else:
            print("Use teacher (stochastic) for consistency!")
            assert 'y_dist_tea_sto' in self.output_dict, "'output_dict' must contain 'y_dist_tea_sto'!"
            y_prob_tea = self.get_output('y_dist_tea_sto')['prob']
        y_prob_tea_rand = graph_replace(y_prob_tea, {x: x_rand})
        y_prob_tea_mixed = lam_y * y_prob_tea + (1 - lam_y) * y_prob_tea_rand

        if self.cons_mode == 'mse':
            # IMPORTANT: Here, we take the sum over classes.
            # Implementations from other papers they use 'mean' instead of 'sum'.
            # This means our 'cons_coeff' should be about 10 (for CIFAR-10 and SVHN),
            # not 100 like in the original papers
            print("cons_mode=mse!")
            consistency = tf.reduce_sum(
                tf.square(y_prob_stu_on_x_mixed -
                          tf.stop_gradient(y_prob_tea_mixed)),
                axis=1)
        elif self.cons_mode == 'kld':
            print("cons_mode=kld!")
            from my_utils.tensorflow_utils.distributions import KLD_2Cats_v2
            consistency = KLD_2Cats_v2(y_prob_stu_on_x_mixed,
                                       tf.stop_gradient(y_prob_tea_mixed))
        elif self.cons_mode == 'rev_kld':
            print("cons_mode=rev_kld!")
            from my_utils.tensorflow_utils.distributions import KLD_2Cats_v2
            consistency = KLD_2Cats_v2(tf.stop_gradient(y_prob_tea_mixed),
                                       y_prob_stu_on_x_mixed)
        else:
            raise ValueError("Do not support 'cons_mode'={}!".format(
                self.cons_mode))

        if self.cons_4_unlabeled_only:
            label_flag_inv = self.get_output('label_flag_inv')
            num_unlabeled = self.get_output('num_unlabeled')
            consistency = tf.reduce_sum(consistency * label_flag_inv,
                                        axis=0) * 1.0 / (num_unlabeled + 1e-8)
        else:
            consistency = tf.reduce_mean(consistency, axis=0)

        results = {
            'cons': consistency,
        }

        return results
    def gp(center, mode, disc_fn, inp_1, inp_2, at_D_comp=None):
        if mode == "interpolation":
            print("[BaseLatentModel.gp] interpolation!")
            assert (inp_1 is not None) and (inp_2 is not None), "If 'mode' is 'interpolation', " \
                "both 'inp_1' and 'inp_2' must be provided!"

            # IMPORTANT: The formula below is wrong (compared to the original).
            # IMPORTANT: It must be [batch, 1, 1, ...], not mixed_shape(inp_1)
            # alpha = tf.random_uniform(shape=mixed_shape(inp_1), minval=0., maxval=1.)

            shape = mixed_shape(inp_1)
            batch_size = shape[0]
            # (batch, 1)
            alpha = tf.random_uniform(shape=[batch_size] + [1] *
                                      (len(shape) - 1),
                                      minval=0.,
                                      maxval=1.)

            # (batch, z_dim)
            inp_i = alpha * inp_1 + ((1 - alpha) * inp_2)

            if at_D_comp is not None:
                D_logit = disc_fn(inp_i)
                D_logit_i = D_logit[:, at_D_comp]
            else:
                D_logit_i = disc_fn(inp_i)

            assert len(
                D_logit_i.shape.as_list()) == 1, D_logit_i.shape.as_list()

            D_grad = tf.gradients(D_logit_i, [inp_i])[0]
            D_grad = flatten_right_from(D_grad, 1)

        elif mode == "self":
            print("[BaseLatentModel.gp] self!")
            assert inp_1 is not None, "If 'mode' is 'self', 'inp_1' must be provided!"

            if at_D_comp is not None:
                D_logit = disc_fn(inp_1)
                D_logit_1 = D_logit[:, at_D_comp]
            else:
                D_logit_1 = disc_fn(inp_1)

            assert len(
                D_logit_1.shape.as_list()) == 1, D_logit_1.shape.as_list()

            D_grad = tf.gradients(D_logit_1, [inp_1])[0]
            D_grad = flatten_right_from(D_grad, 1)

        elif mode == "other":
            print("[BaseLatentModel.gp] other!")
            assert inp_2 is not None, "If 'mode' is 'other', 'inp_2' must be provided!"

            if at_D_comp is not None:
                D_logit = disc_fn(inp_2)
                D_logit_2 = D_logit[:, at_D_comp]
            else:
                D_logit_2 = disc_fn(inp_2)

            assert len(
                D_logit_2.shape.as_list()) == 1, D_logit_2.shape.as_list()

            D_grad = tf.gradients(D_logit_2, [inp_2])[0]
            D_grad = flatten_right_from(D_grad, 1)
            assert len(D_grad.shape) == 2, "'D_grad' must have 2 dimensions!"

        else:
            raise ValueError("Do not support 'mode'='{}'".format(mode))

        assert len(D_grad.shape) == 2, "'D_grad' must have 2 dimensions!"

        if center == 0:
            gp = tf.reduce_mean(tf.reduce_sum(tf.square(D_grad), axis=1),
                                axis=0)
        else:
            slope = tf.sqrt(tf.reduce_sum(tf.square(D_grad), axis=1))
            gp = tf.reduce_mean((slope - center)**2, axis=0)

        return gp