Exemple #1
0
 def additional_generator_losses(self):
     cls_real = self.ce_weight_generator * ktf.reduce_mean(
         ktf.nn.sparse_softmax_cross_entropy_with_logits(
             labels=ktf.squeeze(self.generator_input[1], axis=1),
             logits=self.discriminator_fake_output[1]))
     self.generator_metric_names.append('cls')
     return [cls_real]
Exemple #2
0
    def get_gradient_penalty_loss(self, for_discriminator=True):
        if self.gradient_penalty_weight == 0:
            return []

        inp = self.discriminator_input if for_discriminator else self.generator_input
        if type(inp) == list:
            batch_size = ktf.shape(inp[0])[0]
        else:
            batch_size = ktf.shape(inp)[0]

        points = self.grad_generator_output
        print K.int_shape(points)

        gp_list = []
        disc_out = self.discriminator([points])
        if type(disc_out) != list:
            disc_out = [disc_out]
        gradients = ktf.gradients(disc_out[0], points)

        for gradient in gradients:
            if gradient is None:
                continue
            gradient = ktf.reshape(gradient, (batch_size, -1))
            gradient_l2_norm = ktf.sqrt(ktf.reduce_sum(ktf.square(gradient), axis=1))
            if for_discriminator:
                gradient_penalty = self.gradient_penalty_weight * ktf.square(1 - gradient_l2_norm)
            else:
                gradient_penalty = -self.gradient_penalty_weight_generator * gradient_l2_norm
            gp_list.append(ktf.reduce_mean(gradient_penalty))

        if for_discriminator:
            for i in range(len(gp_list)):
                self.discriminator_metric_names.append('gp_loss_' + str(i))
        return gp_list
Exemple #3
0
 def additional_discriminator_losses(self):
     losses = []
     cls_real = self.ce_weight_discriminator * ktf.reduce_mean(
         ktf.nn.sparse_softmax_cross_entropy_with_logits(
             labels=ktf.squeeze(
                 self.additional_inputs_for_discriminator_train[0], axis=1),
             logits=self.discriminator_real_output[1]))
     self.discriminator_metric_names.append('cls_real')
     losses.append(cls_real)
     if self.classify_generated:
         cls_fake = self.ce_weight_discriminator * ktf.reduce_mean(
             ktf.nn.sparse_softmax_cross_entropy_with_logits(
                 labels=ktf.squeeze(self.generator_input[1], axis=1),
                 logits=self.discriminator_fake_output[1]))
         losses.append(cls_fake)
         self.discriminator_metric_names.append('cls_fake')
     return losses
