Exemplo n.º 1
0
    def build(self, net):
        config = self.config
        gan = self.gan
        ops = self.ops

        discriminator = DCGANDiscriminator(gan, config)
        discriminator.ops = ops
        encoder = UniformDistribution(gan, gan.config.encoder)

        # careful, this order matters
        g2 = gan.generator.reuse(encoder.create())
        double = tf.concat([net] + [g2, g2], axis=0)
        original = discriminator.build(double)
        d1 = self.split_batch(original, 4)

        dg = ops.concat([d1[2], d1[3]], axis=0) # xs for baseline

        #dx is a sampling of x twice
        dx = ops.concat([d1[0], d1[0]], axis=0) # xs for baseline

        xinput = ops.concat([d1[0], d1[1]], axis=0)

        #dg  is a sampling of g twice

        # net is [x,g] (stacked)
        error = self.f(xinput, dx, dg)
        return error
Exemplo n.º 2
0
 def test_projection_gaussian(self):
     config = {
             "projections": ['identity', 'gaussian'],
             "z": 2,
             "min": 0,
             "max": 1
             }
     subject = UniformDistribution(gan, config)
     projections = subject.create()
     assert int(projections.get_shape()[1]) == len(config['projections'])*config['z']
Exemplo n.º 3
0
 def test_projection(self):
     config = {
             "projections": [hg.distributions.uniform_distribution.identity],
             "z": 2,
             "min": 0,
             "max": 1
             }
     subject = UniformDistribution(gan, config)
     projections = subject.create()
     assert subject.ops.shape(projections)[1] == 2
 def test_projection_gaussian(self):
     config = {
             "projections": ['identity', 'gaussian'],
             "z": 2,
             "min": 0,
             "max": 1
             }
     subject = UniformDistribution(gan, config)
     with self.test_session():
         projections = subject.create()
         self.assertEqual(int(projections.get_shape()[1]), len(config['projections'])*config['z'])
Exemplo n.º 5
0
 def test_projection(self):
     config = {
             "projections": [hg.distributions.uniform_distribution.identity],
             "z": 2,
             "min": 0,
             "max": 1
             }
     subject = UniformDistribution(gan, config)
     with self.test_session():
         projections = subject.create()
         self.assertEqual(subject.ops.shape(projections)[1], 2)
Exemplo n.º 6
0
 def test_projection_twice(self):
     config = {
             "projections": ['identity', 'identity'],
             "z": 2,
             "min": 0,
             "max": 1
             }
     subject = UniformDistribution(gan, config)
     with self.test_session():
         projections = subject.create()
         self.assertEqual(int(projections.get_shape()[1]), len(config['projections'])*config['z'])
Exemplo n.º 7
0
def encoder(gan):
    config = {
        "projections": ['identity', 'identity'],
        "z": 2,
        "min": 0,
        "max": 1
    }
    return UniformDistribution(gan, config)
    def test_config(self):
        with self.test_session():
            gan = mock_gan(batch_size=32)
            gan.latent = UniformDistribution(gan, {'z':2, 'min': -1, 'max': 1, 'projections':['identity']})
            gan.latent.create()
            gan.create()

            sampler = GridSampler(gan)
            self.assertEqual(sampler._sample()['generator'].shape[-1], 1)
Exemplo n.º 9
0
            def ec(zt, cp, reuse=True):

                if config.noise:
                    randt = random_like(cp)
                    if config.proxy:
                        dist3 = UniformDistribution(self,
                                                    config.z_distribution)
                        proxy_c = self.create_component(config.proxy_c,
                                                        name='rand_ct',
                                                        input=dist3.sample,
                                                        reuse=reuse)
                        randt = proxy_c.sample

                        c = self.create_component(config.ec,
                                                  name='ec',
                                                  input=zt,
                                                  features={
                                                      'ct-1': cp,
                                                      'n': randt
                                                  },
                                                  reuse=reuse)
                    elif config.proxyrand:
                        c = self.create_component(config.ec,
                                                  name='ec',
                                                  input=zt,
                                                  features={
                                                      'ct-1': cp,
                                                      'n': random_like(zt)
                                                  },
                                                  reuse=reuse)
                    elif config.proxyrand2:
                        c = self.create_component(config.ec,
                                                  name='ec',
                                                  input=zt,
                                                  features={
                                                      'ct-1':
                                                      (cp + randt * 0.01),
                                                      'n': random_like(zt)
                                                  },
                                                  reuse=reuse)
                    else:
                        c = self.create_component(config.ec,
                                                  name='ec',
                                                  input=zt,
                                                  features={
                                                      'ct-1': cp,
                                                      'n': tf.zeros_like(zt)
                                                  },
                                                  reuse=reuse)

                if not reuse:
                    if config.proxy:
                        self._g_vars += proxy_c.variables()
                    self._g_vars += c.variables()
                    self.encoder = c
                return c.sample
Exemplo n.º 10
0
 def random_like(x):
     return UniformDistribution(self,
                                config.z_distribution,
                                output_shape=self.ops.shape(x)).sample
Exemplo n.º 11
0
    def create(self):
        config = self.config
        ops = self.ops
        d_losses = []
        g_losses = []

        def random_like(x):
            return UniformDistribution(self,
                                       config.z_distribution,
                                       output_shape=self.ops.shape(x)).sample

        with tf.device(self.device):
            x_input = tf.identity(self.inputs.x, name='input')

            encoder = self.create_encoder(self.inputs.x)

            # q(z|x)
            if config.u_to_z:
                latent = UniformDistribution(self, config.latent)
            else:
                z_shape = self.ops.shape(encoder.sample)
                uz_shape = z_shape
                uz_shape[-1] = uz_shape[-1] // len(config.latent.projections
                                                   or [1])
                latent = UniformDistribution(self,
                                             config.latent,
                                             output_shape=uz_shape)
            self.latent = latent

            direction, slider = self.create_controls(
                self.ops.shape(latent.sample))
            z = latent.sample + slider * direction
            #projected_encoder = UniformDistribution(self, config.encoder, z=encoder.sample)

            feature_dim = len(ops.shape(z)) - 1
            #stack_z = tf.concat([encoder.sample, z], feature_dim)
            #stack_encoded = tf.concat([encoder.sample, encoder.sample], feature_dim)
            stack_z = z

            if config.u_to_z:
                if config.style_encoder:
                    style_encoder = self.create_component(config.style_encoder,
                                                          input=x_input,
                                                          name='style_encoder')
                    style = style_encoder.sample
                    #style_sample = tf.concat(style, axis=0)
                    style_sample = style
                    #style_sample=random_like(style_sample)
                    #x_hat_style = style_sample
                    x_hat_style = random_like(style_sample)
                    #style_sample =  random_like(x_hat_style)
                    u_to_z = self.create_component(config.u_to_z,
                                                   name='u_to_z',
                                                   features=[style_sample],
                                                   input=z)
                    generator = self.create_component(config.generator,
                                                      input=u_to_z.sample,
                                                      features=[style_sample],
                                                      name='generator')
                else:
                    u_to_z = self.create_component(config.u_to_z,
                                                   name='u_to_z',
                                                   input=z)
                    generator = self.create_component(config.generator,
                                                      input=u_to_z.sample,
                                                      name='generator')
                stacked = [x_input, generator.sample]
                self.generator = generator

                self.encoder = encoder
                features = [encoder.sample, u_to_z.sample]
                self.u_to_z = u_to_z
            else:
                generator = self.create_component(config.generator,
                                                  input=stack_z,
                                                  name='generator')
                self.generator = generator
                stacked = [x_input, generator.sample]

                self.encoder = encoder
                features = ops.concat([encoder.sample, z], axis=0)

            if config.style_encoder:
                x_hat = self.create_component(config.generator,
                                              input=encoder.sample,
                                              features=[x_hat_style],
                                              reuse=True,
                                              name='generator').sample
                stacked += [x_hat]
                features += [encoder.sample]
            else:
                x_hat = self.create_component(config.generator,
                                              input=encoder.sample,
                                              reuse=True,
                                              name='generator').sample
            self.autoencoded_x = x_hat
            self.uniform_sample = generator.sample

            stacked_xg = tf.concat(stacked, axis=0)
            features_zs = tf.concat(features, axis=0)
            self.features = features_zs

            standard_discriminator = self.create_component(
                config.discriminator,
                name='discriminator',
                input=stacked_xg,
                features=[features_zs])
            self.discriminator = standard_discriminator

            d_vars = standard_discriminator.variables()
            g_vars = generator.variables() + encoder.variables()
            if config.style_encoder:
                g_vars += style_encoder.variables()
            if config.u_to_z:
                g_vars += u_to_z.variables()

            if self.config.manifold_guided:
                reencode_u_to_z = self.create_encoder(generator.sample,
                                                      reuse=True)
                stack_z = [encoder.sample, reencode_u_to_z.sample]
                stacked_zs = ops.concat(stack_z, axis=0)
                z_discriminator = self.create_component(config.z_discriminator,
                                                        name='z_discriminator',
                                                        input=stacked_zs)
                self.z_discriminator = z_discriminator
                d_vars += z_discriminator.variables()

            self._g_vars = g_vars
            self._d_vars = d_vars
            standard_loss = self.create_loss(config.loss,
                                             standard_discriminator, x_input,
                                             generator, len(stacked))
            if self.gan.config.infogan:
                d_vars += self.gan.infogan_q.variables()

            loss1 = ["g_loss", standard_loss.g_loss]
            loss2 = ["d_loss", standard_loss.d_loss]

            if self.config.manifold_guided:
                l2 = self.create_loss(config.loss,
                                      z_discriminator,
                                      x_input,
                                      generator,
                                      len(stack_z),
                                      reuse=True)
                d_losses.append(l2.d_loss)
                g_losses.append(l2.g_loss)

            d_losses.append(standard_loss.d_loss)
            g_losses.append(standard_loss.g_loss)
            if self.config.autoencode:
                l2_loss = self.ops.squash(10 * tf.square(x_hat - x_input))
                g_losses = [l2_loss]
                d_losses = [l2_loss]
            if self.config.vae:
                mu, sigma = self.encoder.variational
                eps = 1e-8
                lam = config.vae_lambda or 0.001
                latent_loss = lam * (0.5 * self.ops.squash(
                    tf.square(mu) - tf.square(sigma) -
                    tf.log(tf.square(sigma) + eps) - 1, tf.reduce_sum))
                g_losses.append(latent_loss)
                mu, sigma = u_to_z.variational
                latent_loss = lam * (0.5 * self.ops.squash(
                    tf.square(mu) - tf.square(sigma) -
                    tf.log(tf.square(sigma) + eps) - 1, tf.reduce_sum))
                g_losses.append(latent_loss)

            for i, l in enumerate(g_losses):
                self.add_metric('gl' + str(i), l)
            for i, l in enumerate(d_losses):
                self.add_metric('dl' + str(i), l)
            loss = hc.Config({
                'd_fake': standard_loss.d_fake,
                'd_real': standard_loss.d_real,
                'sample': [tf.add_n(d_losses),
                           tf.add_n(g_losses)]
            })
            self.loss = loss
            trainer = self.create_component(config.trainer,
                                            g_vars=g_vars,
                                            d_vars=d_vars)

            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = generator
        self.slider = slider
        self.direction = direction
        self.z = z
        self.z_hat = encoder.sample
        self.x_input = x_input
        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
        self.random_z = tf.random_uniform(ops.shape(latent.sample),
                                          -1,
                                          1,
                                          name='random_z')

        if hasattr(generator, 'mask_generator'):
            self.mask_generator = generator.mask_generator
            self.mask = mask
            self.autoencode_mask = generator.mask_generator.sample
            self.autoencode_mask_3_channel = generator.mask
