def forward(self, data, state): real_img, fake_img, cycled_img, same_img = data total_loss = self._adversarial_loss(fake_img) + self._identity_loss(real_img, same_img) + self._cycle_loss( real_img, cycled_img) if self.average_loss: total_loss = reduce_mean(total_loss) return total_loss
def forward(self, data, state): real_img, fake_img = data real_img_loss = tf.reduce_mean(self.loss_fn(tf.ones_like(real_img), real_img), axis=(1, 2)) fake_img_loss = tf.reduce_mean(self.loss_fn(tf.zeros_like(real_img), fake_img), axis=(1, 2)) total_loss = real_img_loss + fake_img_loss if self.average_loss: total_loss = reduce_mean(total_loss) return 0.5 * total_loss
def forward(self, data, state): real_img, fake_img = data real_img_loss = torch.mean(self.loss_fn(real_img, torch.ones_like(real_img, device=self.device)), dim=(2, 3)) fake_img_loss = torch.mean(self.loss_fn(fake_img, torch.zeros_like(real_img, device=self.device)), dim=(2, 3)) total_loss = real_img_loss + fake_img_loss if self.average_loss: total_loss = reduce_mean(total_loss) return 0.5 * total_loss
def forward(self, data, state): y_pred, y_style, y_content, image_out = data style_loss = [self.calculate_style_recon_loss(a, b) for a, b in zip(y_style['style'], y_pred['style'])] style_loss = torch.stack(style_loss, dim=0).sum(dim=0) style_loss *= self.style_weight content_loss = [ self.calculate_feature_recon_loss(a, b) for a, b in zip(y_content['content'], y_pred['content']) ] content_loss = torch.stack(content_loss, dim=0).sum(dim=0) content_loss *= self.content_weight total_variation_reg = self.calculate_total_variation(image_out) total_variation_reg *= self.tv_weight loss = style_loss + content_loss + total_variation_reg if self.average_loss: loss = reduce_mean(loss) return loss
def forward(self, data, state): return reduce_mean(data)
def test_reduce_mean_axis(self): self.assertTrue(np.array_equal(reduce_mean(self.test_np, axis=0), [2, 2.5]))
def test_reduce_mean_torch_value(self): self.assertTrue(np.array_equal(reduce_mean(self.test_torch).numpy(), 1))
def test_reduce_mean_torch_type(self): self.assertIsInstance(reduce_mean(self.test_torch), torch.Tensor, 'Output type must be torch.Tensor')
def test_reduce_mean_tf_type(self): self.assertIsInstance(reduce_mean(self.test_tf), tf.Tensor, 'Output type must be tf.Tensor')
def test_reduce_mean_np_value(self): self.assertEqual(reduce_mean(self.test_np), 2.25)
def test_reduce_mean_np_type(self): self.assertIsInstance(reduce_mean(self.test_np), np.ScalarType, 'Output type must be NumPy')