Example #1
0
    def create(self):
        config = self.config
        ops = self.ops

        self.session = self.ops.new_session(self.ops_config)

        encoder_config = dict(config.input_encoder)
        encode_a = self.create_component(encoder_config)
        encode_a.ops.describe("encode_a")
        encode_b = self.create_component(encoder_config)
        encode_b.ops.describe("encode_b")

        g_ab = self.create_component(config.generator)
        g_ab.ops.describe("g_ab")
        g_ba = self.create_component(config.generator)
        g_ba.ops.describe("g_ba")

        #encode_a.ops = g_ab.ops
        #encode_b.ops = g_ba.ops

        encode_a.create(self.inputs.xa)
        encode_b.create(self.inputs.xb)

        g_ab.create(encode_a.sample)
        g_ba.create(encode_b.sample)

        self.xba = g_ba.sample
        self.xab = g_ab.sample

        discriminator_a = self.create_component(config.discriminator)
        discriminator_b = self.create_component(config.discriminator)
        discriminator_a.ops.describe("discriminator_a")
        discriminator_b.ops.describe("discriminator_b")
        discriminator_a.create(x=self.inputs.xa, g=g_ba.sample)
        discriminator_b.create(x=self.inputs.xb, g=g_ab.sample)

        encode_g_ab = encode_b.reuse(g_ab.sample)
        encode_g_ba = encode_a.reuse(g_ba.sample)

        cyca = g_ba.reuse(encode_g_ab)
        cycb = g_ab.reuse(encode_g_ba)

        lossa = self.create_component(config.loss,
                                      discriminator=discriminator_a,
                                      generator=g_ba)
        lossb = self.create_component(config.loss,
                                      discriminator=discriminator_b,
                                      generator=g_ab)

        lossa.create()
        lossb.create()

        cycloss = tf.reduce_mean(tf.abs(self.inputs.xa-cyca)) + \
                       tf.reduce_mean(tf.abs(self.inputs.xb-cycb))

        # loss terms

        cycloss_lambda = config.cycloss_lambda
        if cycloss_lambda is None:
            cycloss_lambda = 10
        cycloss *= cycloss_lambda
        loss1 = ('generator', cycloss + lossb.g_loss)
        loss2 = ('discriminator', lossb.d_loss)
        loss3 = ('generator', cycloss + lossa.g_loss)
        loss4 = ('discriminator', lossa.d_loss)

        var_lists = []
        var_lists.append(encode_a.variables() + g_ab.variables())
        var_lists.append(discriminator_b.variables())
        var_lists.append(encode_b.variables() + g_ba.variables())
        var_lists.append(discriminator_a.variables())

        metrics = []
        metrics.append(lossa.metrics)
        metrics.append(None)
        metrics.append(lossb.metrics)
        metrics.append(None)

        self.trainer = MultiStepTrainer(self,
                                        self.config.trainer,
                                        [loss1, loss2, loss3, loss4],
                                        var_lists=var_lists,
                                        metrics=metrics)
        self.trainer.create()

        self.cyca = cyca
        self.cycb = cycb
        self.cycloss = cycloss
        self.encoder = encode_a
        self.generator = g_ab

        self.session.run(tf.global_variables_initializer())