Exemplo n.º 12
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            x_input = tf.identity(self.inputs.x, name='input')

            def random_like(x):
                return UniformDistribution(
                    self,
                    config.z_distribution,
                    output_shape=self.ops.shape(x)).sample

            # q(z|x)
            encoder = self.create_component(config.encoder,
                                            input=x_input,
                                            name='z_encoder')
            print("ENCODER ", encoder.sample)

            self.encoder = encoder
            z_shape = self.ops.shape(encoder.sample)
            style = self.create_component(config.style_encoder,
                                          input=x_input,
                                          name='style')
            self.style = style
            self.styleb = style  # hack for sampler
            self.random_style = random_like(style.sample)

            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(
                config.z_distribution.projections)
            UniformDistribution = UniformDistribution(self,
                                                      config.z_distribution,
                                                      output_shape=uz_shape)
            direction, slider = self.create_controls(
                self.ops.shape(UniformDistribution.sample))
            z = UniformDistribution.sample + slider * direction

            #projected_encoder = UniformDistribution(self, config.encoder, z=encoder.sample)

            feature_dim = len(ops.shape(z)) - 1
            #stack_z = tf.concat([encoder.sample, z], feature_dim)
            #stack_encoded = tf.concat([encoder.sample, encoder.sample], feature_dim)
            stack_z = z

            generator = self.create_component(config.generator,
                                              input=stack_z,
                                              features=[style.sample])
            self.uniform_sample = generator.sample
            x_hat = generator.reuse(encoder.sample)

            features_zs = ops.concat([encoder.sample, z], axis=0)
            stacked_xg = ops.concat([x_input, generator.sample], axis=0)

            if config.u_to_z:
                u_to_z = self.create_component(config.u_to_z,
                                               name='u_to_z',
                                               input=random_like(z))
                gu = generator.reuse(u_to_z.sample)
                stacked_xg = ops.concat([x_input, gu], axis=0)
                features_zs = ops.concat([encoder.sample, u_to_z.sample],
                                         axis=0)

            standard_discriminator = self.create_component(
                config.discriminator,
                name='discriminator',
                input=stacked_xg,
                features=[features_zs])
            standard_loss = self.create_loss(config.loss,
                                             standard_discriminator, x_input,
                                             generator, 2)

            l = standard_loss
            d_loss1 = l.d_loss
            g_loss1 = l.g_loss

            d_vars1 = standard_discriminator.variables()
            g_vars1 = generator.variables() + encoder.variables(
            ) + style.variables()

            metrics = {'g_loss': l.g_loss, 'd_loss': l.d_loss}
            t0 = random_like(
                style.sample
            )  #tf.concat([random_like(style1), random_like(style1)], axis=1)
            t1 = style.sample  #tf.concat([style1, style2], axis=1)
            stack = [t0, t1]
            stacked = ops.concat(stack, axis=0)
            features = None
            z_d = self.create_component(config.z_discriminator,
                                        name='forcerandom_discriminator',
                                        input=stacked)
            loss3 = self.create_component(config.loss,
                                          discriminator=z_d,
                                          x=x_input,
                                          generator=generator,
                                          split=2)
            metrics["forcerandom_gloss"] = loss3.g_loss
            metrics["forcerandom_dloss"] = loss3.d_loss
            if config.forcerandom:
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss
                d_vars1 += z_d.variables()

            if config.u_to_z:
                g_vars1 += u_to_z.variables()

            lossa = hc.Config({
                'sample': [d_loss1, g_loss1],
                'metrics': metrics
            })
            trainer = ConsensusTrainer(self,
                                       config.trainer,
                                       loss=lossa,
                                       g_vars=g_vars1,
                                       d_vars=d_vars1)

            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = generator
        self.uniform_distribution = uniform_encoder
        self.slider = slider
        self.direction = direction
        self.z = z
        self.z_hat = encoder.sample
        self.x_input = x_input
        self.autoencoded_x = x_hat
        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
        self.random_z = tf.random_uniform(ops.shape(
            UniformDistribution.sample),
                                          -1,
                                          1,
                                          name='random_z')

        if hasattr(generator, 'mask_generator'):
            self.mask_generator = generator.mask_generator
            self.mask = mask
            self.autoencode_mask = generator.mask_generator.sample
            self.autoencode_mask_3_channel = generator.mask