Exemple #4
0
    def get_gradient_penalty_loss(self):
        if self.gradient_penalty_weight == 0:
            return []

        if type(self.discriminator_input) == list:
            batch_size = ktf.shape(self.discriminator_input[0])[0]
            ranks = [len(inp.get_shape().as_list()) for inp in self.discriminator_input]
        else:
            batch_size = ktf.shape(self.discriminator_input)[0]
            ranks = [len(self.discriminator_input.get_shape().as_list())]

        def cast_all(values, reference_type_vals):
            return [ktf.cast(alpha, dtype=ref.dtype) for alpha, ref in zip(values, reference_type_vals)]

        def std_if_not_int(val):
            if val.dtype.is_integer:
                return 0
            else:
                return ktf.stop_gradient(K.std(val, keepdims=True))

        def point_for_gp_wgan():
            weights = ktf.random_uniform((batch_size, 1), minval=0, maxval=1)
            weights = [ktf.reshape(weights, (-1, ) + (1, ) * (rank - 1)) for rank in ranks]
            weights = cast_all(weights, self.discriminator_input)
            points = [(w * r) + ((1 - w) * f) for r, f, w in zip(self.discriminator_input, self.generator_output, weights)]
            return points

        def points_for_dragan():
            alphas = ktf.random_uniform((batch_size, 1), minval=0, maxval=1)
            alphas = [ktf.reshape(alphas, (-1, ) + (1, ) * (rank - 1)) for rank in ranks]
            alphas = cast_all(alphas, self.discriminator_input)
            fake = [ktf.random_uniform(ktf.shape(t), minval=0, maxval=1) * std_if_not_int(t) * 0.5
                       for t in self.discriminator_input]
            fake = cast_all(fake, self.discriminator_input)

            points = [(w * r) + ((1 - w) * f) for r, f, w in zip(self.discriminator_input, fake, alphas)]
            return points

        points = {'wgan-gp': point_for_gp_wgan(), 'dragan': points_for_dragan()}
        points = points[self.gradient_penalty_type]

        gp_list = []
        disc_out = self.discriminator(points)
        if type(disc_out) != list:
            disc_out = [disc_out]
        gradients = ktf.gradients(disc_out[0], points)

        for gradient in gradients:
            if gradient is None:
                continue
            gradient = ktf.reshape(gradient, (batch_size, -1))
            gradient_l2_norm = ktf.sqrt(ktf.reduce_sum(ktf.square(gradient), axis=1))
            gradient_penalty = self.gradient_penalty_weight * ktf.square(1 - gradient_l2_norm)
            gp_list.append(ktf.reduce_mean(gradient_penalty))

        for i in range(len(gp_list)):
            self.discriminator_metric_names.append('gp_loss_' + str(i))
        return gp_list
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        _, w, h, c = input_shape
        bs = K.shape(inputs)[0]

        #if c < self.group:
        #    raise ValueError('Input channels should be larger than group size' +
        #                     '; Received input channels: ' + str(c) +
        #                     '; Group size: ' + str(self.group)
        #                    )
        #x = K.reshape(inputs, (batch_size, h, w, self.group, c // self.group))

        x_t = ktf.transpose(inputs, (3, 0, 1, 2))
        #x_t = ktf.transpose(x, (3, 4, 0, 1, 2))

        # BxCxHxW -> CxB*H*W
        x_flat = ktf.reshape(x_t, (c, -1))

        # Covariance
        m = ktf.reduce_mean(x_flat, axis=1, keepdims=True)
        m = K.in_train_phase(m, self.moving_mean)
        f = x_flat - m

        if self.decomposition == 'cholesky':

            def get_inv_sqrt(ff):
                sqrt = ktf.cholesky(ff)
                inv_sqrt = ktf.matrix_triangular_solve(sqrt, ktf.eye(c))
                return sqrt, inv_sqrt
        elif self.decomposition in ['zca', 'zca-cor']:

            def get_inv_sqrt(ff):
                with ktf.device('/cpu:0'):
                    S, U, _ = ktf.svd(ff + ktf.eye(c) * self.epsilon,
                                      full_matrices=True)
                D = ktf.diag(ktf.pow(S, -0.5))
                inv_sqrt = ktf.matmul(ktf.matmul(U, D), U, transpose_b=True)
                D = ktf.diag(ktf.pow(S, 0.5))
                sqrt = ktf.matmul(ktf.matmul(U, D), U, transpose_b=True)
                return sqrt, inv_sqrt
        elif self.decomposition in ['pca', 'pca-cor']:

            def get_inv_sqrt(ff):
                with ktf.device('/cpu:0'):
                    S, U, _ = ktf.svd(ff + ktf.eye(c) * self.epsilon,
                                      full_matrices=True)
                D = ktf.diag(ktf.pow(S, -0.5))
                inv_sqrt = ktf.matmul(D, U, transpose_b=True)
                D = ktf.diag(ktf.pow(S, 0.5))
                sqrt = ktf.matmul(D, U, transpose_b=True)
                return sqrt, inv_sqrt
        else:
            assert False

        def train():
            ff_apr = ktf.matmul(f, f, transpose_b=True) / (
                ktf.cast(bs * w * h, ktf.float32) - 1.)
            if self.decomposition in ['pca-cor', 'zca-cor']:
                dinv = ktf.diag(ktf.sqrt(ktf.diag_part(ff_apr)))
                ff_apr = ktf.matmul(ktf.matmul(dinv, ff_apr),
                                    ktf.matrix_inverse(dinv),
                                    transpose_b=True)
            self.add_update([
                K.moving_average_update(self.moving_mean, m, self.momentum),
                K.moving_average_update(self.moving_cov, ff_apr, self.momentum)
            ], inputs)
            ff_apr_shrinked = (
                1 - self.epsilon) * ff_apr + ktf.eye(c) * self.epsilon

            if self.renorm:
                l, l_inv = get_inv_sqrt(ff_apr_shrinked)
                ff_mov = (1 - self.epsilon
                          ) * self.moving_cov + ktf.eye(c) * self.epsilon
                _, l_mov_inverse = get_inv_sqrt(ff_mov)
                l_ndiff = K.stop_gradient(l)
                return ktf.matmul(ktf.matmul(l_mov_inverse, l_ndiff), l_inv)

            return get_inv_sqrt(ff_apr_shrinked)[1]

        def test():
            ff_mov = (
                1 - self.epsilon) * self.moving_cov + ktf.eye(c) * self.epsilon
            return get_inv_sqrt(ff_mov)[1]

        inv_sqrt = K.in_train_phase(train, test)
        f_hat = ktf.matmul(inv_sqrt, f)

        decorelated = K.reshape(f_hat, [c, bs, w, h])
        decorelated = ktf.transpose(decorelated, [1, 2, 3, 0])

        broadcast_shape = [1] * len(input_shape)
        if self.axis is not None:
            broadcast_shape[self.axis] = input_shape[self.axis]
        if self.scale:
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            decorelated = decorelated * broadcast_gamma
        if self.center:
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            decorelated = decorelated + broadcast_beta

        return decorelated
Exemple #6
0
 def ls_loss_true(logits):
     return ktf.reduce_mean((logits - 1) ** 2)
Exemple #7
0
 def ns_loss_fake(logits):
     labels = ktf.zeros_like(logits)
     return ktf.reduce_mean(ktf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))
Exemple #8
0
 def hinge(logits):
     return -ktf.reduce_mean(logits)
Exemple #9
0
 def wgan(logits):
     return -ktf.reduce_mean(logits)
Exemple #10
0
 def hinge_loss_fake(logits):
     return ktf.reduce_mean(ktf.maximum(0.0, 1.0 + logits))
Exemple #11
0
 def hinge_loss_true(logits):
     return ktf.reduce_mean(ktf.maximum(0.0, 1.0 - logits))
Exemple #12
0
 def wgan_loss_fake(logits):
     return ktf.reduce_mean(logits)
Exemple #13
0
 def wgan_loss_true(logits):
     return -ktf.reduce_mean(logits)
Exemple #14
0
 def ls_loss_fake(logits):
     return ktf.reduce_mean(logits ** 2)