Exemple #1
0
    def create_trainer(self, cycloss, z_cycloss, encoder, generator,
                       encoder_loss, standard_loss, standard_discriminator,
                       encoder_discriminator):

        metrics = []
        metrics.append(standard_loss.metrics)

        d_vars = standard_discriminator.variables()
        g_vars = generator.variables() + encoder.variables()
        print("D_VARS", d_vars)
        print("G_VARS", g_vars)
        #d_loss = standard_loss.d_loss
        #g_loss = standard_loss.g_loss + cycloss
        loss1 = ("g_loss", standard_loss.g_loss)
        loss2 = ("d_loss", standard_loss.d_loss)
        loss = hc.Config({
            'sample': [standard_loss.d_loss, standard_loss.g_loss],
            'metrics': {
                'g_loss': loss1[1],
                'd_loss': loss2[1]
            }
        })
        trainer = ConsensusTrainer(self,
                                   self.config.trainer,
                                   loss=loss,
                                   g_vars=g_vars,
                                   d_vars=d_vars)
        return trainer
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            ga = self.create_component(config.generator,
                                       input=xb_input,
                                       name='a_generator')
            gb = self.create_component(config.generator,
                                       input=xa_input,
                                       name='b_generator')

            za = ga.controls["z"]
            zb = gb.controls["z"]

            self.uniform_sample = ga.sample

            xba = ga.sample
            xab = gb.sample
            xa_hat = ga.reuse(gb.sample)
            xb_hat = gb.reuse(ga.sample)

            z_shape = self.ops.shape(za)
            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(
                config.z_distribution.projections)
            ue = UniformDistribution(self,
                                     config.z_distribution,
                                     output_shape=uz_shape)
            features_a = ops.concat([ga.sample, xa_input], axis=0)
            features_b = ops.concat([gb.sample, xb_input], axis=0)
            stacked_a = ops.concat([xa_input, ga.sample], axis=0)
            stacked_b = ops.concat([xb_input, gb.sample], axis=0)
            stacked_z = ops.concat([ue.sample, za, zb], axis=0)
            da = self.create_component(config.discriminator,
                                       name='a_discriminator',
                                       input=stacked_a,
                                       features=[features_b])
            db = self.create_component(config.discriminator,
                                       name='b_discriminator',
                                       input=stacked_b,
                                       features=[features_a])
            dz = self.create_component(config.z_discriminator,
                                       name='z_discriminator',
                                       input=stacked_z)

            if config.ali_z:
                features_alia = ops.concat([za, ue.sample], axis=0)
                features_alib = ops.concat([zb, ue.sample], axis=0)
                uga = ga.reuse(tf.zeros_like(xb_input),
                               replace_controls={"z": ue.sample})
                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": ue.sample})
                stacked_alia = ops.concat([xa_input, uga], axis=0)
                stacked_alib = ops.concat([xb_input, ugb], axis=0)

                dalia = self.create_component(config.ali_discriminator,
                                              name='alia_discriminator',
                                              input=stacked_alia,
                                              features=[features_alib])
                dalib = self.create_component(config.ali_discriminator,
                                              name='alib_discriminator',
                                              input=stacked_alib,
                                              features=[features_alia])
                lalia = self.create_loss(config.loss, dalia, None, None, 2)
                lalib = self.create_loss(config.loss, dalib, None, None, 2)

            la = self.create_loss(config.loss, da, xa_input, ga.sample, 2)
            lb = self.create_loss(config.loss, db, xb_input, gb.sample, 2)
            lz = self.create_loss(config.loss, dz, None, None, 3)

            d_vars = da.variables() + db.variables() + lz.variables()
            if config.ali_z:
                d_vars += dalia.variables() + dalib.variables()
            g_vars = ga.variables() + gb.variables()

            d_loss = la.d_loss + lb.d_loss + lz.d_loss
            g_loss = la.g_loss + lb.g_loss + lz.g_loss
            metrics = {
                'ga_loss': la.g_loss,
                'gb_loss': lb.g_loss,
                'gz_loss': lz.g_loss,
                'da_loss': la.d_loss,
                'db_loss': lb.d_loss,
                'dz_loss': lz.d_loss
            }
            if config.ali_z:
                d_loss += lalib.d_loss + lalia.d_loss
                g_loss += lalib.g_loss + lalia.g_loss
                metrics['galia_loss'] = lalia.g_loss
                metrics['galib_loss'] = lalib.g_loss
                metrics['dalia_loss'] = lalia.d_loss
                metrics['dalib_loss'] = lalib.d_loss

            loss = hc.Config({'sample': [d_loss, g_loss], 'metrics': metrics})
            trainer = ConsensusTrainer(self,
                                       config.trainer,
                                       loss=loss,
                                       g_vars=g_vars,
                                       d_vars=d_vars)
            self.initialize_variables()

        self.trainer = trainer
        self.generator = ga
        self.encoder = gb  # this is the other gan
        self.uniform_distribution = hc.Config({"sample": za})  #uniform_encoder
        self.zb = zb
        self.z_hat = gb.sample
        self.x_input = xa_input
        self.autoencoded_x = xa_hat

        self.cyca = xa_hat
        self.cycb = xb_hat
        self.xba = xba
        self.xab = xab

        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