Exemplo n.º 13
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            if config.same_g:
                ga = self.create_component(config.generator, input=xb_input, name='a_generator')
                gb = self.create_component(config.generator, input=xa_input, name='a_generator', reuse=True)
            else:
                ga = self.create_component(config.generator, input=xb_input, name='a_generator')
                gb = self.create_component(config.generator, input=xa_input, name='b_generator')

            za = ga.controls["z"]
            zb = gb.controls["z"]

            self.uniform_sample = ga.sample

            xba = ga.sample
            xab = gb.sample
            xa_hat = ga.reuse(gb.sample)
            xb_hat = gb.reuse(ga.sample)

            z_shape = self.ops.shape(za)
            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(config.z_distribution.projections)
            ue = UniformDistribution(self, config.z_distribution, output_shape=uz_shape)
            ue2 = UniformDistribution(self, config.z_distribution, output_shape=uz_shape)
            ue3 = UniformDistribution(self, config.z_distribution, output_shape=uz_shape)
            ue4 = UniformDistribution(self, config.z_distribution, output_shape=uz_shape)
            print('ue', ue.sample)

            zua = ue.sample
            zub = ue2.sample

            uga = ga.reuse(tf.zeros_like(xb_input), replace_controls={"z":zua})
            ugb = gb.reuse(tf.zeros_like(xa_input), replace_controls={"z":zub})

            xa = xa_input
            xb = xb_input

            t0 = xb
            t1 = gb.sample
            t2 = ga.sample
            f0 = za
            f1 = zb
            f2 = za
            stack = [t0, t1]
            stacked = ops.concat(stack, axis=0)
            features = ops.concat([f0, f1], axis=0)

            d = self.create_component(config.discriminator, name='d_ab', input=stacked, features=[features])
            self.za = za
            self.discriminator = d
            l = self.create_loss(config.loss, d, xa_input, ga.sample, len(stack))
            loss1 = l
            d_loss1 = l.d_loss
            g_loss1 = l.g_loss

            d_vars1 = d.variables()
            print('--', ga, gb)

            d_loss = l.d_loss
            g_loss = l.g_loss

            metrics = {
                    'g_loss': l.g_loss,
                    'd_loss': l.d_loss
                }

            t0 = xb
            t1 = gb.sample
            f0 = za
            f1 = zb
            stack = [t0, t1]
            stacked = ops.concat(stack, axis=0)
            features = ops.concat([f0, f1], axis=0)

            d = self.create_component(config.discriminator, name='d_ab', input=stacked, features=[features], reuse=True)
            l = self.create_loss(config.loss, d, xa_input, ga.sample, len(stack))

            d_loss += l.d_loss
            g_loss += l.g_loss
            d_vars2 = d.variables()
            metrics["ga_gloss"]=l.g_loss
            metrics["ga_dloss"]=l.d_loss
            loss2=l
            g_loss2 = loss2.g_loss
            d_loss2 = loss2.d_loss
            trainers = []

            if(config.alpha):
                t0 = zub
                t1 = zb
                netzd = tf.concat(axis=0, values=[t0,t1])
                z_d = self.create_component(config.z_discriminator, name='z_discriminator', input=netzd)

                loss3 = self.create_component(config.loss, discriminator = z_d, x=xa_input, generator=ga, split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"]=loss3.g_loss
                metrics["za_dloss"]=loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss


            if(config.ug):
                t0 = xb
                t1 = ugb
                netzd = tf.concat(axis=0, values=[t0,t1])
                z_d = self.create_component(config.discriminator, name='ug_discriminator', input=netzd)

                loss3 = self.create_component(config.loss, discriminator = z_d, x=xa_input, generator=ga, split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"]=loss3.g_loss
                metrics["za_dloss"]=loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            if(config.alpha2):
                t0 = zua
                t1 = za
                netzd = tf.concat(axis=0, values=[t0,t1])
                z_d = self.create_component(config.z_discriminator, name='z_discriminator2', input=netzd)

                loss3 = self.create_component(config.loss, discriminator = z_d, x=xa_input, generator=ga, split=2)
                d_vars2 += z_d.variables()
                metrics["za_gloss"]=loss3.g_loss
                metrics["za_dloss"]=loss3.d_loss
                d_loss2 += loss3.d_loss
                g_loss2 += loss3.g_loss

            if config.ug2:
                t0 = xa
                t1 = uga
                netzd = tf.concat(axis=0, values=[t0,t1])
                z_d = self.create_component(config.discriminator, name='ug_discriminator2', input=netzd)

                loss3 = self.create_component(config.loss, discriminator = z_d, x=xa_input, generator=ga, split=2)
                d_vars2 += z_d.variables()
                metrics["za_gloss"]=loss3.g_loss
                metrics["za_dloss"]=loss3.d_loss
                d_loss2 += loss3.d_loss
                g_loss2 += loss3.g_loss


            loss = hc.Config({
                'd_fake':loss1.d_fake,
                'd_real':loss1.d_real,
                'sample': [tf.add_n([d_loss1, d_loss2]), tf.add_n([g_loss1,g_loss2])]
            })
            self._g_vars = ga.variables() + gb.variables()
            self._d_vars = d_vars1# + d_vars2
            self.loss=loss
            trainer = self.create_component(config.trainer)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = ga
        self.encoder = gb # this is the other gan
        self.uniform_distribution = hc.Config({"sample":zb})#uniform_encoder
        self.zb = zb
        self.z_hat = gb.sample
        self.x_input = xa_input
        self.autoencoded_x = xa_hat

        self.cyca = xa_hat
        self.cycb = xb_hat
        self.xba = xba
        self.xab = xab
        self.uga = uga
        self.ugb = ugb

        rgb = tf.cast((self.generator.sample+1)*127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb, 0xFF000000, name='generator_int')
Exemplo n.º 14
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            if config.same_g:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = self.create_component(config.generator,
                                           input=xa_input,
                                           name='a_generator',
                                           reuse=True)
            elif config.two_g:
                ga = self.create_component(config.generator1,
                                           input=xb_input,
                                           name='a_generator')
                gb = self.create_component(config.generator2,
                                           input=xa_input,
                                           name='b_generator')
            else:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = self.create_component(config.generator,
                                           input=xa_input,
                                           name='b_generator')

            za = ga.controls["z"]
            zb = gb.controls["z"]

            self.uniform_sample = ga.sample

            xba = ga.sample
            xab = gb.sample
            xa_hat = ga.sample
            xb_hat = gb.sample
            #xa_hat = ga.reuse(gb.sample)
            #xb_hat = gb.reuse(ga.sample)

            z_shape = self.ops.shape(za)
            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(config.latent.projections
                                               or [1])
            ue = UniformDistribution(self,
                                     config.latent,
                                     output_shape=uz_shape)
            ue2 = UniformDistribution(self,
                                      config.latent,
                                      output_shape=uz_shape)
            ue3 = UniformDistribution(self,
                                      config.latent,
                                      output_shape=uz_shape)
            ue4 = UniformDistribution(self,
                                      config.latent,
                                      output_shape=uz_shape)
            print('ue', ue.sample)

            zua = ue.sample
            zub = ue2.sample

            uga = ga.sample  #ga.reuse(tf.zeros_like(xb_input), replace_controls={"z":zua})
            ugb = gb.sample  #gb.reuse(tf.zeros_like(xa_input), replace_controls={"z":zub})

            xa = xa_input
            xb = xb_input

            re_ga = self.create_component(config.generator,
                                          input=gb.sample,
                                          name='a_generator',
                                          reuse=True)
            re_gb = self.create_component(config.generator,
                                          input=ga.sample,
                                          name='b_generator',
                                          reuse=True)
            re_zb = zb  #re_gb.controls['z']

            t0 = tf.concat([xb, xb], axis=3)
            t1 = tf.concat([gb.sample, gb.sample], axis=3)
            zaxis = len(self.ops.shape(za)) - 1

            f0 = tf.concat([za, za], axis=zaxis)
            f1 = tf.concat([zb, zb], axis=zaxis)
            #f0 = tf.concat([zb, za], axis=zaxis)
            #f1 = tf.concat([za, zb], axis=zaxis)
            stack = [t0, t1]
            stacked = ops.concat(stack, axis=0)
            features = ops.concat([f0, f1], axis=0)

            d = self.create_component(config.discriminator,
                                      name='d_ab',
                                      input=stacked,
                                      features=[features])

            self.za = za
            self.discriminator = d
            l = self.create_loss(config.loss, d, None, None, len(stack))
            loss = l
            d1_lambda = config.d1_lambda
            d2_lambda = config.d2_lambda
            d_loss1 = d1_lambda * l.d_loss
            g_loss1 = d1_lambda * l.g_loss

            d_vars1 = d.variables()

            d_loss = l.d_loss
            g_loss = l.g_loss

            metrics = {'g_loss': l.g_loss, 'd_loss': l.d_loss}

            self._g_vars = ga.variables() + gb.variables()
            self._d_vars = d_vars1
            self.loss = loss
            self.generator = gb
            trainer = self.create_component(config.trainer)
            self.initialize_variables()

        self.trainer = trainer
        self.latent = hc.Config({'sample': zb})
        self.generator = gb
        self.encoder = gb  # this is the other gan
        self.uniform_distribution = hc.Config({"sample": zb})  #uniform_encoder
        self.zb = zb
        self.z_hat = gb.sample
        self.x_input = xa_input
        self.autoencoded_x = xa_hat

        self.cyca = xa_hat
        self.cycb = xb_hat
        self.xba = xba
        self.xab = xab
        self.uga = uga
        self.ugb = ugb

        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
Exemplo n.º 15
0
 def random_t(shape):
     shape[-1] //= len(config.z_distribution.projections)
     return UniformDistribution(self, config.z_distribution, output_shape=shape).sample
Exemplo n.º 16
0
    def create(self):
        config = self.config
        ops = self.ops
        d_losses = []
        g_losses = []

        def random_like(x):
            return UniformDistribution(self,
                                       config.z_distribution,
                                       output_shape=self.ops.shape(x)).sample

        with tf.device(self.device):
            x_input = tf.identity(self.inputs.x, name='input')

            # q(z|x)
            if config.u_to_z:
                UniformDistribution = UniformDistribution(
                    self, config.z_distribution)
            else:
                z_shape = self.ops.shape(encoder.sample)
                uz_shape = z_shape
                uz_shape[-1] = uz_shape[-1] // len(
                    config.z_distribution.projections)
                UniformDistribution = UniformDistribution(
                    self, config.z_distribution, output_shape=uz_shape)
            self.uniform_distribution = uniform_encoder

            direction, slider = self.create_controls(
                self.ops.shape(UniformDistribution.sample))
            z = UniformDistribution.sample + slider * direction

            #projected_encoder = UniformDistribution(self, config.encoder, z=encoder.sample)

            feature_dim = len(ops.shape(z)) - 1
            #stack_z = tf.concat([encoder.sample, z], feature_dim)
            #stack_encoded = tf.concat([encoder.sample, encoder.sample], feature_dim)
            stack_z = z

            u_to_z = self.create_component(config.u_to_z,
                                           name='u_to_z',
                                           input=z)
            generator = self.create_component(config.generator,
                                              input=u_to_z.sample,
                                              name='generator')
            stacked = [x_input, generator.sample]
            self.generator = generator

            encoder = self.create_encoder(self.inputs.x)

            self.encoder = encoder
            features = [encoder.sample, u_to_z.sample]

            reencode_u_to_z = self.create_encoder(generator.sample, reuse=True)
            print("GEN SAMPLE", generator.sample, reencode_u_to_z.sample,
                  u_to_z.sample)
            reencode_u_to_z_to_g = self.create_component(
                config.generator,
                input=reencode_u_to_z.sample,
                name='generator',
                reuse=True)
            stacked += [reencode_u_to_z_to_g.sample]
            features += [reencode_u_to_z.sample]

            x_hat = self.create_component(config.generator,
                                          input=encoder.sample,
                                          reuse=True,
                                          name='generator').sample
            self.uniform_sample = generator.sample

            stacked_xg = tf.concat(stacked, axis=0)
            features_zs = tf.concat(features, axis=0)

            standard_discriminator = self.create_component(
                config.discriminator,
                name='discriminator',
                input=stacked_xg,
                features=[features_zs])
            self.discriminator = standard_discriminator
            d_vars = standard_discriminator.variables()
            g_vars = generator.variables() + encoder.variables()
            g_vars += u_to_z.variables()

            self._g_vars = g_vars
            self._d_vars = d_vars
            standard_loss = self.create_loss(config.loss,
                                             standard_discriminator, x_input,
                                             generator, len(stacked))
            if self.gan.config.infogan:
                d_vars += self.gan.infogan_q.variables()

            loss1 = ["g_loss", standard_loss.g_loss]
            loss2 = ["d_loss", standard_loss.d_loss]

            d_losses.append(standard_loss.d_loss)
            g_losses.append(standard_loss.g_loss)
            if self.config.autoencode:
                l2_loss = self.ops.squash(10 * tf.square(x_hat - x_input))
                g_losses = [l2_loss]
                d_losses = [l2_loss]

            for i, l in enumerate(g_losses):
                self.add_metric('gl' + str(i), l)
            for i, l in enumerate(d_losses):
                self.add_metric('dl' + str(i), l)
            loss = hc.Config({
                'd_fake': standard_loss.d_fake,
                'd_real': standard_loss.d_real,
                'sample': [tf.add_n(d_losses),
                           tf.add_n(g_losses)]
            })
            self.loss = loss
            self.uniform_distribution = uniform_encoder
            trainer = self.create_component(config.trainer,
                                            loss=loss,
                                            g_vars=g_vars,
                                            d_vars=d_vars)

            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = generator
        self.slider = slider
        self.direction = direction
        self.z = z
        self.z_hat = encoder.sample
        self.x_input = x_input
        self.autoencoded_x = x_hat
        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
        self.random_z = tf.random_uniform(ops.shape(
            UniformDistribution.sample),
                                          -1,
                                          1,
                                          name='random_z')

        if hasattr(generator, 'mask_generator'):
            self.mask_generator = generator.mask_generator
            self.mask = mask
            self.autoencode_mask = generator.mask_generator.sample
            self.autoencode_mask_3_channel = generator.mask
Exemplo n.º 17
0
    def create(self):
        config = self.config
        ops = self.ops
        d_losses = []
        g_losses = []
        encoder = self.create_encoder(self.inputs.x)

        with tf.device(self.device):
            x_input = tf.identity(self.inputs.x, name='input')

            if config.u_to_z:
                latent = UniformDistribution(self, config.latent)
            else:
                z_shape = self.ops.shape(encoder.sample)
                uz_shape = z_shape
                uz_shape[-1] = uz_shape[-1] // len(config.latent.projections)
                latent = UniformDistribution(self,
                                             config.latent,
                                             output_shape=uz_shape)
            self.uniform_distribution = latent
            self.latent = latent
            direction, slider = self.create_controls(
                self.ops.shape(latent.sample))
            z = latent.sample + slider * direction

            u_to_z = self.create_component(config.u_to_z,
                                           name='u_to_z',
                                           input=z)
            generator = self.create_component(config.generator,
                                              input=u_to_z.sample,
                                              name='generator')
            stacked = [x_input, generator.sample]
            self.generator = generator

            self.encoder = encoder
            features = [encoder.sample, u_to_z.sample]

            reencode_u_to_z = self.create_encoder(generator.sample, reuse=True)
            reencode_u_to_z_to_g = self.create_component(
                config.generator,
                input=reencode_u_to_z.sample,
                name='generator',
                reuse=True)

            self.reencode_g = reencode_u_to_z_to_g

            x_hat = self.create_component(config.generator,
                                          input=encoder.sample,
                                          reuse=True,
                                          name='generator').sample
            reencode_x_hat_to_z = self.create_encoder(x_hat, reuse=True)
            self.uniform_sample = generator.sample

            d_vars = []
            g_vars = generator.variables() + encoder.variables()
            g_vars += u_to_z.variables()

            def ali(*stack, reuse=False):
                xs = [t for t, _ in stack]
                zs = [t for _, t in stack]
                xs = tf.concat(xs, axis=0)
                zs = tf.concat(zs, axis=0)

                discriminator = self.create_component(config.discriminator,
                                                      name='discriminator',
                                                      input=xs,
                                                      features=[zs],
                                                      reuse=reuse)
                loss = self.create_loss(config.loss, discriminator, None, None,
                                        len(stack))
                return loss, discriminator

            def d(name, stack):
                if name is None:
                    name = config
                stacked = tf.concat(stack, axis=0)
                discriminator = self.create_component(config[name],
                                                      name=name,
                                                      input=stacked)
                loss = self.create_loss(config.loss, discriminator, None, None,
                                        len(stack))
                return loss, discriminator

            l1, d1 = ali([self.inputs.x, encoder.sample],
                         [generator.sample, u_to_z.sample],
                         [reencode_u_to_z_to_g.sample, reencode_u_to_z.sample])
            l2, d2 = ali(
                [self.inputs.x, tf.zeros_like(encoder.sample)],
                [generator.sample,
                 tf.zeros_like(u_to_z.sample)], [
                     reencode_u_to_z_to_g.sample,
                     tf.zeros_like(reencode_u_to_z.sample)
                 ],
                reuse=True)
            l3, d3 = ali([tf.zeros_like(self.inputs.x), encoder.sample],
                         [tf.zeros_like(generator.sample), u_to_z.sample], [
                             tf.zeros_like(reencode_u_to_z_to_g.sample),
                             reencode_u_to_z.sample
                         ],
                         reuse=True)

            if config.alternate:
                d_losses = [
                    beta * (l1.d_loss - l2.d_loss - l3.d_loss) + l2.d_loss +
                    2 * l3.d_loss
                ]
                g_losses = [
                    beta * (l1.g_loss - l2.g_loss - l3.g_loss) + l2.g_loss +
                    2 * l3.g_loss
                ]

            if config.mutual_only:
                d_losses = [2 * l1.d_loss - l2.d_loss - l3.d_loss]
                g_losses = [2 * l1.g_loss - l2.g_loss - l3.g_loss]
            else:

                l4, d4 = d('x_discriminator', [
                    self.inputs.x, generator.sample,
                    reencode_u_to_z_to_g.sample
                ])
                l5, d5 = d(
                    'z_discriminator',
                    [encoder.sample, u_to_z.sample, reencode_u_to_z.sample])

                beta = config.beta or 0.9
                d_losses = [
                    beta * (l1.d_loss - l2.d_loss - l3.d_loss) + l4.d_loss +
                    2 * l5.d_loss
                ]
                g_losses = [
                    beta * (l1.g_loss - l2.g_loss - l3.g_loss) + l4.g_loss +
                    2 * l5.g_loss
                ]
            self.discriminator = d1
            self.loss = l1

            self.add_metric("ld", d_losses[0])
            self.add_metric("lg", g_losses[0])

            if config.mutual_only or config.alternate:
                for d in [d1]:
                    d_vars += d.variables()
            else:
                for d in [d1, d4, d5]:
                    d_vars += d.variables()

            self._g_vars = g_vars
            self._d_vars = d_vars

            loss = hc.Config({
                'd_fake': l1.d_fake,
                'd_real': l1.d_real,
                'sample': [tf.add_n(d_losses),
                           tf.add_n(g_losses)]
            })
            self.loss = loss
            self.uniform_distribution = latent
            trainer = self.create_component(config.trainer,
                                            g_vars=g_vars,
                                            d_vars=d_vars)

            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = generator
        self.slider = slider
        self.direction = direction
        self.z = z
        self.z_hat = encoder.sample
        self.x_input = x_input
        self.autoencoded_x = x_hat
        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
        self.random_z = tf.random_uniform(ops.shape(latent.sample),
                                          -1,
                                          1,
                                          name='random_z')

        if hasattr(generator, 'mask_generator'):
            self.mask_generator = generator.mask_generator
            self.mask = mask
            self.autoencode_mask = generator.mask_generator.sample
            self.autoencode_mask_3_channel = generator.mask
Exemplo n.º 18
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            ga = self.create_component(config.generator,
                                       input=xb_input,
                                       name='a_generator')
            gb = self.create_component(config.generator,
                                       input=xa_input,
                                       name='b_generator')

            za = ga.controls["z"]
            zb = gb.controls["z"]

            self.uniform_sample = ga.sample

            xba = ga.sample
            xab = gb.sample
            xa_hat = ga.reuse(gb.sample)
            xb_hat = gb.reuse(ga.sample)

            z_shape = self.ops.shape(za)
            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(
                config.z_distribution.projections)
            ue = UniformDistribution(self,
                                     config.z_distribution,
                                     output_shape=uz_shape)
            features_a = ops.concat([ga.sample, xa_input], axis=0)
            features_b = ops.concat([gb.sample, xb_input], axis=0)
            stacked_a = ops.concat([xa_input, ga.sample], axis=0)
            stacked_b = ops.concat([xb_input, gb.sample], axis=0)
            stacked_z = ops.concat([ue.sample, za, zb], axis=0)
            da = self.create_component(config.discriminator,
                                       name='a_discriminator',
                                       input=stacked_a,
                                       features=[features_b])
            db = self.create_component(config.discriminator,
                                       name='b_discriminator',
                                       input=stacked_b,
                                       features=[features_a])
            dz = self.create_component(config.z_discriminator,
                                       name='z_discriminator',
                                       input=stacked_z)

            if config.ali_z:
                features_alia = ops.concat([za, ue.sample], axis=0)
                features_alib = ops.concat([zb, ue.sample], axis=0)
                uga = ga.reuse(tf.zeros_like(xb_input),
                               replace_controls={"z": ue.sample})
                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": ue.sample})
                stacked_alia = ops.concat([xa_input, uga], axis=0)
                stacked_alib = ops.concat([xb_input, ugb], axis=0)

                dalia = self.create_component(config.ali_discriminator,
                                              name='alia_discriminator',
                                              input=stacked_alia,
                                              features=[features_alib])
                dalib = self.create_component(config.ali_discriminator,
                                              name='alib_discriminator',
                                              input=stacked_alib,
                                              features=[features_alia])
                lalia = self.create_loss(config.loss, dalia, None, None, 2)
                lalib = self.create_loss(config.loss, dalib, None, None, 2)

            la = self.create_loss(config.loss, da, xa_input, ga.sample, 2)
            lb = self.create_loss(config.loss, db, xb_input, gb.sample, 2)
            lz = self.create_loss(config.loss, dz, None, None, 3)

            d_vars = da.variables() + db.variables() + lz.variables()
            if config.ali_z:
                d_vars += dalia.variables() + dalib.variables()
            g_vars = ga.variables() + gb.variables()

            d_loss = la.d_loss + lb.d_loss + lz.d_loss
            g_loss = la.g_loss + lb.g_loss + lz.g_loss
            metrics = {
                'ga_loss': la.g_loss,
                'gb_loss': lb.g_loss,
                'gz_loss': lz.g_loss,
                'da_loss': la.d_loss,
                'db_loss': lb.d_loss,
                'dz_loss': lz.d_loss
            }
            if config.ali_z:
                d_loss += lalib.d_loss + lalia.d_loss
                g_loss += lalib.g_loss + lalia.g_loss
                metrics['galia_loss'] = lalia.g_loss
                metrics['galib_loss'] = lalib.g_loss
                metrics['dalia_loss'] = lalia.d_loss
                metrics['dalib_loss'] = lalib.d_loss

            loss = hc.Config({'sample': [d_loss, g_loss], 'metrics': metrics})
            trainer = ConsensusTrainer(self,
                                       config.trainer,
                                       loss=loss,
                                       g_vars=g_vars,
                                       d_vars=d_vars)
            self.initialize_variables()

        self.trainer = trainer
        self.generator = ga
        self.encoder = gb  # this is the other gan
        self.uniform_distribution = hc.Config({"sample": za})  #uniform_encoder
        self.zb = zb
        self.z_hat = gb.sample
        self.x_input = xa_input
        self.autoencoded_x = xa_hat

        self.cyca = xa_hat
        self.cycb = xb_hat
        self.xba = xba
        self.xab = xab

        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
Exemplo n.º 19
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            x_input = tf.identity(self.inputs.x, name='input')

            encoder = self.create_encoder(x_input)
            self.encoder = encoder
            z_shape = self.ops.shape(encoder.sample)

            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(config.encoder.projections)
            UniformDistribution = UniformDistribution(self, config.encoder, output_shape=uz_shape)
            direction, slider = self.create_controls(self.ops.shape(UniformDistribution.sample))
            z = UniformDistribution.sample + slider * direction
            
            #projected_encoder = UniformDistribution(self, config.encoder, z=encoder.sample)

            z_discriminator = self.create_z_discriminator(UniformDistribution.sample, encoder.sample)

            feature_dim = len(ops.shape(z))-1
            #stack_z = tf.concat([encoder.sample, z], feature_dim)
            #stack_encoded = tf.concat([encoder.sample, encoder.sample], feature_dim)
            stack_z = z

            generator = self.create_component(config.generator, input=stack_z)
            self.uniform_sample = generator.sample
            x_hat = generator.reuse(encoder.sample)

            if hasattr(generator, 'mask_single_channel'):
                mask = generator.mask_single_channel

            encoder_loss = self.create_loss(config.eloss or config.loss, z_discriminator, z, encoder, 2)
            if config.segments_included:
                newsample = generator.reuse(stack_z, mask=generator.mask_single_channel)
#                stacked = [x_input, generator.sample, newsample, x_hat]
                stacked = [x_input, newsample, generator.sample, x_hat, generator.g1x, generator.g2x, generator.g3x]
                #stacked = [x_input, g1x, g2x, newsample, generator.sample, x_hat]
                #stacked = [x_input, newsample, generator.sample, x_hat]
            elif config.simple_d:
                stacked = [x_input, self.uniform_sample]
            else:
                stacked = [x_input, self.uniform_sample, x_hat]

            stacked_xg = ops.concat(stacked, axis=0)
            standard_discriminator = self.create_component(config.discriminator, name='discriminator', input=stacked_xg)
            standard_loss = self.create_loss(config.loss, standard_discriminator, x_input, generator, len(stacked))

            #loss terms
            cycloss = self.create_cycloss(x_input, x_hat)
            z_cycloss = self.create_z_cycloss(UniformDistribution.sample, encoder.sample, encoder, generator)

            #first_pixel = tf.slice(generator.mask_single_channel, [0,0,0,0], [-1,1,1,-1]) + 1 # we want to minimize range -1 to 1
            #cycloss += tf.reduce_sum(tf.reshape(first_pixel, [-1]), axis=0)

            if hasattr(generator, 'mask'): #TODO only segment
                cycloss_whitening_lambda = config.cycloss_whitening_lambda or 0.01
                cycloss += tf.reduce_mean(tf.reshape(0.5-tf.abs(generator.mask-0.5), [-1]), axis=0) * cycloss_whitening_lambda

            #if hasattr(generator, 'mask'): # TODO only multisegment
            #    cycloss_single_channel_lambda = config.cycloss_single_channel_lambda or 0.01
            #    m = tf.reduce_sum(generator.mask, 3)
            #    cycloss += tf.reduce_mean(tf.reshape(tf.abs(1.0-m)/ops.shape(generator.mask)[3], [-1]), axis=0) * cycloss_single_channel_lambda
            if hasattr(generator, 'mask'): # TODO only multisegment
            #    cycloss_single_channel_lambda = config.cycloss_single_channel_lambda or 0.01
                m = tf.reduce_mean(generator.mask, 1, keep_dims=True)
                m = tf.reduce_mean(m, 2, keep_dims=True)
                c = 0.1
                cycloss += (c - tf.minimum(tf.reduce_min(m, 3, keep_dims=True), c))
            #    cycloss += tf.reduce_mean(tf.reshape(tf.abs(1.0-m)/ops.shape(generator.mask)[3], [-1]), axis=0) * cycloss_single_channel_lambda

            trainer = self.create_trainer(cycloss, z_cycloss, encoder, generator, encoder_loss, standard_loss, standard_discriminator, z_discriminator)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = generator
        self.uniform_distribution = uniform_encoder
        self.slider = slider
        self.direction = direction
        self.z = z
        self.z_hat = encoder.sample
        self.x_input = x_input
        self.autoencoded_x = x_hat
        rgb = tf.cast((self.generator.sample+1)*127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb, 0xFF000000, name='generator_int')
        self.random_z = tf.random_uniform(ops.shape(UniformDistribution.sample), -1, 1, name='random_z')

        if hasattr(generator, 'mask_generator'):
            self.mask_generator = generator.mask_generator
            self.mask = mask
            self.autoencode_mask = generator.mask_generator.sample
            self.autoencode_mask_3_channel = generator.mask
Exemplo n.º 20
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            ue = UniformDistribution(self, config.latent)
            ue2 = UniformDistribution(self, config.latent)
            ue3 = UniformDistribution(self, config.latent)
            ue4 = UniformDistribution(self, config.latent)
            if config.same_g:
                gb = self.create_component(config.generator,
                                           input=ue3.sample,
                                           name='a_generator')
                gb = self.create_component(config.generator,
                                           input=ue4.sample,
                                           name='a_generator',
                                           reuse=True)
            else:

                def _append_xa(gan, config, net):
                    _x = gan.inputs.xa
                    s = [int(x) for x in net.get_shape()]
                    shape = [s[1], s[2]]
                    _x = tf.image.resize_images(_x, shape, 1)
                    return _x

                def _append_xb(gan, config, net):
                    _x = gan.inputs.xb
                    s = [int(x) for x in net.get_shape()]
                    shape = [s[1], s[2]]
                    _x = tf.image.resize_images(_x, shape, 1)
                    return _x

                #config.generator['layer_filter'] = _append_xb
                #gb = self.create_component(config.generator, input=ue3.sample, name='a_generator')
                config.generator['layer_filter'] = _append_xa
                gb = self.create_component(config.generator,
                                           input=ue4.sample,
                                           name='b_generator')

            za = gb.controls["z"]
            zb = gb.controls["z"]

            self.uniform_sample = gb.sample

            zua = ue.sample
            zub = ue2.sample

            xa = xa_input
            xb = xb_input

            t0 = tf.concat([xb, xb], axis=3)
            t1 = tf.concat([gb.sample, xb], axis=3)
            t0 = xb
            t1 = gb.sample
            f0 = za
            f1 = zb
            f2 = zub
            stack = [t0, t1]
            stacked = ops.concat(stack, axis=0)
            features = ops.concat([f0, f1], axis=0)

            config.discriminator['layer_filter'] = _append_xa
            d = self.create_component(config.discriminator,
                                      name='d_ab',
                                      input=stacked)

            self.za = za
            self.discriminator = d
            l = self.create_loss(config.loss, d, xb_input, gb.sample,
                                 len(stack))
            loss1 = l
            d1_lambda = config.d1_lambda
            d2_lambda = config.d2_lambda
            d_loss1 = d1_lambda * l.d_loss
            g_loss1 = d1_lambda * l.g_loss

            d_vars1 = d.variables()

            d_loss = l.d_loss
            g_loss = l.g_loss

            metrics = {'g_loss': l.g_loss, 'd_loss': l.d_loss}

            if d2_lambda > -1:
                t0 = xa
                t1 = gb.sample
                f0 = zb
                f1 = za
                stack = [t0, t1]
                stacked = ops.concat(stack, axis=0)
                features = ops.concat([f0, f1], axis=0)

                config.discriminator['layer_filter'] = _append_xa
                d = self.create_component(config.discriminator,
                                          name='d2_ab',
                                          input=stacked,
                                          features=[features])
                self.discriminator2 = d
                l = self.create_loss(config.loss, d, xb_input, gb.sample,
                                     len(stack))

                d_vars2 = d.variables()
                metrics["gb_gloss"] = l.g_loss
                metrics["gb_dloss"] = l.d_loss
                loss2 = l
                g_loss2 = d2_lambda * loss2.g_loss
                d_loss2 = d2_lambda * loss2.d_loss
            else:
                d_loss2 = 0.0
                g_loss2 = 0.0
                d_vars2 = []

            if (config.alpha):
                t0 = zub
                t1 = zb
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=gb,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            if (config.ug):
                t0 = xb
                t1 = ugb
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.discriminator,
                                            name='ug_discriminator',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=gb,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            if (config.alpha2):
                t0 = zua
                t1 = za
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator2',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=gb,
                                              split=2)
                d_vars2 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss2 += loss3.d_loss
                g_loss2 += loss3.g_loss

            if config.ug2:
                t0 = xa
                t1 = ugb
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.discriminator,
                                            name='ug_discriminator2',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=gb,
                                              split=2)
                d_vars2 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss2 += loss3.d_loss
                g_loss2 += loss3.g_loss

            loss = hc.Config({
                'd_fake':
                loss1.d_fake,
                'd_real':
                loss1.d_real,
                'sample':
                [tf.add_n([d_loss1, d_loss2]),
                 tf.add_n([g_loss1, g_loss2])]
            })
            self._g_vars = gb.variables() + gb.variables()
            self._d_vars = d_vars1 + d_vars2
            self.loss = loss
            self.generator = gb
            trainer = self.create_component(config.trainer)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.latent = hc.Config({'sample': zb})
        self.generator = gb
        self.encoder = gb  # this is the other gan
        self.uniform_distribution = hc.Config({"sample": zb})  #uniform_encoder
        self.zb = zb
        self.z_hat = gb.sample
        self.x_input = xa_input

        self.uga = gb.sample
        self.ugb = gb.sample
        self.cyca = gb.sample
        self.cycb = gb.sample
        self.xba = gb.sample
        self.xab = gb.sample

        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