Example #2
0
    def create(self):
        BaseGAN.create(self)
        if self.session is None:
            self.session = self.ops.new_session(self.ops_config)
        with tf.device(self.device):
            config = self.config
            print(config)
            print("________")
            ops = self.ops

            self.encoder = self.create_component(config.encoder)
            self.encoder.create()
            generator_samples = []
            config.generator.skip_linear = True

            print("!!!!!!!!!!!!!!!!!!!!!!! Creataing generator",
                  config.generator)
            generator = self.create_component(config.generator)
            generator.ops.describe("generator")
            self.generator = generator
            for i in range(config.number_generators):
                primes = config.generator.initial_dimensions or [4, 4]
                initial_depth = generator.depths(primes[0])[0]
                net = ops.reshape(self.encoder.sample,
                                  [ops.shape(self.encoder.sample)[0], -1])
                new_shape = [
                    ops.shape(net)[0], primes[0], primes[1], initial_depth
                ]
                net = ops.linear(net, initial_depth * primes[0] * primes[1])
                pi = ops.reshape(net, new_shape)

                #pi = tf.zeros([self.batch_size(), primes[0], primes[1], 256])
                print("[MultiGeneratorGAN] Creating generator ", i, pi)
                if i == 0:
                    gi = generator.create(pi)
                else:
                    gi = generator.reuse(pi)
                generator_samples.append(gi)

            self.discriminator = self.create_component(config.discriminator)
            self.discriminator.ops.describe("discriminator")

            losses = []
            self.loss = self

            self.loss = self.create_component(config.loss)

            g_loss = tf.constant(0.0)
            d_loss = tf.constant(0.0)
            metrics = []
            d_fake_features = []

            for i in range(config.number_generators):
                if i == 0:
                    di = self.discriminator.create(x=self.inputs.x,
                                                   g=generator_samples[i])
                else:
                    di = self.discriminator.reuse(x=self.inputs.x,
                                                  g=generator_samples[i])
                d_real, d_fake = self.split_batch(di, 2)
                # after the divergence measure or before ? TODO
                #d_fake_features.append(self.discriminator.g_loss_features)
                d_fake_features.append(d_fake)

                loss = self.loss.create(d_real=d_real, d_fake=d_fake)
                losses.append(loss)
                g_loss += loss[1]
                d_loss += loss[0]

            var_lists = []
            steps = []

            if config.class_loss_type == 'svm':
                # classifier loss
                for i in range(config.number_generators):
                    features = tf.reshape(d_fake_features[i],
                                          [self.batch_size(), -1])
                    c_loss = ops.lookup('crelu')(features)
                    print("C LOSS 1", c_loss)
                    c_loss = ops.linear(c_loss, config.number_generators)
                    label = tf.one_hot([i], config.number_generators)
                    label = tf.tile(label, [self.batch_size(), 1])
                    c_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=c_loss, labels=label)

                    g_loss += c_loss * (config.c_loss_lambda or 1)
                metrics.append({"class loss": self.ops.squash(c_loss)})
                metrics.append(self.loss.metrics)

                var_lists.append(self.generator.variables() + self.variables())
                var_lists.append(self.discriminator.variables())
                steps = [('generator', g_loss), ('discriminator', d_loss)]

            if config.class_loss_type == 'gan':
                d2 = None
                g2 = None
                l2 = None
                d2_loss_sum = tf.constant(0.)
                g2_loss_sum = tf.constant(0.)

                var_lists.append(self.generator.variables() + self.variables())
                var_lists.append(self.discriminator.variables())
                # classifier as gan loss
                for i in range(config.number_generators):
                    if i != 0:
                        self.ops.reuse()
                    features = tf.reshape(d_fake_features[i],
                                          [self.batch_size(), -1])
                    label = tf.one_hot([i], config.number_generators)
                    label = tf.tile(label, [self.batch_size(), 1])

                    # D2(G2(gx), label)
                    g2config = dict(config.generator2)
                    g2 = self.create_component(g2config)
                    # this is the generator
                    g2.ops.describe("G2")

                    if i == 0:
                        g2sample = g2.create(
                            tf.concat(generator_samples, axis=3))
                    else:
                        g2sample = g2.reuse(
                            tf.concat(generator_samples, axis=3))

                    d2config = dict(config.discriminator2)
                    d2 = self.create_component(d2config)

                    d2.ops.describe("D2")
                    if i == 0:
                        d2.create(x=label, g=g2sample)
                        var_lists.append(g2.variables())
                        var_lists.append(d2.variables())
                    else:
                        d2.reuse(x=label, g=g2sample)

                    l2config = dict(config.loss)
                    l2 = self.create_component(l2config,
                                               discriminator=d2,
                                               generator=g2)
                    d2_loss, g2_loss = l2.create()

                    g2_loss_sum += g2_loss
                    d2_loss_sum += d2_loss

                    if i != 0:
                        self.ops.stop_reuse()

                steps = [('generator 1', g_loss + g2_loss_sum),
                         ('discriminator 1', d_loss),
                         ('generator 2', g2_loss_sum + g_loss),
                         ('discriminator 2', d2_loss)]

                metrics.append(None)
                metrics.append(self.loss.metrics)
                metrics.append(None)
                metrics.append(l2.metrics)

            print("T", self.config.trainer, steps, metrics)
            self.trainer = MultiStepTrainer(self,
                                            self.config.trainer,
                                            steps,
                                            var_lists=var_lists,
                                            metrics=metrics)
            self.trainer.create()

            self.session.run(tf.global_variables_initializer())
            self.uniform_sample = tf.concat(generator_samples, axis=1)