Exemple #3
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            if config.same_g:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = hc.Config({
                    "sample": ga.reuse(xa_input),
                    "controls": {
                        "z": ga.controls['z']
                    },
                    "reuse": ga.reuse
                })
            else:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = self.create_component(config.generator,
                                           input=xa_input,
                                           name='b_generator')

            za = ga.controls["z"]
            zb = gb.controls["z"]

            self.uniform_sample = ga.sample

            xba = ga.sample
            xab = gb.sample
            xa_hat = ga.reuse(gb.sample)
            xb_hat = gb.reuse(ga.sample)

            z_shape = self.ops.shape(za)
            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(
                config.z_distribution.projections)
            ue = UniformDistribution(self,
                                     config.z_distribution,
                                     output_shape=uz_shape)
            ue2 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            ue3 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            ue4 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            print('ue', ue.sample)

            zua = ue.sample
            zub = ue2.sample

            uga = ga.reuse(tf.zeros_like(xb_input),
                           replace_controls={"z": zua})
            ugb = gb.reuse(tf.zeros_like(xa_input),
                           replace_controls={"z": zub})

            xa = xa_input
            xb = xb_input

            t0 = ops.concat([xb, xa], axis=3)
            t1 = ops.concat([ugb, uga], axis=3)
            t2 = ops.concat([gb.sample, ga.sample], axis=3)
            f0 = ops.concat([za, zb], axis=3)
            f1 = ops.concat([zub, zua], axis=3)
            f2 = ops.concat([zb, za], axis=3)
            features = ops.concat([f0, f1, f2], axis=0)
            stack = [t0, t1, t2]

            if config.mess2:
                xbxa = ops.concat([xb_input, xa_input], axis=3)
                gbga = ops.concat([gb.sample, ga.sample], axis=3)
                fa = ops.concat([za, zb], axis=3)
                fb = ops.concat([za, zb], axis=3)
                features = ops.concat([fa, fb], axis=0)
                stack = [xbxa, gbga]

            if config.mess6:
                t0 = ops.concat([xb, xa], axis=3)

                t1 = ops.concat([gb.sample, uga], axis=3)
                t2 = ops.concat([gb.sample, xa], axis=3)
                t3 = ops.concat([xb, ga.sample], axis=3)
                t4 = ops.concat([ugb, ga.sample], axis=3)
                features = None
                stack = [t0, t1, t2, t3, t4]

            if config.mess7:
                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": zua})
                t0 = ops.concat([xb, ga.sample], axis=3)
                t1 = ops.concat([ugb, uga], axis=3)
                t2 = ops.concat([gb.sample, xa], axis=3)
                features = None
                stack = [t0, t1, t2]

            if config.mess8:
                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": zua})
                uga2 = ga.reuse(tf.zeros_like(xa_input),
                                replace_controls={"z": zub})
                ugb2 = gb.reuse(tf.zeros_like(xa_input),
                                replace_controls={"z": zub})
                t0 = ops.concat([xb, ga.sample, xa, gb.sample], axis=3)
                t1 = ops.concat([ugb, uga, uga2, ugb2], axis=3)
                features = None
                stack = [t0, t1]

            if config.mess10:
                t0 = ops.concat([xb, xa], axis=3)

                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": zua})
                t1 = ops.concat([ugb, uga], axis=3)
                t2 = ops.concat([gb.sample, xa], axis=3)
                t3 = ops.concat([xb, ga.sample], axis=3)
                features = None
                stack = [t0, t1, t2, t3]

            if config.mess11:
                t0 = ops.concat([xa, xb, ga.sample, gb.sample], axis=3)
                ugbga = ga.reuse(ugb)
                ugagb = gb.reuse(uga)

                t1 = ops.concat([ga.sample, gb.sample, uga, ugb], axis=3)
                features = None
                stack = [t0, t1]

            if config.mess12:
                t0 = ops.concat([xb, xa], axis=3)
                t2 = ops.concat([gb.sample, ga.sample], axis=3)
                f0 = ops.concat([za, zb], axis=3)
                f2 = ops.concat([zua, zub], axis=3)
                features = ops.concat([f0, f2], axis=0)
                stack = [t0, t2]

            if config.mess13:
                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": zua})
                features = None
                t0 = ops.concat([xa, gb.sample], axis=3)
                t1 = ops.concat([ga.sample, xb], axis=3)
                t2 = ops.concat([uga, ugb], axis=3)
                stack = [t0, t1, t2]

            if config.mess14:
                features = None
                t0 = ops.concat([xa, gb.sample], axis=3)
                t1 = ops.concat([ga.sample, xb], axis=3)
                stack = [t0, t1]

            stacked = ops.concat(stack, axis=0)
            d = self.create_component(config.discriminator,
                                      name='alia_discriminator',
                                      input=stacked,
                                      features=[features])
            l = self.create_loss(config.loss, d, xa_input, ga.sample,
                                 len(stack))

            d_vars = d.variables()
            if config.same_g:
                g_vars = ga.variables()
            else:
                g_vars = ga.variables() + gb.variables()

            d_loss = l.d_loss
            g_loss = l.g_loss

            metrics = {'g_loss': l.g_loss, 'd_loss': l.d_loss}

            if (config.alpha):
                #t0 = ops.concat([zua,zub], axis=3)
                #t1 = ops.concat([za,zb], axis=3)
                t0 = zua
                t1 = za
                t2 = zb
                netzd = tf.concat(axis=0, values=[t0, t1, t2])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator',
                                            input=netzd)

                print("Z_D", z_d)
                lz = self.create_component(config.loss,
                                           discriminator=z_d,
                                           x=xa_input,
                                           generator=ga,
                                           split=2)
                d_loss += lz.d_loss
                g_loss += lz.g_loss
                d_vars += z_d.variables()
                metrics["a_gloss"] = lz.g_loss
                metrics["a_dloss"] = lz.d_loss

            if (config.mess13):
                t0 = ops.concat([xb, ga.sample], axis=3)
                t1 = ops.concat([gb.sample, xa], axis=3)
                t2 = ops.concat([ugb, uga], axis=3)
                stack = [t0, t1, t2]
                features = None
                stacked = tf.concat(axis=0, values=stack)
                d2 = self.create_component(config.discriminator,
                                           name='align_2',
                                           input=stacked,
                                           features=[features])
                lz = self.create_loss(config.loss, d2, xa_input, ga.sample,
                                      len(stack))
                d_vars += d2.variables()

                d_loss += lz.d_loss
                g_loss += lz.g_loss
                metrics["mess13_g"] = lz.g_loss
                metrics["mess13_d"] = lz.d_loss

            if (config.mess14):
                t0 = ops.concat([xb, xa], axis=3)
                t2 = ops.concat([ugb, uga], axis=3)
                stack = [t0, t2]
                features = None
                stacked = tf.concat(axis=0, values=stack)
                d3 = self.create_component(config.discriminator,
                                           name='align_3',
                                           input=stacked,
                                           features=[features])
                lz = self.create_loss(config.loss, d3, xa_input, ga.sample,
                                      len(stack))
                d_vars += d3.variables()

                d_loss += lz.d_loss
                g_loss += lz.g_loss
                metrics["mess14_g"] = lz.g_loss
                metrics["mess14_d"] = lz.d_loss

            loss = hc.Config({'sample': [d_loss, g_loss], 'metrics': metrics})
            trainer = ConsensusTrainer(self,
                                       config.trainer,
                                       loss=loss,
                                       g_vars=g_vars,
                                       d_vars=d_vars)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = ga
        self.encoder = gb  # this is the other gan
        self.uniform_distribution = hc.Config({"sample": zb})  #uniform_encoder
        self.zb = zb
        self.z_hat = gb.sample
        self.x_input = xa_input
        self.autoencoded_x = xa_hat

        self.cyca = xa_hat
        self.cycb = xb_hat
        self.xba = xba
        self.xab = xab
        self.uga = uga
        self.ugb = ugb

        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