Exemplo n.º 21
0
    def create(self):
        config = self.config
        ops = self.ops
        self._g_vars = []
        d_vars = []

        with tf.device(self.device):
            def random_t(shape):
                shape[-1] //= len(config.z_distribution.projections)
                return UniformDistribution(self, config.z_distribution, output_shape=shape).sample
            def random_like(x):
                shape = self.ops.shape(x)
                return random_t(shape)

            self.frame_count = len(self.inputs.frames)
            self.frames = self.inputs.frames

            dist = UniformDistribution(self, config.z_distribution)
            dist2 = UniformDistribution(self, config.z_distribution)
            dist3 = UniformDistribution(self, config.z_distribution)
            dist4 = UniformDistribution(self, config.z_distribution)
            dist5 = UniformDistribution(self, config.z_distribution)
            uz = self.create_component(config.uz, name='u_to_z', input=dist.sample)
            uc = self.create_component(config.uc, name='u_to_c', input=dist2.sample)
            uz2 = self.create_component(config.uz, name='u_to_z', input=dist3.sample, reuse=True)
            uc2 = self.create_component(config.uc, name='u_to_c', input=dist4.sample, reuse=True)
            uc3 = self.create_component(config.uc, name='u_to_c', input=dist5.sample, reuse=True)

            self._g_vars += uz.variables()
            self._g_vars += uc.variables()

            def ec(zt, cp,reuse=True):
                if config.noise:
                    randt = random_like(cp)
                    if config.proxy:
                        dist3 = UniformDistribution(self, config.z_distribution)
                        proxy_c = self.create_component(config.proxy_c, name='rand_ct', input=dist3.sample, reuse=reuse)
                        randt = proxy_c.sample

                        c = self.create_component(config.ec, name='ec', input=zt, features={'ct-1':cp, 'n':randt}, reuse=reuse)
                    elif config.proxyrand:
                        c = self.create_component(config.ec, name='ec', input=zt, features={'ct-1':cp, 'n':random_like(zt)}, reuse=reuse)
                    elif config.proxyrand2:
                        c = self.create_component(config.ec, name='ec', input=zt, features={'ct-1':(cp+randt*0.01), 'n':random_like(zt) }, reuse=reuse)
                    else:
                        c = self.create_component(config.ec, name='ec', input=zt, features={'ct-1':cp, 'n':tf.zeros_like(zt)}, reuse=reuse)

                if not reuse:
                    if config.proxy:
                        self._g_vars += proxy_c.variables()
                    self._g_vars += c.variables()
                    self.encoder = c
                return c.sample
            def ez(ft, zp,reuse=True):
                z = self.create_component(config.ez, name='ez', input=ft, features=[zp], reuse=reuse)
                if not reuse:
                    self._g_vars += z.variables()
                return z.sample

            def build_g(zt, ct, reuse=True):
                print("Gb", reuse,zt,ct)
                g = self.create_component(config.generator, name='generator', input=ct, features=[zt], reuse=reuse)
                if not reuse:
                    self._g_vars += g.variables()
                return g.sample

            def encode_frames(fs, c0, z0, reuse=True):
                cs = [c0]
                zs = [z0]
                x_hats = [build_g(zs[-1], cs[-1], reuse=reuse)]
                for i in range(len(fs)):
                    print("encode frames", i)
                    _reuse = reuse or (i!=0)
                    z = ez(fs[i], zs[-1], reuse=_reuse)
                    c = ec(z, cs[-1], reuse=_reuse)
                    x_hat = build_g(z, c, reuse=True)
                    zs.append(z)
                    cs.append(c)
                    x_hats.append(x_hat)
                return cs, zs, x_hats

            def build_sim(z0, c0, steps, reuse=True):
                zs = [z0]
                cs = [c0]
                gs = [build_g(zs[-1], cs[-1], reuse=reuse)]
                for i in range(steps):
                    _reuse = reuse or (i!=0)
                    z = ez(gs[-1], zs[-1], reuse=_reuse)
                    c = ec(z, cs[-1], reuse=_reuse)
                    g = build_g(z, c, reuse=True)
                    zs.append(z)
                    cs.append(c)
                    gs.append(g)

                return gs, cs, zs

            def rotate(first, second, offset=None):
                rotations = [tf.concat(first[:offset], axis=axis)]
                elem = first
                for e in second:
                    elem = elem[1:]+[e]
                    rotations.append(tf.concat(elem[:offset], axis=axis))
                return rotations

            def disc(metric, name, _inputs, _features, reuse=False):
               _is = tf.concat(_inputs,axis=0)
               _fs = tf.concat(_features,axis=0)
               disc = self.create_component(config[name], name=name, input=_is, features=[_fs], reuse=reuse)
               l2 = self.create_loss(config.loss, disc, None, None, len(_inputs))
               self.add_metric(metric, l2.d_loss)
               self.add_metric(metric, l2.g_loss)
               return l2, disc.variables(), disc

            def mi(metric, name, _inputs, _features):
                _inputsb = [tf.zeros_like(x) for x in _inputs]
                _featuresb = [tf.zeros_like(x) for x in _features]
                beta = config.bottleneck_beta or 1
                ib_2_c = config.ib_2_c or 1
                inputs = tf.concat(_inputs, axis=0)
                features = tf.concat(_features, axis=0)
                gl,dl,d_vars,_ = disc(metric+'1', name, _inputs, _features)
                gl2,dl2, _,_ = disc(metric+'2', name, _inputs, _features, reuse=True)
                gl3,dl3, _,_ = disc(metric+'3', name, _inputs, _features, reuse=True)
                dls = ib_2_c * beta *tf.add_n([dl,-dl2,-dl3])
                gls = ib_2_c * beta *tf.add_n([gl,-gl2,-gl3])

                return gls, dls, d_vars


            #self.frames = [f+tf.random_uniform(self.ops.shape(f), minval=-0.1, maxval=0.1) for f in self.frames ]
            #cs, zs, x_hats = encode_frames(self.frames, tf.zeros_like(uc2.sample), tf.zeros_like(uz2.sample), reuse=False)
            if config.randd:
                cs, zs, x_hats = encode_frames(self.frames, uc2.sample, tf.zeros_like(uz2.sample), reuse=False)
            elif config.zerod:
                cs, zs, x_hats = encode_frames(self.frames, tf.zeros_like(uc2.sample), tf.zeros_like(uz2.sample), reuse=False)
            else:
                cs, zs, x_hats = encode_frames(self.frames, uc2.sample, uz2.sample, reuse=False)
            extra_frames = config.extra_frames or 2
            self.zs = zs
            self.cs = cs
            if config.zerod:
                ugs, ucs, uzs = build_sim(uz.sample, uc.sample, len(self.frames))
            else:
                ugs, ucs, uzs = build_sim(uz.sample, uc.sample, len(self.frames))
            alt_gs, alt_cs, alt_zs = build_sim(zs[1], cs[1], len(self.frames))
            self.ucs = ucs
            self.alt_cs = alt_cs
            self.alt_zs = alt_zs
            ugs_next, ucs_next, uzs_next = build_sim(uzs[-1], ucs[-1], len(self.frames))
            re_ucs_next, re_uzs_next, re_ugs_next = encode_frames(ugs_next[1:len(self.frames)], ucs_next[0], uzs_next[0])
            gs_next, cs_next, zs_next = build_sim(zs[-1], cs[-1], len(self.frames)+extra_frames)
            re_ucs, re_uzs, ugs_hat = encode_frames(ugs[1:len(self.frames)], ucs[0], uzs[0])
            re_cs, re_zs, re_ugs = encode_frames(x_hats[1:len(self.frames)], cs[0], zs[0])
            re_cs_next, re_zs_next, re_gs_next = encode_frames(gs_next[1:len(self.frames)], cs_next[0], zs_next[0])
            self.x_hats = x_hats
            axis = len(ops.shape(zs[1]))-1
            t0 = tf.concat(zs[1:-1], axis=axis)
            t1 = tf.concat(uzs[1:-1], axis=axis)
            t2 = tf.concat(zs_next[1:len(cs)-1], axis=axis)
            if config.manifold_guided:
                t1 = tf.concat(re_uzs[1:], axis=axis)
                t2 = tf.concat(re_zs_next[1:len(cs)-1], axis=axis)
            t3 = re_uzs_next#tf.concat(re_ucs_next, axis=axis)


            t0 = tf.concat(self.frames[1:], axis=axis)
            f0 = tf.concat(cs[1:-1], axis=axis)
            self.x0 = t0

            stack = [t0]
            features = [f0]
            if config.encode_x_hat:
                #stack += rotate(ugs[:-2], ugs[-2:]+ugs_next)
                #features += rotate(ucs[:-2], ucs[-2:]+ucs_next)
                stack += [tf.concat(x_hats[2:], axis=axis)]
                features += [tf.concat(cs[1:-1], axis=axis)]

            if config.encode_alternate_path:
                #stack += rotate(ugs[:-2], ugs[-2:]+ugs_next)
                #features += rotate(ucs[:-2], ucs[-2:]+ucs_next)
                stack.append(tf.concat([self.frames[1]]+alt_gs[1:-2], axis=axis))
                features.append(tf.concat(alt_cs[:-2], axis=axis))
     

            if config.encode_ug:
                #stack += rotate(ugs[:-2], ugs[-2:]+ugs_next)
                #features += rotate(ucs[:-2], ucs[-2:]+ucs_next)
                self.g0 = tf.concat(ugs[1:-1], axis=axis)
                self.c0 = tf.concat(ucs[1:-1], axis=axis)
                stack.append(self.g0)
                features.append(self.c0)
            if config.encode_re_ug:
                stack.append(tf.concat(re_ugs[1:], axis=axis))
                features.append(tf.concat(re_ucs[1:], axis=axis))
                
            if config.encode_forward:
                stack += rotate(self.frames[2:]+[gs_next[0]], gs_next[1:])
                features += rotate(cs[2:], cs_next[1:])
                #print("GS", gs_next, features)
                #stack += rotate(gs_next[:-4], gs_next[-4:])
                #features += rotate(cs_next[:-4], cs_next[-4:])
            #if config.encode_forward_next:
            #    stack += [tf.concat(gs_next[:-4],axis=axis)]
            #    features += [tf.concat(cs_next[:-4],axis=axis)]

            #d = self.create_component(config.z_discriminator, name='d_img', input=tf.concat(features, axis=0), features=[None])
            #_d = d
            #d_vars += d.variables()
            #l = self.create_loss(config.loss, d, None, None, len(stack))
            #Ic =0.1 
            #d_loss = 2*l.d_loss
            #g_loss = 2*l.g_loss

            self.video_generator_last_z = uzs[0]
            self.video_generator_last_c = ucs[0]
            self.gs_next = gs_next
            ztn = uzs[1]
            ctn = ucs[1]
            self.video_generator_last_zn = ztn
            self.video_generator_last_cn = ctn
            self.c_drift = self.ops.squash(tf.abs(ucs[1]-ucs[0]))
            gen = hc.Config({"sample":ugs[0]})


            #_stacked = ops.concat(stack, axis=0)
            #d = self.create_component(config.discriminator, name='d_manifold', input=_stacked, features=[None])
            #d_vars += d.variables()
            #l = self.create_loss(config.loss, d, None, None, len(stack))
            #d_loss += l.d_loss
            #g_loss += l.g_loss


            #gl, dl, dvs = mi('mi', 'b_discriminator2', stack, features)
            #g_loss += gl
            #d_loss += dl
            #d_vars += dvs
            l,dvs,disc = disc('m1', 'discriminator', stack, features)
            self.discriminator = disc
            g_loss = l.g_loss
            d_loss = l.d_loss
            d_vars += dvs

            if config.vae:
                if(hasattr(self, "variational")):
                    for i,var in enumerate(self.variational):
                        mu,sigma = var
                        eps = 1e-8
                        lam = config.vae_lambda or 0.01
                        latent_loss = lam*(0.5 *self.ops.squash(tf.square(mu)-tf.square(sigma) - tf.log(tf.square(sigma)+eps) - 1, tf.reduce_sum ))
                        self.add_metric("vae"+str(i), latent_loss)
                        #d_loss += latent_loss
                        g_loss -= latent_loss


            gx_sample = gen.sample
            gy_sample = gen.sample
            gx = hc.Config({"sample":gx_sample})
            gy = hc.Config({"sample":gy_sample})

            last_frame = tf.slice(gy_sample, [0,0,0,0], [-1, -1, -1, 3])
            self.y = hc.Config({"sample":last_frame})
            self.gy = self.y
            self.gx = self.y
            self.uniform_sample = gen.sample

            self.preview = tf.concat(self.inputs.frames[:-1] + [gen.sample], axis=1)#tf.concat(tf.split(gen.sample, (self.ops.shape(gen.sample)[3]//3), 3), axis=1)


            trainers = []

            lossa = hc.Config({'sample': [d_loss, g_loss], 'd_fake': l.d_fake, 'd_real': l.d_real, 'config': l.config})
            self.loss = lossa
            self._d_vars = d_vars
            trainer = self.create_component(config.trainer, loss = lossa, g_vars = self._g_vars, d_vars = d_vars)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = gx
        self.z_hat = gy.sample
        self.x_input = self.inputs.frames[0]

        self.uga = self.y.sample
        self.uniform_distribution = dist