Example #3
0
    def create_trainer(self, cycloss, z_cycloss, encoder, generator, encoder_loss, standard_loss, standard_discriminator, encoder_discriminator):
        if z_cycloss is not None:
            loss1=('generator encoder', z_cycloss + cycloss + encoder_loss.g_loss)
            loss2=('generator image', z_cycloss + cycloss + standard_loss.g_loss)
            loss3=('discriminator image', standard_loss.d_loss)
            loss4=('discriminator encoder', encoder_loss.d_loss)
        else:
            loss1=('generator encoder', cycloss + encoder_loss.g_loss)
            loss2=('generator image',cycloss + standard_loss.g_loss)
            loss3=('discriminator image', standard_loss.d_loss)
            loss4=('discriminator encoder', encoder_loss.d_loss)

        var_lists = []
        var_lists.append(encoder.variables())
        var_lists.append(generator.variables())
        var_lists.append(standard_discriminator.variables())
        var_lists.append(encoder_discriminator.variables())

        metrics = []
        metrics.append(None)
        metrics.append(None)
        metrics.append(standard_loss.metrics)
        metrics.append(encoder_loss.metrics)

        if self.config.trainer['class'] == ConsensusTrainer:
            d_vars = standard_discriminator.variables() + encoder_discriminator.variables()
            g_vars = generator.variables() + encoder.variables()
            d_loss = standard_loss.d_loss + encoder_loss.d_loss
            g_loss = encoder_loss.g_loss + standard_loss.g_loss + cycloss
            #d_loss = standard_loss.d_loss
            #g_loss = standard_loss.g_loss + cycloss
            loss = hc.Config({'sample': [d_loss, g_loss], 'metrics': 
                {'g_loss': loss2[1], 'd_loss': loss3[1], 'e_g_loss': loss1[1], 'e_d_loss': loss4[1]}})
            trainer = ConsensusTrainer(self, self.config.trainer, loss = loss, g_vars = g_vars, d_vars = d_vars)
        elif self.config.trainer['class'] == MultiTrainerTrainer:
            d_vars = standard_discriminator.variables() 
            g_vars = generator.variables()
            d_loss = standard_loss.d_loss
            g_loss = standard_loss.g_loss
            if(self.config.cycloss_on_g):
                g_loss += cycloss * self.config.cycloss_on_g
            #d_loss = standard_loss.d_loss
            #g_loss = standard_loss.g_loss + cycloss
            loss = hc.Config({'sample': [d_loss, g_loss], 'metrics': 
                {'g_loss': loss2[1], 'd_loss': loss3[1], 'e_g_loss': loss1[1], 'e_d_loss': loss4[1]}})
            trainer1 = ConsensusTrainer(self, self.config.trainer, loss = loss, g_vars = g_vars, d_vars = d_vars)


            d_vars = encoder_discriminator.variables()
            g_vars = encoder.variables()
            d_loss = encoder_loss.d_loss
            g_loss = encoder_loss.g_loss + cycloss

            loss = hc.Config({'sample': [d_loss, g_loss], 'metrics': 
                {'g_loss': loss2[1], 'd_loss': loss3[1], 'e_g_loss': loss1[1], 'e_d_loss': loss4[1]}})
            trainer2 = ConsensusTrainer(self, self.config.trainer, loss = loss, g_vars = g_vars, d_vars = d_vars)

            trainer = MultiTrainerTrainer([trainer1, trainer2])
        else:
            trainer = MultiStepTrainer(self, self.config.trainer, [loss1,loss2,loss3,loss4], var_lists=var_lists, metrics=metrics)
        return trainer