Exemple #4
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            if config.same_g:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = hc.Config({
                    "sample": ga.reuse(xa_input),
                    "controls": {
                        "z": ga.controls['z']
                    },
                    "reuse": ga.reuse
                })
            else:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = self.create_component(config.generator,
                                           input=xa_input,
                                           name='b_generator')

            za = ga.controls["z"]
            zb = gb.controls["z"]

            self.uniform_sample = ga.sample

            xba = ga.sample
            xab = gb.sample
            xa_hat = ga.reuse(gb.sample)
            xb_hat = gb.reuse(ga.sample)

            z_shape = self.ops.shape(za)
            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(
                config.z_distribution.projections)
            ue = UniformDistribution(self,
                                     config.z_distribution,
                                     output_shape=uz_shape)
            ue2 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            ue3 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            ue4 = UniformDistribution(self,
                                      config.z_distribution,
                                      output_shape=uz_shape)
            print('ue', ue.sample)

            zua = ue.sample
            zub = ue2.sample

            ga2 = self.create_component(config.generator,
                                        input=tf.zeros_like(xb_input),
                                        name='ua_generator')
            gb2 = self.create_component(config.generator,
                                        input=tf.zeros_like(xb_input),
                                        name='ub_generator')
            uga = ga2.reuse(tf.zeros_like(xb_input),
                            replace_controls={"z": zua})
            ugb = gb2.reuse(tf.zeros_like(xb_input),
                            replace_controls={"z": zub})

            ugbprime = gb.reuse(uga)
            ugbzb = gb.controls['z']

            xa = xa_input
            xb = xb_input

            t0 = xb
            t1 = gb.sample
            t2 = ugbprime
            f0 = za
            f1 = zb
            f2 = ugbzb
            stack = [t0, t1, t2]
            stacked = ops.concat(stack, axis=0)
            features = ops.concat([f0, f1, f2], axis=0)

            xa_hat = ugbprime

            d = self.create_component(config.discriminator,
                                      name='d_ab',
                                      input=stacked,
                                      features=[features])
            l = self.create_loss(config.loss, d, xa_input, ga.sample,
                                 len(stack))
            loss1 = l
            d_loss1 = l.d_loss
            g_loss1 = l.g_loss

            d_vars1 = d.variables()
            g_vars1 = ga2.variables() + gb2.variables() + gb.variables(
            ) + ga.variables()  #gb.variables()# + gb.variables()

            d_loss = l.d_loss
            g_loss = l.g_loss

            metrics = {'g_loss': l.g_loss, 'd_loss': l.d_loss}

            g_vars2 = ga.variables()
            trainers = []

            if config.cyc:
                t0 = xa
                t1 = ga.sample
                f0 = zb
                f1 = za
                stack = [t0, t1]
                stacked = ops.concat(stack, axis=0)
                features = ops.concat([f0, f1], axis=0)

                d = self.create_component(config.discriminator,
                                          name='d2_ab',
                                          input=stacked,
                                          features=[features])
                l = self.create_loss(config.loss, d, xa_input, ga.sample,
                                     len(stack))

                d_loss1 += l.d_loss
                g_loss1 += l.g_loss
                d_vars1 += d.variables()
                metrics["gloss2"] = l.g_loss
                metrics["dloss2"] = l.d_loss
                #loss2=l
                #g_loss2 = loss2.g_loss
                #d_loss2 = loss2.d_loss

            if (config.alpha):
                t0 = zub
                t1 = zb
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            if (config.ug):
                t0 = xb
                t1 = ugb
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.discriminator,
                                            name='ug_discriminator',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            if (config.alpha2):
                t0 = zua
                t1 = za
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator2',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            if config.ug2:
                t0 = xa
                t1 = uga
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.discriminator,
                                            name='ug_discriminator2',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            if config.ug3:
                t0 = tf.concat(values=[xa, xb], axis=3)
                t1 = tf.concat(values=[uga, ugb], axis=3)
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.discriminator,
                                            name='ug_discriminator3',
                                            input=netzd)

                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                d_vars1 += z_d.variables()
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            lossa = hc.Config({
                'sample': [d_loss1, g_loss1],
                'metrics': metrics
            })
            #lossb = hc.Config({'sample': [d_loss2, g_loss2], 'metrics': metrics})
            trainers += [
                ConsensusTrainer(self,
                                 config.trainer,
                                 loss=lossa,
                                 g_vars=g_vars1,
                                 d_vars=d_vars1)
            ]
            #trainers += [ConsensusTrainer(self, config.trainer, loss = lossb, g_vars = g_vars2, d_vars = d_vars2)]
            trainer = MultiTrainerTrainer(trainers)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = ga
        self.encoder = gb  # this is the other gan
        self.uniform_distribution = hc.Config({"sample": zb})  #uniform_encoder
        self.zb = zb
        self.z_hat = gb.sample
        self.x_input = xa_input
        self.autoencoded_x = xa_hat

        self.cyca = xa_hat
        self.cycb = xb_hat
        self.xba = xba
        self.xab = xab
        self.uga = uga
        self.ugb = ugb

        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
