Ejemplo n.º 1
0
    def forward(self, data, state):
        real_img, fake_img, cycled_img, same_img = data
        total_loss = self._adversarial_loss(fake_img) + self._identity_loss(real_img, same_img) + self._cycle_loss(
            real_img, cycled_img)

        if self.average_loss:
            total_loss = reduce_mean(total_loss)

        return total_loss
Ejemplo n.º 2
0
    def forward(self, data, state):
        real_img, fake_img = data
        real_img_loss = tf.reduce_mean(self.loss_fn(tf.ones_like(real_img), real_img), axis=(1, 2))
        fake_img_loss = tf.reduce_mean(self.loss_fn(tf.zeros_like(real_img), fake_img), axis=(1, 2))
        total_loss = real_img_loss + fake_img_loss

        if self.average_loss:
            total_loss = reduce_mean(total_loss)

        return 0.5 * total_loss
Ejemplo n.º 3
0
    def forward(self, data, state):
        real_img, fake_img = data
        real_img_loss = torch.mean(self.loss_fn(real_img, torch.ones_like(real_img, device=self.device)), dim=(2, 3))
        fake_img_loss = torch.mean(self.loss_fn(fake_img, torch.zeros_like(real_img, device=self.device)), dim=(2, 3))
        total_loss = real_img_loss + fake_img_loss

        if self.average_loss:
            total_loss = reduce_mean(total_loss)

        return 0.5 * total_loss
Ejemplo n.º 4
0
    def forward(self, data, state):
        y_pred, y_style, y_content, image_out = data

        style_loss = [self.calculate_style_recon_loss(a, b) for a, b in zip(y_style['style'], y_pred['style'])]
        style_loss = torch.stack(style_loss, dim=0).sum(dim=0)
        style_loss *= self.style_weight

        content_loss = [
            self.calculate_feature_recon_loss(a, b) for a, b in zip(y_content['content'], y_pred['content'])
        ]
        content_loss = torch.stack(content_loss, dim=0).sum(dim=0)
        content_loss *= self.content_weight

        total_variation_reg = self.calculate_total_variation(image_out)
        total_variation_reg *= self.tv_weight
        loss = style_loss + content_loss + total_variation_reg

        if self.average_loss:
            loss = reduce_mean(loss)

        return loss
Ejemplo n.º 5
0
 def forward(self, data, state):
     return reduce_mean(data)
Ejemplo n.º 6
0
 def test_reduce_mean_axis(self):
     self.assertTrue(np.array_equal(reduce_mean(self.test_np, axis=0), [2, 2.5]))
Ejemplo n.º 7
0
 def test_reduce_mean_torch_value(self):
     self.assertTrue(np.array_equal(reduce_mean(self.test_torch).numpy(), 1))
Ejemplo n.º 8
0
 def test_reduce_mean_torch_type(self):
     self.assertIsInstance(reduce_mean(self.test_torch), torch.Tensor, 'Output type must be torch.Tensor')
Ejemplo n.º 9
0
 def test_reduce_mean_tf_type(self):
     self.assertIsInstance(reduce_mean(self.test_tf), tf.Tensor, 'Output type must be tf.Tensor')
Ejemplo n.º 10
0
 def test_reduce_mean_np_value(self):
     self.assertEqual(reduce_mean(self.test_np), 2.25)
Ejemplo n.º 11
0
 def test_reduce_mean_np_type(self):
     self.assertIsInstance(reduce_mean(self.test_np), np.ScalarType, 'Output type must be NumPy')