Example #4
0
    def create(self):
        BaseGAN.create(self)
        if self.session is None:
            self.session = self.ops.new_session(self.ops_config)
        with tf.device(self.device):
            config = self.config
            ops = self.ops

            g_encoder = dict(config.g_encoder or config.discriminator)
            encoder = self.create_component(g_encoder)
            encoder.ops.describe("g_encoder")
            encoder.create(self.inputs.x)
            encoder.z = tf.zeros(0)
            if (len(encoder.sample.get_shape()) == 2):
                s = ops.shape(encoder.sample)
                encoder.sample = tf.reshape(encoder.sample, [s[0], s[1], 1, 1])

            z_discriminator = dict(config.z_discriminator
                                   or config.discriminator)
            z_discriminator['layer_filter'] = None

            encoder_discriminator = self.create_component(z_discriminator)
            encoder_discriminator.ops.describe("z_discriminator")
            standard_discriminator = self.create_component(
                config.discriminator)
            standard_discriminator.ops.describe("discriminator")

            #encoder.sample = ops.reshape(encoder.sample, [ops.shape(encoder.sample)[0], -1])
            uniform_encoder_config = config.encoder
            z_size = 1
            for size in ops.shape(encoder.sample)[1:]:
                z_size *= size
            uniform_encoder_config.z = z_size
            uniform_encoder = UniformEncoder(self, uniform_encoder_config)
            uniform_encoder.create()

            self.generator = self.create_component(config.generator)

            z = uniform_encoder.sample
            x = self.inputs.x

            # project the output of the autoencoder
            projection_input = ops.reshape(encoder.sample,
                                           [ops.shape(encoder.sample)[0], -1])
            projections = []
            for projection in uniform_encoder.config.projections:
                projection = uniform_encoder.lookup(projection)(
                    uniform_encoder.config, self.gan, projection_input)
                projection = ops.reshape(projection, ops.shape(encoder.sample))
                projections.append(projection)
            z_hat = tf.concat(axis=3, values=projections)

            z = ops.reshape(z, ops.shape(z_hat))
            # end encoding

            g = self.generator.create(z)
            sample = self.generator.sample
            self.uniform_sample = self.generator.sample
            x_hat = self.generator.reuse(z_hat)

            encoder_discriminator.create(x=z, g=z_hat)

            eloss = dict(config.loss)
            eloss['gradient_penalty'] = False
            encoder_loss = self.create_component(
                eloss, discriminator=encoder_discriminator)
            encoder_loss.create()

            stacked_xg = ops.concat([x, x_hat, g], axis=0)
            standard_discriminator.create(stacked_xg)

            standard_loss = self.create_component(
                config.loss, discriminator=standard_discriminator)
            standard_loss.create(split=3)

            self.trainer = self.create_component(config.trainer)

            #loss terms
            distance = config.distance or ops.lookup('l1_distance')
            cycloss = tf.reduce_mean(distance(self.inputs.x, x_hat))
            cycloss_lambda = config.cycloss_lambda
            if cycloss_lambda is None:
                cycloss_lambda = 10
            cycloss *= cycloss_lambda
            loss1 = ('generator', cycloss + encoder_loss.g_loss)
            loss2 = ('generator', cycloss + standard_loss.g_loss)
            loss3 = ('discriminator', standard_loss.d_loss)
            loss4 = ('discriminator', encoder_loss.d_loss)

            var_lists = []
            var_lists.append(encoder.variables())
            var_lists.append(self.generator.variables())
            var_lists.append(standard_discriminator.variables())
            var_lists.append(encoder_discriminator.variables())

            metrics = []
            metrics.append(encoder_loss.metrics)
            metrics.append(standard_loss.metrics)
            metrics.append(None)
            metrics.append(None)

            # trainer

            self.trainer = MultiStepTrainer(self,
                                            self.config.trainer,
                                            [loss1, loss2, loss3, loss4],
                                            var_lists=var_lists,
                                            metrics=metrics)
            self.trainer.create()

            self.session.run(tf.global_variables_initializer())

            self.encoder = encoder
            self.uniform_encoder = uniform_encoder