Exemple #5
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            x_input = tf.identity(self.inputs.x, name='input')

            def random_like(x):
                return UniformDistribution(
                    self,
                    config.z_distribution,
                    output_shape=self.ops.shape(x)).sample

            # q(z|x)
            encoder = self.create_component(config.encoder,
                                            input=x_input,
                                            name='z_encoder')
            print("ENCODER ", encoder.sample)

            self.encoder = encoder
            z_shape = self.ops.shape(encoder.sample)
            style = self.create_component(config.style_encoder,
                                          input=x_input,
                                          name='style')
            self.style = style
            self.styleb = style  # hack for sampler
            self.random_style = random_like(style.sample)

            uz_shape = z_shape
            uz_shape[-1] = uz_shape[-1] // len(
                config.z_distribution.projections)
            UniformDistribution = UniformDistribution(self,
                                                      config.z_distribution,
                                                      output_shape=uz_shape)
            direction, slider = self.create_controls(
                self.ops.shape(UniformDistribution.sample))
            z = UniformDistribution.sample + slider * direction

            #projected_encoder = UniformDistribution(self, config.encoder, z=encoder.sample)

            feature_dim = len(ops.shape(z)) - 1
            #stack_z = tf.concat([encoder.sample, z], feature_dim)
            #stack_encoded = tf.concat([encoder.sample, encoder.sample], feature_dim)
            stack_z = z

            generator = self.create_component(config.generator,
                                              input=stack_z,
                                              features=[style.sample])
            self.uniform_sample = generator.sample
            x_hat = generator.reuse(encoder.sample)

            features_zs = ops.concat([encoder.sample, z], axis=0)
            stacked_xg = ops.concat([x_input, generator.sample], axis=0)

            if config.u_to_z:
                u_to_z = self.create_component(config.u_to_z,
                                               name='u_to_z',
                                               input=random_like(z))
                gu = generator.reuse(u_to_z.sample)
                stacked_xg = ops.concat([x_input, gu], axis=0)
                features_zs = ops.concat([encoder.sample, u_to_z.sample],
                                         axis=0)

            standard_discriminator = self.create_component(
                config.discriminator,
                name='discriminator',
                input=stacked_xg,
                features=[features_zs])
            standard_loss = self.create_loss(config.loss,
                                             standard_discriminator, x_input,
                                             generator, 2)

            l = standard_loss
            d_loss1 = l.d_loss
            g_loss1 = l.g_loss

            d_vars1 = standard_discriminator.variables()
            g_vars1 = generator.variables() + encoder.variables(
            ) + style.variables()

            metrics = {'g_loss': l.g_loss, 'd_loss': l.d_loss}
            t0 = random_like(
                style.sample
            )  #tf.concat([random_like(style1), random_like(style1)], axis=1)
            t1 = style.sample  #tf.concat([style1, style2], axis=1)
            stack = [t0, t1]
            stacked = ops.concat(stack, axis=0)
            features = None
            z_d = self.create_component(config.z_discriminator,
                                        name='forcerandom_discriminator',
                                        input=stacked)
            loss3 = self.create_component(config.loss,
                                          discriminator=z_d,
                                          x=x_input,
                                          generator=generator,
                                          split=2)
            metrics["forcerandom_gloss"] = loss3.g_loss
            metrics["forcerandom_dloss"] = loss3.d_loss
            if config.forcerandom:
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss
                d_vars1 += z_d.variables()

            if config.u_to_z:
                g_vars1 += u_to_z.variables()

            lossa = hc.Config({
                'sample': [d_loss1, g_loss1],
                'metrics': metrics
            })
            trainer = ConsensusTrainer(self,
                                       config.trainer,
                                       loss=lossa,
                                       g_vars=g_vars1,
                                       d_vars=d_vars1)

            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = generator
        self.uniform_distribution = uniform_encoder
        self.slider = slider
        self.direction = direction
        self.z = z
        self.z_hat = encoder.sample
        self.x_input = x_input
        self.autoencoded_x = x_hat
        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
        self.random_z = tf.random_uniform(ops.shape(
            UniformDistribution.sample),
                                          -1,
                                          1,
                                          name='random_z')

        if hasattr(generator, 'mask_generator'):
            self.mask_generator = generator.mask_generator
            self.mask = mask
            self.autoencode_mask = generator.mask_generator.sample
            self.autoencode_mask_3_channel = generator.mask