Exemplo n.º 22
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            x_input = tf.identity(self.inputs.x, name='input')

            # q(z|x)
            encoder = self.create_encoder(x_input)

            self.encoder = encoder
            z_shape = self.ops.shape(encoder.sample)

            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(config.encoder.projections)
            UniformDistribution = UniformDistribution(self, config.encoder, output_shape=uz_shape)
            direction, slider = self.create_controls(self.ops.shape(UniformDistribution.sample))
            z = UniformDistribution.sample + slider * direction
            
            #projected_encoder = UniformDistribution(self, config.encoder, z=encoder.sample)


            feature_dim = len(ops.shape(z))-1
            #stack_z = tf.concat([encoder.sample, z], feature_dim)
            #stack_encoded = tf.concat([encoder.sample, encoder.sample], feature_dim)
            stack_z = z

            generator = self.create_component(config.generator, input=stack_z)
            self.uniform_sample = generator.sample
            x_hat = generator.reuse(encoder.sample)

            # z = random uniform
            # z_hat = z of x
            # g = random generated
            # x_input

            stacked_xg = ops.concat([generator.sample, x_input], axis=0)
            stacked_zs = ops.concat([z, encoder.sample], axis=0)
            standard_discriminator = self.create_component(config.discriminator, name='discriminator', input=stacked_xg, features=[stacked_zs])
            z_discriminator = self.create_z_discriminator(UniformDistribution.sample, encoder.sample)
            standard_loss = self.create_loss(config.loss, standard_discriminator, x_input, generator, 2)

            encoder_loss = self.create_loss(config.eloss or config.loss, z_discriminator, z, encoder, 2)

            trainer = self.create_trainer(None, None, encoder, generator, encoder_loss, standard_loss, standard_discriminator, z_discriminator)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = generator
        self.uniform_distribution = uniform_encoder
        self.slider = slider
        self.direction = direction
        self.z = z
        self.z_hat = encoder.sample
        self.x_input = x_input
        self.autoencoded_x = x_hat
        rgb = tf.cast((self.generator.sample+1)*127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb, 0xFF000000, name='generator_int')
        self.random_z = tf.random_uniform(ops.shape(UniformDistribution.sample), -1, 1, name='random_z')

        if hasattr(generator, 'mask_generator'):
            self.mask_generator = generator.mask_generator
            self.mask = mask
            self.autoencode_mask = generator.mask_generator.sample
            self.autoencode_mask_3_channel = generator.mask
