Exemplo n.º 1
0
    def test_proxy_methods(self):
        with self.test_session():
            gan = mock_gan()
            mock_sample = tf.constant(1., shape=[1, 1])
            multi = MultiComponent(combine='concat',
                                   components=[
                                       MockLoss(gan, sample=mock_sample),
                                       MockLoss(gan, sample=mock_sample)
                                   ])

            multi.proxy()
            self.assertEqual(multi.proxy_called, [True, True])
    def test_proxy_methods(self):
        with self.test_session():
            gan = mock_gan()
            mock_sample = tf.constant(1., shape=[1,1])
            multi = MultiComponent(combine='concat',
                components=[
                    MockLoss(gan, sample=mock_sample),
                    MockLoss(gan, sample=mock_sample)
            ])

            multi.proxy()
            self.assertEqual(multi.proxy_called, [True, True])
Exemplo n.º 3
0
    def build(self, net):
        gan = self.gan
        config = self.config
        self.d_variables = []

        discs = []
        self.kwargs["input"] = net
        self.kwargs["reuse"] = self.ops._reuse
        for i in range(config.discriminator_count or 0):
            name = self.ops.description + "_d_" + str(i)
            self.kwargs["name"] = name
            print(">>CREATING ", i)
            disc = config['discriminator_class'](gan, config, **self.kwargs)
            self.ops.add_weights(disc.variables())
            self.d_variables += [disc.variables()]

            discs.append(disc)

        for i, dconfig in enumerate(config.discriminators):
            name = self.ops.description + "_d_" + str(i)
            self.kwargs["name"] = name
            disc = dconfig['class'](gan, dconfig, **self.kwargs)

            self.ops.add_weights(disc.variables())
            self.d_variables += [disc.variables()]
            discs.append(disc)

        combine = MultiComponent(combine=self.config.combine or "concat",
                                 components=discs)
        self.sample = combine.sample
        self.children = discs
        return self.sample
Exemplo n.º 4
0
 def add_supervised_loss(self):
     if self.args.classloss:
         print("[discriminator] Class loss is on.  Semi-supervised learning mode activated.")
         supervised_loss = SupervisedLoss(self.gan, self.gan.config.loss)
         self.gan.loss = MultiComponent(components=[supervised_loss, self.gan.loss], combine='add')
         #EWW
     else:
         print("[discriminator] Class loss is off.  Unsupervised learning mode activated.")
Exemplo n.º 5
0
    def test_sample(self):
        with self.test_session():
            gan = mock_gan()
            mock_sample = tf.constant(1., shape=[1, 1])
            multi = MultiComponent(combine='concat',
                                   components=[
                                       MockLoss(gan, sample=mock_sample),
                                       MockLoss(gan, sample=mock_sample)
                                   ])

            gan.encoder = multi
            self.assertEqual(gan.ops.shape(multi.sample), [1, 2])
Exemplo n.º 6
0
 def test_combine_dict(self):
     with self.test_session():
         gan = mock_gan()
         ops = gan.ops
         mock_sample = tf.constant(1., shape=[1])
         multi = MultiComponent(combine='add',
                                components=[
                                    MockLoss(gan, sample={"a": "b"}),
                                    MockLoss(gan, sample={"b": "c"})
                                ])
         self.assertEqual(len(multi.sample), 2)
         self.assertEqual(multi.sample['a'], 'b')
         self.assertEqual(multi.sample['b'], 'c')
Exemplo n.º 7
0
 def test_sample_loss(self):
     with self.test_session():
         gan = mock_gan()
         ops = gan.ops
         mock_sample = tf.constant(1., shape=[1])
         multi = MultiComponent(combine='add',
                                components=[
                                    MockLoss(gan,
                                             sample=[mock_sample, None]),
                                    MockLoss(
                                        gan,
                                        sample=[mock_sample, mock_sample])
                                ])
         self.assertEqual(len(multi.sample), 2)
         self.assertEqual(ops.shape(multi.sample[0]), [1])
         self.assertEqual(ops.shape(multi.sample[1]), [1])