Exemple #6
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            if config.same_g:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = hc.Config({
                    "sample": ga.reuse(xa_input),
                    "controls": {
                        "z": ga.controls['z']
                    },
                    "reuse": ga.reuse
                })
            else:
                ga = self.create_component(config.generator,
                                           input=xb_input,
                                           name='a_generator')
                gb = self.create_component(config.generator,
                                           input=xa_input,
                                           name='b_generator')

            za = ga.controls["z"]
            zb = gb.controls["z"]

            self.uniform_sample = gb.sample

            xba = ga.sample
            xab = gb.sample
            xa_hat = ga.reuse(gb.sample)
            xb_hat = gb.reuse(ga.sample)
            xa = xa_input
            xb = xb_input

            if config.ignore_random:
                t0 = xb
                t1 = gb.sample
                f0 = za
                f1 = zb
                stack = [t0, t1]
                stacked = ops.concat(stack, axis=0)
                features = ops.concat([f0, f1], axis=0)
                self.inputs.x = xb
                ugb = gb.sample
                zub = zb
                sourcezub = zb

            else:
                z_shape = self.ops.shape(za)
                uz_shape = z_shape
                uz_shape[-1] = uz_shape[-1] // len(
                    config.z_distribution.projections)
                ue = UniformDistribution(self,
                                         config.z_distribution,
                                         output_shape=uz_shape)
                ue2 = UniformDistribution(self,
                                          config.z_distribution,
                                          output_shape=uz_shape)
                ue3 = UniformDistribution(self,
                                          config.z_distribution,
                                          output_shape=uz_shape)
                ue4 = UniformDistribution(self,
                                          config.z_distribution,
                                          output_shape=uz_shape)
                print('ue', ue.sample)

                zua = ue.sample
                zub = ue2.sample

                ue2 = UniformDistribution(
                    self,
                    config.z_distribution,
                    output_shape=[self.ops.shape(za)[0], config.source_linear])
                zub = ue2.sample
                uz_to_gz = self.create_component(config.uz_to_gz,
                                                 name='uzb_to_gzb',
                                                 input=zub)
                zub = uz_to_gz.sample
                sourcezub = zub
                ugb = gb.reuse(tf.zeros_like(xa_input),
                               replace_controls={"z": zub})

                t0 = xb
                t1 = gb.sample
                t2 = ugb
                f0 = za
                f1 = zb
                f2 = zub
                stack = [t0, t1, t2]
                stacked = ops.concat(stack, axis=0)
                features = ops.concat([f0, f1, f2], axis=0)

            d = self.create_component(config.discriminator,
                                      name='d_ab',
                                      input=stacked,
                                      features=[features])
            l = self.create_loss(config.loss, d, xa_input, ga.sample,
                                 len(stack))
            loss1 = l
            d_loss1 = l.d_loss
            g_loss1 = l.g_loss

            d_vars1 = d.variables()
            g_vars1 = gb.variables() + ga.variables()
            if not config.ignore_random:
                g_vars1 += uz_to_gz.variables(
                )  #gb.variables()# + gb.variables()

            d_loss = l.d_loss
            g_loss = l.g_loss

            metrics = {'g_loss': l.g_loss, 'd_loss': l.d_loss}

            if config.inline_alpha:
                t0 = zub
                t1 = zb
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator',
                                            input=netzd)
                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                d_vars1 += z_d.variables()
                d_loss1 += loss3.d_loss
                g_loss1 += loss3.g_loss

            trainers = []
            if config.separate_alpha:
                t0 = zub
                t1 = zb
                netzd = tf.concat(axis=0, values=[t0, t1])
                z_d = self.create_component(config.z_discriminator,
                                            name='z_discriminator',
                                            input=netzd)
                loss3 = self.create_component(config.loss,
                                              discriminator=z_d,
                                              x=xa_input,
                                              generator=ga,
                                              split=2)
                metrics["za_gloss"] = loss3.g_loss
                metrics["za_dloss"] = loss3.d_loss
                g_vars1 = gb.variables() + ga.variables(
                )  #gb.variables()# + gb.variables()
                trainers += [
                    ConsensusTrainer(self,
                                     config.trainer,
                                     loss=loss3,
                                     g_vars=uz_to_gz.variables(),
                                     d_vars=z_d.variables())
                ]

            lossa = hc.Config({
                'sample': [d_loss1, g_loss1],
                'metrics': metrics
            })
            #lossb = hc.Config({'sample': [d_loss2, g_loss2], 'metrics': metrics})
            trainers += [
                ConsensusTrainer(self,
                                 config.trainer,
                                 loss=lossa,
                                 g_vars=g_vars1,
                                 d_vars=d_vars1)
            ]
            #trainers += [ConsensusTrainer(self, config.trainer, loss = lossb, g_vars = g_vars2, d_vars = d_vars2)]
            trainer = MultiTrainerTrainer(trainers)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = ga
        self.encoder = hc.Config({"sample": ugb})  # this is the other gan
        self.uniform_distribution = hc.Config({"sample":
                                               zub})  #uniform_encoder
        self.uniform_distribution_source = hc.Config({"sample": sourcezub
                                                      })  #uniform_encoder
        self.zb = zb
        self.z_hat = gb.sample
        self.x_input = xa_input
        self.autoencoded_x = xa_hat

        self.cyca = xa_hat
        self.cycb = xb_hat
        self.xba = xba
        self.xab = xab
        self.uga = ugb
        self.ugb = ugb

        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')