Exemplo n.º 23
0
import hyperchamber as hc
import numpy as np
import hypergan as hg
from hypergan.distributions.uniform_distribution import UniformDistribution
from hypergan.gan_component import ValidationException

from unittest.mock import MagicMock
from tests.mocks import MockDiscriminator, mock_gan

gan = mock_gan()
distribution = UniformDistribution(gan, {
    'test':True,
    "z": 2,
    "min": 0,
    "max": 1
})
class TestUniformDistribution:
    def test_config(self):
        assert distribution.config.test == True

    def test_projection(self):
        config = {
                "projections": [hg.distributions.uniform_distribution.identity],
                "z": 2,
                "min": 0,
                "max": 1
                }
        subject = UniformDistribution(gan, config)
        projections = subject.create()
        assert subject.ops.shape(projections)[1] == 2
Exemplo n.º 24
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            if config.same_g:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = hc.Config({
                    "sample": ga.reuse(xa_input),
                    "controls": {
                        "z": ga.controls['z']
                    },
                    "reuse": ga.reuse
                })
            else:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = self.create_component(config.generator,
                                           input=xa_input,
                                           name='b_generator')

            za = ga.controls["z"]
            zb = gb.controls["z"]

            self.uniform_sample = ga.sample

            xba = ga.sample
            xab = gb.sample
            xa_hat = ga.reuse(gb.sample)
            xb_hat = gb.reuse(ga.sample)

            z_shape = self.ops.shape(za)
            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(
                config.z_distribution.projections)
            ue = UniformDistribution(self,
                                     config.z_distribution,
                                     output_shape=uz_shape)
            ue2 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            ue3 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            ue4 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            print('ue', ue.sample)

            zua = ue.sample
            zub = ue2.sample

            uga = ga.reuse(tf.zeros_like(xb_input),
                           replace_controls={"z": zua})
            ugb = gb.reuse(tf.zeros_like(xa_input),
                           replace_controls={"z": zub})

            xa = xa_input
            xb = xb_input

            t0 = ops.concat([xb, xa], axis=3)
            t1 = ops.concat([ugb, uga], axis=3)
            t2 = ops.concat([gb.sample, ga.sample], axis=3)
            f0 = ops.concat([za, zb], axis=3)
            f1 = ops.concat([zub, zua], axis=3)
            f2 = ops.concat([zb, za], axis=3)
            features = ops.concat([f0, f1, f2], axis=0)
            stack = [t0, t1, t2]

            if config.mess2:
                xbxa = ops.concat([xb_input, xa_input], axis=3)
                gbga = ops.concat([gb.sample, ga.sample], axis=3)
                fa = ops.concat([za, zb], axis=3)
                fb = ops.concat([za, zb], axis=3)
                features = ops.concat([fa, fb], axis=0)
                stack = [xbxa, gbga]

            if config.mess6:
                t0 = ops.concat([xb, xa], axis=3)

                t1 = ops.concat([gb.sample, uga], axis=3)
                t2 = ops.concat([gb.sample, xa], axis=3)
                t3 = ops.concat([xb, ga.sample], axis=3)
                t4 = ops.concat([ugb, ga.sample], axis=3)
                features = None
                stack = [t0, t1, t2, t3, t4]

            if config.mess7:
                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": zua})
                t0 = ops.concat([xb, ga.sample], axis=3)
                t1 = ops.concat([ugb, uga], axis=3)
                t2 = ops.concat([gb.sample, xa], axis=3)
                features = None
                stack = [t0, t1, t2]

            if config.mess8:
                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": zua})
                uga2 = ga.reuse(tf.zeros_like(xa_input),
                                replace_controls={"z": zub})
                ugb2 = gb.reuse(tf.zeros_like(xa_input),
                                replace_controls={"z": zub})
                t0 = ops.concat([xb, ga.sample, xa, gb.sample], axis=3)
                t1 = ops.concat([ugb, uga, uga2, ugb2], axis=3)
                features = None
                stack = [t0, t1]

            if config.mess10:
                t0 = ops.concat([xb, xa], axis=3)

                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": zua})
                t1 = ops.concat([ugb, uga], axis=3)
                t2 = ops.concat([gb.sample, xa], axis=3)
                t3 = ops.concat([xb, ga.sample], axis=3)
                features = None
                stack = [t0, t1, t2, t3]

            if config.mess11:
                t0 = ops.concat([xa, xb, ga.sample, gb.sample], axis=3)
                ugbga = ga.reuse(ugb)
                ugagb = gb.reuse(uga)

                t1 = ops.concat([ga.sample, gb.sample, uga, ugb], axis=3)
                features = None
                stack = [t0, t1]

            if config.mess12:
                t0 = ops.concat([xb, xa], axis=3)
                t2 = ops.concat([gb.sample, ga.sample], axis=3)
                f0 = ops.concat([za, zb], axis=3)
                f2 = ops.concat([zua, zub], axis=3)
                features = ops.concat([f0, f2], axis=0)
                stack = [t0, t2]

            if config.mess13:
                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": zua})
                features = None
                t0 = ops.concat([xa, gb.sample], axis=3)
                t1 = ops.concat([ga.sample, xb], axis=3)
                t2 = ops.concat([uga, ugb], axis=3)
                stack = [t0, t1, t2]

            if config.mess14:
                features = None
                t0 = ops.concat([xa, gb.sample], axis=3)
                t1 = ops.concat([ga.sample, xb], axis=3)
                stack = [t0, t1]

            stacked = ops.concat(stack, axis=0)
            d = self.create_component(config.discriminator,
                                      name='alia_discriminator',
                                      input=stacked,
                                      features=[features])
            l = self.create_loss(config.loss, d, xa_input, ga.sample,
                                 len(stack))

            d_vars = d.variables()
            if config.same_g:
                g_vars = ga.variables()
            else:
                g_vars = ga.variables() + gb.variables()

            d_loss = l.d_loss
            g_loss = l.g_loss

            metrics = {'g_loss': l.g_loss, 'd_loss': l.d_loss}

            if (config.alpha):
                #t0 = ops.concat([zua,zub], axis=3)
                #t1 = ops.concat([za,zb], axis=3)
                t0 = zua
                t1 = za
                t2 = zb
                netzd = tf.concat(axis=0, values=[t0, t1, t2])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator',
                                            input=netzd)

                print("Z_D", z_d)
                lz = self.create_component(config.loss,
                                           discriminator=z_d,
                                           x=xa_input,
                                           generator=ga,
                                           split=2)
                d_loss += lz.d_loss
                g_loss += lz.g_loss
                d_vars += z_d.variables()
                metrics["a_gloss"] = lz.g_loss
                metrics["a_dloss"] = lz.d_loss

            if (config.mess13):
                t0 = ops.concat([xb, ga.sample], axis=3)
                t1 = ops.concat([gb.sample, xa], axis=3)
                t2 = ops.concat([ugb, uga], axis=3)
                stack = [t0, t1, t2]
                features = None
                stacked = tf.concat(axis=0, values=stack)
                d2 = self.create_component(config.discriminator,
                                           name='align_2',
                                           input=stacked,
                                           features=[features])
                lz = self.create_loss(config.loss, d2, xa_input, ga.sample,
                                      len(stack))
                d_vars += d2.variables()

                d_loss += lz.d_loss
                g_loss += lz.g_loss
                metrics["mess13_g"] = lz.g_loss
                metrics["mess13_d"] = lz.d_loss

            if (config.mess14):
                t0 = ops.concat([xb, xa], axis=3)
                t2 = ops.concat([ugb, uga], axis=3)
                stack = [t0, t2]
                features = None
                stacked = tf.concat(axis=0, values=stack)
                d3 = self.create_component(config.discriminator,
                                           name='align_3',
                                           input=stacked,
                                           features=[features])
                lz = self.create_loss(config.loss, d3, xa_input, ga.sample,
                                      len(stack))
                d_vars += d3.variables()

                d_loss += lz.d_loss
                g_loss += lz.g_loss
                metrics["mess14_g"] = lz.g_loss
                metrics["mess14_d"] = lz.d_loss

            loss = hc.Config({'sample': [d_loss, g_loss], 'metrics': metrics})
            trainer = ConsensusTrainer(self,
                                       config.trainer,
                                       loss=loss,
                                       g_vars=g_vars,
                                       d_vars=d_vars)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = ga
        self.encoder = gb  # this is the other gan
        self.uniform_distribution = hc.Config({"sample": zb})  #uniform_encoder
        self.zb = zb
        self.z_hat = gb.sample
        self.x_input = xa_input
        self.autoencoded_x = xa_hat

        self.cyca = xa_hat
        self.cycb = xb_hat
        self.xba = xba
        self.xab = xab
        self.uga = uga
        self.ugb = ugb

        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
