Esempio n. 1
0
    def training_step(self, batch, batch_idx):
        x, y = batch
        y1 = self.model1(x) + self.const
        y2 = self.model2(y1)

        loss = self.crit(y2, y)

        res = Result()
        res.step(loss)
        res.abc = [torch.ones(2), torch.zeros(2)]
        return res
    def test_step(self):
        m = torch.nn.Linear(5, 5)
        x = torch.ones(5)
        opt1 = torch.optim.Adam(m.parameters(), lr=0.1)
        opt2 = torch.optim.Adam(m.parameters(), lr=0.1)
        Result.optim_list = [opt1, opt2]
        z = m.weight.data.clone()

        res = Result()
        y0 = m(x)
        y = torch.mean(y0 * y0)
        res.step(y)
        res2 = Result()
        self.assertEqual(res.optim_phase, 1)
        self.assertEqual(res2.optim_phase, 1)
        self.assertFalse(torch.equal(m.weight.data, z))
        z = m.weight.data.clone()

        y0 = m(x)
        y = torch.mean(y0 * y0)
        res.step(y)
        self.assertEqual(res.optim_phase, 0)
        self.assertEqual(res2.optim_phase, 0)
        self.assertFalse(torch.equal(m.weight.data, z))
        z = m.weight.data.clone()

        y0 = m(x)
        y = torch.mean(y0 * y0)
        res.step(y, 0)
        self.assertEqual(res.optim_phase, 0)
        self.assertEqual(res2.optim_phase, 0)
        self.assertFalse(torch.equal(m.weight.data, z))
    def _shared_step(self, batch, save_img, is_train):
        res = Result()

        X, Y_real = batch
        Y_fake = self.gen(X)

        if is_train:
            Y_pool = self.image_pool.query(Y_fake.detach())
        else:
            Y_pool = Y_fake
            res.recon_error = self.crit(Y_real, Y_fake)

        real_predict = self.disc(Y_real)
        fake_predict = self.disc(Y_pool)

        real_label = self.real_label.expand_as(real_predict)
        fake_label = self.fake_label.expand_as(fake_predict)

        disc_loss = 0.5 * (self.crit(real_predict, real_label) + \
                           self.crit(fake_predict, fake_label))

        if is_train:
            res.step(disc_loss)
        res.disc_loss = disc_loss

        gen_predict = self.disc(Y_fake)
        gen_loss = self.crit(gen_predict, real_label)

        if is_train:
            res.step(gen_loss)
        res.gen_loss = gen_loss

        if save_img:
            res.img = [
                self.un_normalize(X[:self.n_vis]),
                self.un_normalize(Y_fake[:self.n_vis]),
                self.un_normalize(Y_real[:self.n_vis])
            ]
        return res