Exemple #7
0
    def create(self):
        config = self.config
        ops = self.ops

        with tf.device(self.device):
            #x_input = tf.identity(self.inputs.x, name='input')
            xa_input = tf.identity(self.inputs.xa, name='xa_i')
            xb_input = tf.identity(self.inputs.xb, name='xb_i')

            def random_like(x):
                return UniformDistribution(
                    self,
                    config.z_distribution,
                    output_shape=self.ops.shape(x)).sample

            #y=a
            #x=b
            zgx = self.create_component(config.encoder,
                                        input=xa_input,
                                        name='xa_to_x')
            zgy = self.create_component(config.encoder,
                                        input=xb_input,
                                        name='xb_to_y')
            zx = zgx.sample
            zy = zgy.sample

            z_noise = random_like(zx)
            n_noise = random_like(zx)
            if config.style:
                stylex = self.create_component(config.style_discriminator,
                                               input=xb_input,
                                               name='xb_style')
                styley = self.create_component(config.style_discriminator,
                                               input=xa_input,
                                               name='xa_style')
                zy = tf.concat(values=[zy, z_noise], axis=3)
                zx = tf.concat(values=[zx, n_noise], axis=3)
                gy = self.create_component(config.generator,
                                           features=[styley.sample],
                                           input=zy,
                                           name='gy_generator')
                y = hc.Config({"sample": xa_input})
                zx = self.create_component(config.encoder,
                                           input=y.sample,
                                           name='xa_to_x',
                                           reuse=True).sample
                zx = tf.concat(values=[zx, z_noise], axis=3)
                gx = self.create_component(config.generator,
                                           features=[stylex.sample],
                                           input=zx,
                                           name='gx_generator')
            else:
                gy = self.create_component(config.generator,
                                           features=[z_noise],
                                           input=zy,
                                           name='gy_generator')
                y = hc.Config({"sample": xa_input})
                zx = self.create_component(config.encoder,
                                           input=y.sample,
                                           name='xa_to_x',
                                           reuse=True).sample
                gx = self.create_component(config.generator,
                                           features=[z_noise],
                                           input=zx,
                                           name='gx_generator')
                stylex = hc.Config({"sample": random_like(y.sample)})

            self.y = y
            self.gy = gy
            self.gx = gx

            ga = gy
            gb = gx

            self.uniform_sample = gb.sample

            xba = ga.sample
            xab = gb.sample
            xa_hat = ga.reuse(zx)
            xb_hat = gb.reuse(zy)
            xa = xa_input
            xb = xb_input

            self.styleb = stylex
            self.random_style = random_like(stylex.sample)

            t0 = xb
            t1 = gx.sample
            f0 = gy.sample
            f1 = y.sample
            stack = [t0, t1]
            stacked = ops.concat(stack, axis=0)
            features = ops.concat([f0, f1], axis=0)
            self.inputs.x = xa
            ugb = gb.reuse(random_like(zy))
            zub = zy
            sourcezub = zy

            d = self.create_component(config.discriminator,
                                      name='d_ab',
                                      input=stacked,
                                      features=[features])
            l = self.create_loss(config.loss, d, xa_input, ga.sample,
                                 len(stack))
            loss1 = l
            d_loss1 = l.d_loss
            g_loss1 = l.g_loss

            d_vars1 = d.variables()
            g_vars1 = gb.variables() + ga.variables() + zgx.variables(
            ) + zgy.variables()

            d_loss = l.d_loss
            g_loss = l.g_loss

            metrics = {'g_loss': l.g_loss, 'd_loss': l.d_loss}

            trainers = []

            lossa = hc.Config({
                'sample': [d_loss1, g_loss1],
                'metrics': metrics
            })
            #lossb = hc.Config({'sample': [d_loss2, g_loss2], 'metrics': metrics})
            trainers += [
                ConsensusTrainer(self,
                                 config.trainer,
                                 loss=lossa,
                                 g_vars=g_vars1,
                                 d_vars=d_vars1)
            ]
            #trainers += [ConsensusTrainer(self, config.trainer, loss = lossb, g_vars = g_vars2, d_vars = d_vars2)]
            trainer = MultiTrainerTrainer(trainers)
            self.session.run(tf.global_variables_initializer())

        self.trainer = trainer
        self.generator = gb
        self.encoder = hc.Config({"sample": ugb})  # this is the other gan
        self.uniform_distribution = hc.Config({"sample":
                                               zub})  #uniform_encoder
        self.uniform_distribution_source = hc.Config({"sample": sourcezub
                                                      })  #uniform_encoder
        self.zb = zy
        self.z_hat = gb.sample
        self.x_input = xa_input
        self.autoencoded_x = xb_hat

        self.cyca = xa_hat
        self.cycb = xb_hat
        self.xba = xba
        self.xab = xab
        self.uga = y.sample
        self.ugb = ugb

        rgb = tf.cast((self.generator.sample + 1) * 127.5, tf.int32)
        self.generator_int = tf.bitwise.bitwise_or(rgb,
                                                   0xFF000000,
                                                   name='generator_int')