Exemplo n.º 25
0
    def create(self):
        config = self.config
        ops = self.ops
        d_losses = []
        g_losses = []

        with tf.device(self.device):
            x_input = tf.identity(self.inputs.x, name='input')

            if config.u_to_z:
                latent = UniformDistribution(self, config.z_distribution)
            else:
                z_shape = self.ops.shape(encoder.sample)
                uz_shape = z_shape
                uz_shape[-1] = uz_shape[-1] // len(config.z_distribution.projections)
                latent = UniformDistribution(self, config.z_distribution, output_shape=uz_shape)
            self.uniform_distribution = latent 
            self.latent = latent
            direction, slider = self.create_controls(self.ops.shape(latent.sample))
            z = latent.sample + slider * direction

            u_to_z = self.create_component(config.u_to_z, name='u_to_z', input=z)
            generator = self.create_component(config.generator, input=u_to_z.sample, name='generator')
            stacked = [x_input, generator.sample]
            self.generator = generator

            encoder = self.create_encoder(self.inputs.x)

            self.encoder = encoder
            features = [encoder.sample, u_to_z.sample]

            reencode_u_to_z = self.create_encoder(generator.sample, reuse=True)
            reencode_u_to_z_to_g= self.create_component(config.generator, input=reencode_u_to_z.sample, name='generator', reuse=True)

            self.reencode_g = reencode_u_to_z_to_g

            x_hat = self.create_component(config.generator, input=encoder.sample, reuse=True, name='generator').sample
            reencode_x_hat_to_z = self.create_encoder(x_hat, reuse=True)
            self.uniform_sample = generator.sample

            d_vars = []
            g_vars = generator.variables() + encoder.variables()
            g_vars += u_to_z.variables()

            def ali(*stack, reuse=False):
                xs=[t for t,_ in stack]
                zs=[t for _,t in stack]
                xs=tf.concat(xs,axis=0)
                zs=tf.concat(zs,axis=0)

                discriminator = self.create_component(config.discriminator, name='discriminator', input=xs, features=[zs], reuse=reuse)
                loss = self.create_loss(config.loss, discriminator, None, None, len(stack))
                return loss,discriminator

            def d(name, stack):
                if name is None:
                    name = config
                stacked = tf.concat(stack,axis=0)
                discriminator = self.create_component(config[name], name=name, input=stacked)
                loss = self.create_loss(config.loss, discriminator, None, None, len(stack))
                return loss,discriminator

            l1, d1 = ali([self.inputs.x,encoder.sample],[generator.sample,u_to_z.sample],[reencode_u_to_z_to_g.sample, reencode_u_to_z.sample])
            l2, d2 = ali([self.inputs.x,tf.zeros_like(encoder.sample)],[generator.sample,tf.zeros_like(u_to_z.sample)],[reencode_u_to_z_to_g.sample, tf.zeros_like(reencode_u_to_z.sample)], reuse=True)
            l3, d3 = ali([tf.zeros_like(self.inputs.x),encoder.sample],[tf.zeros_like(generator.sample),u_to_z.sample],[tf.zeros_like(reencode_u_to_z_to_g.sample), reencode_u_to_z.sample], reuse=True)

            l4, d4 = d('x_discriminator', [self.inputs.x, generator.sample, reencode_u_to_z_to_g.sample])
            l5, d5 = d('z_discriminator', [encoder.sample, u_to_z.sample, reencode_u_to_z.sample])
            self.discriminator = d1
            self.loss = l1

            beta = config.beta or 0.9
            d_losses = [beta * (l1.d_loss - l2.d_loss - l3.d_loss) + l4.d_loss + 2*l5.d_loss]
            g_losses = [beta * (l1.g_loss - l2.g_loss - l3.g_loss) + l4.g_loss + 2*l5.g_loss]

            if config.alternate:
                d_losses = [beta * (l1.d_loss - l2.d_loss - l3.d_loss) + l2.d_loss + 2*l3.d_loss]
                g_losses = [beta * (l1.g_loss - l2.g_loss - l3.g_loss) + l2.g_loss + 2*l3.g_loss]

            if config.mutual_only:
                d_losses = [2*l1.d_loss - l2.d_loss - l3.d_loss)]
                g_losses = [2*l1.g_loss - l2.g_loss - l3.g_loss)]


            self.add_metric("ld", d_losses[0])
            self.add_metric("lg", g_losses[0])

            if config.mutual_only or config.alternate:
                for d in [d1]:
                    d_vars += d.variables()
            else:
                for d in [d1,d4,d5]:
                    d_vars += d.variables()

            self._g_vars = g_vars
            self._d_vars = d_vars

            loss = hc.Config({
                'd_fake':l1.d_fake,
                'd_real':l1.d_real,
                'sample': [tf.add_n(d_losses), tf.add_n(g_losses)]
            })
