def test_proxy_methods(self): with self.test_session(): gan = mock_gan() mock_sample = tf.constant(1., shape=[1, 1]) multi = MultiComponent(combine='concat', components=[ MockLoss(gan, sample=mock_sample), MockLoss(gan, sample=mock_sample) ]) multi.proxy() self.assertEqual(multi.proxy_called, [True, True])
def test_proxy_methods(self): with self.test_session(): gan = mock_gan() mock_sample = tf.constant(1., shape=[1,1]) multi = MultiComponent(combine='concat', components=[ MockLoss(gan, sample=mock_sample), MockLoss(gan, sample=mock_sample) ]) multi.proxy() self.assertEqual(multi.proxy_called, [True, True])
def build(self, net): gan = self.gan config = self.config self.d_variables = [] discs = [] self.kwargs["input"] = net self.kwargs["reuse"] = self.ops._reuse for i in range(config.discriminator_count or 0): name = self.ops.description + "_d_" + str(i) self.kwargs["name"] = name print(">>CREATING ", i) disc = config['discriminator_class'](gan, config, **self.kwargs) self.ops.add_weights(disc.variables()) self.d_variables += [disc.variables()] discs.append(disc) for i, dconfig in enumerate(config.discriminators): name = self.ops.description + "_d_" + str(i) self.kwargs["name"] = name disc = dconfig['class'](gan, dconfig, **self.kwargs) self.ops.add_weights(disc.variables()) self.d_variables += [disc.variables()] discs.append(disc) combine = MultiComponent(combine=self.config.combine or "concat", components=discs) self.sample = combine.sample self.children = discs return self.sample
def add_supervised_loss(self): if self.args.classloss: print("[discriminator] Class loss is on. Semi-supervised learning mode activated.") supervised_loss = SupervisedLoss(self.gan, self.gan.config.loss) self.gan.loss = MultiComponent(components=[supervised_loss, self.gan.loss], combine='add') #EWW else: print("[discriminator] Class loss is off. Unsupervised learning mode activated.")
def test_sample(self): with self.test_session(): gan = mock_gan() mock_sample = tf.constant(1., shape=[1, 1]) multi = MultiComponent(combine='concat', components=[ MockLoss(gan, sample=mock_sample), MockLoss(gan, sample=mock_sample) ]) gan.encoder = multi self.assertEqual(gan.ops.shape(multi.sample), [1, 2])
def test_combine_dict(self): with self.test_session(): gan = mock_gan() ops = gan.ops mock_sample = tf.constant(1., shape=[1]) multi = MultiComponent(combine='add', components=[ MockLoss(gan, sample={"a": "b"}), MockLoss(gan, sample={"b": "c"}) ]) self.assertEqual(len(multi.sample), 2) self.assertEqual(multi.sample['a'], 'b') self.assertEqual(multi.sample['b'], 'c')
def test_sample_loss(self): with self.test_session(): gan = mock_gan() ops = gan.ops mock_sample = tf.constant(1., shape=[1]) multi = MultiComponent(combine='add', components=[ MockLoss(gan, sample=[mock_sample, None]), MockLoss( gan, sample=[mock_sample, mock_sample]) ]) self.assertEqual(len(multi.sample), 2) self.assertEqual(ops.shape(multi.sample[0]), [1]) self.assertEqual(ops.shape(multi.sample[1]), [1])