Ejemplo n.º 1
0
    def test_loss_init_WITH_NON_LOSS(self):
<<<<<<< HEAD
        self.assertRaises(ValueError, lambda: Loss("name", "shortname", "decoder_outputs", "decoder_targets", "loss"))

    def test_loss_backward_WITH_NO_LOSS(self):
        loss = Loss("name", "shortname", "decoder_output", "decoder_output", torch.nn.NLLLoss())
=======
        self.assertRaises(ValueError, lambda: Loss("name", "shortname",
                                                   "decoder_outputs", "decoder_targets", "loss"))

    def test_loss_backward_WITH_NO_LOSS(self):
        loss = Loss("name", "shortname", "decoder_output",
                    "decoder_output", torch.nn.NLLLoss())
>>>>>>> upstream/master
        self.assertRaises(ValueError, lambda: loss.backward())

    def test_nllloss_init(self):
        loss = NLLLoss()
        self.assertEqual(loss.name, NLLLoss._NAME)
        self.assertEqual(loss.log_name, NLLLoss._SHORTNAME)
<<<<<<< HEAD
        self.assertTrue(type(loss.criterion) is torch.nn.NLLLoss)
=======
        self.assertTrue(isinstance(loss.criterion, torch.nn.NLLLoss))
>>>>>>> upstream/master

    def test_nllloss(self):
        loss = NLLLoss()
        pytorch_loss = 0
        pytorch_criterion = torch.nn.NLLLoss()
Ejemplo n.º 2
0
 def test_loss_backward_WITH_NO_LOSS(self):
     loss = Loss("name", "shortname", "decoder_output", "decoder_output",
                 torch.nn.NLLLoss())
     self.assertRaises(ValueError, lambda: loss.backward())