Exemplo n.º 26
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            if config.same_g:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = hc.Config({
                    "sample": ga.reuse(xa_input),
                    "controls": {
                        "z": ga.controls['z']
                    },
                    "reuse": ga.reuse
                })
            else:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = self.create_component(config.generator,
                                           input=xa_input,
                                           name='b_generator')

            za = ga.controls["z"]
            zb = gb.controls["z"]

            self.uniform_sample = ga.sample

            xba = ga.sample
            xab = gb.sample
            xa_hat = ga.reuse(gb.sample)
            xb_hat = gb.reuse(ga.sample)

            z_shape = self.ops.shape(za)
            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(
                config.z_distribution.projections)
            ue = UniformDistribution(self,
                                     config.z_distribution,
                                     output_shape=uz_shape)
            ue2 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            ue3 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            ue4 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            print('ue', ue.sample)

            zua = ue.sample
            zub = ue2.sample

            ga2 = self.create_component(config.generator,
                                        input=tf.zeros_like(xb_input),
                                        name='ua_generator')
            gb2 = self.create_component(config.generator,
                                        input=tf.zeros_like(xb_input),
                                        name='ub_generator')
            uga = ga2.reuse(tf.zeros_like(xb_input),
                            replace_controls={"z": zua})
            ugb = gb2.reuse(tf.zeros_like(xb_input),
                            replace_controls={"z": zub})

            ugbprime = gb.reuse(uga)
            ugbzb = gb.controls['z']

            xa = xa_input
            xb = xb_input

            t0 = xb
            t1 = gb.sample
            t2 = ugbprime
            f0 = za
            f1 = zb
            f2 = ugbzb
            stack = [t0, t1, t2]
            stacked = ops.concat(stack, axis=0)
            features = ops.concat([f0, f1, f2], axis=0)

            xa_hat = ugbprime

            d = self.create_component(config.discriminator,
                                      name='d_ab',
                                      input=stacked,
                                      features=[features])
            l = self.create_loss(config.loss, d, xa_input, ga.sample,
                                 len(stack))
            loss1 = l
            d_loss1 = l.d_loss
            g_loss1 = l.g_loss

            d_vars1 = d.variables()
            g_vars1 = ga2.variables() + gb2.variables() + gb.variables(
            ) + ga.variables()  #gb.variables()# + gb.variables()

            d_loss = l.d_loss
            g_loss = l.g_loss

            metrics = {'g_loss': l.g_loss, 'd_loss': l.d_loss}

            g_vars2 = ga.variables()
            trainers = []

            if config.cyc:
                t0 = xa
                t1 = ga.sample
                f0 = zb
                f1 = za
                stack = [t0, t1]
                stacked = ops.concat(stack, axis=0)
                features = ops.concat([f0, f1], axis=0)

                d = self.create_component(config.discriminator,
                                          name='d2_ab',
                                          input=stacked,
                                          features=[features])
                l = self.create_loss(config.loss, d, xa_input, ga.sample,
                                     len(stack))

                d_loss1 += l.d_loss
                g_loss1 += l.g_loss
                d_vars1 += d.variables()
                metrics["gloss2"] = l.g_loss
                metrics["dloss2"] = l.d_loss
                #loss2=l
                #g_loss2 = loss2.g_loss
                #d_loss2 = loss2.d_loss

            if (config.alpha):
                t0 = zub
                t1 = zb
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            if (config.ug):
                t0 = xb
                t1 = ugb
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.discriminator,
                                            name='ug_discriminator',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            if (config.alpha2):
                t0 = zua
                t1 = za
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator2',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            if config.ug2:
                t0 = xa
                t1 = uga
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.discriminator,
                                            name='ug_discriminator2',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            if config.ug3:
                t0 = tf.concat(values=[xa, xb], axis=3)
                t1 = tf.concat(values=[uga, ugb], axis=3)
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.discriminator,
                                            name='ug_discriminator3',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            lossa = hc.Config({
                'sample': [d_loss1, g_loss1],
                'metrics': metrics
            })
            #lossb = hc.Config({'sample': [d_loss2, g_loss2], 'metrics': metrics})
            trainers += [
                ConsensusTrainer(self,
                                 config.trainer,
                                 loss=lossa,
                                 g_vars=g_vars1,
                                 d_vars=d_vars1)
            ]
            #trainers += [ConsensusTrainer(self, config.trainer, loss = lossb, g_vars = g_vars2, d_vars = d_vars2)]
            trainer = MultiTrainerTrainer(trainers)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = ga
        self.encoder = gb  # this is the other gan
        self.uniform_distribution = hc.Config({"sample": zb})  #uniform_encoder
        self.zb = zb
        self.z_hat = gb.sample
        self.x_input = xa_input
        self.autoencoded_x = xa_hat

        self.cyca = xa_hat
        self.cycb = xb_hat
        self.xba = xba
        self.xab = xab
        self.uga = uga
        self.ugb = ugb

        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
Exemplo n.º 27
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            if config.same_g:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = hc.Config({
                    "sample": ga.reuse(xa_input),
                    "controls": {
                        "z": ga.controls['z']
                    },
                    "reuse": ga.reuse
                })
            else:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = self.create_component(config.generator,
                                           input=xa_input,
                                           name='b_generator')

            za = ga.controls["z"]
            zb = gb.controls["z"]

            self.uniform_sample = gb.sample

            xba = ga.sample
            xab = gb.sample
            xa_hat = ga.reuse(gb.sample)
            xb_hat = gb.reuse(ga.sample)
            xa = xa_input
            xb = xb_input

            if config.ignore_random:
                t0 = xb
                t1 = gb.sample
                f0 = za
                f1 = zb
                stack = [t0, t1]
                stacked = ops.concat(stack, axis=0)
                features = ops.concat([f0, f1], axis=0)
                self.inputs.x = xb
                ugb = gb.sample
                zub = zb
                sourcezub = zb

            else:
                z_shape = self.ops.shape(za)
                uz_shape = z_shape
                uz_shape[-1] = uz_shape[-1] // len(
                    config.z_distribution.projections)
                ue = UniformDistribution(self,
                                         config.z_distribution,
                                         output_shape=uz_shape)
                ue2 = UniformDistribution(self,
                                          config.z_distribution,
                                          output_shape=uz_shape)
                ue3 = UniformDistribution(self,
                                          config.z_distribution,
                                          output_shape=uz_shape)
                ue4 = UniformDistribution(self,
                                          config.z_distribution,
                                          output_shape=uz_shape)
                print('ue', ue.sample)

                zua = ue.sample
                zub = ue2.sample

                ue2 = UniformDistribution(
                    self,
                    config.z_distribution,
                    output_shape=[self.ops.shape(za)[0], config.source_linear])
                zub = ue2.sample
                uz_to_gz = self.create_component(config.uz_to_gz,
                                                 name='uzb_to_gzb',
                                                 input=zub)
                zub = uz_to_gz.sample
                sourcezub = zub
                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": zub})

                t0 = xb
                t1 = gb.sample
                t2 = ugb
                f0 = za
                f1 = zb
                f2 = zub
                stack = [t0, t1, t2]
                stacked = ops.concat(stack, axis=0)
                features = ops.concat([f0, f1, f2], axis=0)

            d = self.create_component(config.discriminator,
                                      name='d_ab',
                                      input=stacked,
                                      features=[features])
            l = self.create_loss(config.loss, d, xa_input, ga.sample,
                                 len(stack))
            loss1 = l
            d_loss1 = l.d_loss
            g_loss1 = l.g_loss

            d_vars1 = d.variables()
            g_vars1 = gb.variables() + ga.variables()
            if not config.ignore_random:
                g_vars1 += uz_to_gz.variables(
                )  #gb.variables()# + gb.variables()

            d_loss = l.d_loss
            g_loss = l.g_loss

            metrics = {'g_loss': l.g_loss, 'd_loss': l.d_loss}

            if config.inline_alpha:
                t0 = zub
                t1 = zb
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator',
                                            input=netzd)
                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_vars1 += z_d.variables()
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            trainers = []
            if config.separate_alpha:
                t0 = zub
                t1 = zb
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator',
                                            input=netzd)
                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                g_vars1 = gb.variables() + ga.variables(
                )  #gb.variables()# + gb.variables()
                trainers += [
                    ConsensusTrainer(self,
                                     config.trainer,
                                     loss=loss3,
                                     g_vars=uz_to_gz.variables(),
                                     d_vars=z_d.variables())
                ]

            lossa = hc.Config({
                'sample': [d_loss1, g_loss1],
                'metrics': metrics
            })
            #lossb = hc.Config({'sample': [d_loss2, g_loss2], 'metrics': metrics})
            trainers += [
                ConsensusTrainer(self,
                                 config.trainer,
                                 loss=lossa,
                                 g_vars=g_vars1,
                                 d_vars=d_vars1)
            ]
            #trainers += [ConsensusTrainer(self, config.trainer, loss = lossb, g_vars = g_vars2, d_vars = d_vars2)]
            trainer = MultiTrainerTrainer(trainers)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = ga
        self.encoder = hc.Config({"sample": ugb})  # this is the other gan
        self.uniform_distribution = hc.Config({"sample":
                                               zub})  #uniform_encoder
        self.uniform_distribution_source = hc.Config({"sample": sourcezub
                                                      })  #uniform_encoder
        self.zb = zb
        self.z_hat = gb.sample
        self.x_input = xa_input
        self.autoencoded_x = xa_hat

        self.cyca = xa_hat
        self.cycb = xb_hat
        self.xba = xba
        self.xab = xab
        self.uga = ugb
        self.ugb = ugb

        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')