Esempio n. 1
0
    def data_loop(self, loader, phase, step_method):
        for submodule in vars(self).values():
            if isinstance(submodule, nn.Module):
                getattr(submodule, phase)()

        all_results = [None] * len(loader)
        for batch_idx, batch in enumerate(tqdm(loader, ascii=True,
                                               desc=phase)):
            if isinstance(batch, torch.Tensor):
                batch = batch.to(self.device)

            elif type(batch) == list or type(batch) == tuple:
                for i in range(len(batch)):
                    batch[i] = batch[i].to(self.device)

            elif type(batch) == dict:
                for key, val in batch.items():
                    batch[key] = val.to(self.device)

            else:
                raise ValueError(
                    'Compatible Datasets have output elements of '
                    'type: torch.Tensor, list, tuple, or dict. All'
                    'collection compatible types must have '
                    'torch.Tensor items.')

            step_result = step_method(batch, batch_idx)
            Result.reset_phase()

            all_results[batch_idx] = step_result

        collected_results = Result.collect(all_results)
        return collected_results
Esempio n. 2
0
    def fit(self, training_dataset, validation_dataset):
        self.training_loader = DataLoader(training_dataset,
                                          batch_size=self.batch_size,
                                          shuffle=True,
                                          num_workers=self.n_cpus,
                                          pin_memory=(self.use_gpu
                                                      and self.pin_memory))

        self.validation_loader = DataLoader(validation_dataset,
                                            batch_size=self.batch_size,
                                            shuffle=False,
                                            num_workers=self.n_cpus,
                                            pin_memory=(self.use_gpu
                                                        and self.pin_memory))

        optimizers = self.configure_optimizers()
        self.optim_list, self.sched_list = dfs_detree_optimizer_list(
            optimizers)

        Result.optim_list = self.optim_list
        Result.reset_phase()

        self.best_validation = sys.float_info.max
        self.best_state = None
        leading_state = None

        print(self)

        self.on_fit_start()
        self.to_device()
        for self.current_epoch in range(self.n_epochs):
            train_outputs = self.data_loop(self.training_loader, 'train',
                                           self.training_step)
            self.training_epoch_end(train_outputs)

            for scheduler in self.sched_list:
                scheduler.step()

            with torch.no_grad():
                valid_outputs = self.data_loop(self.validation_loader, 'eval',
                                               self.validation_step)
                computed_valid = self.validation_epoch_end(valid_outputs)

                if computed_valid == None:
                    warnings.warn('\n If no aggregate loss is returned by '
                                  'validation_epoch_end, then Trainer can\'t '
                                  'keep track of best validation state.')

                elif computed_valid.item() < self.best_validation:
                    self.best_validation = computed_valid
                    leading_state = self.get_state_dicts()

            self.save_state()

        self.best_state = leading_state
        self.save_state()
        self.on_fit_end()
    def test_step(self):
        m = torch.nn.Linear(5, 5)
        x = torch.ones(5)
        opt1 = torch.optim.Adam(m.parameters(), lr=0.1)
        opt2 = torch.optim.Adam(m.parameters(), lr=0.1)
        Result.optim_list = [opt1, opt2]
        z = m.weight.data.clone()

        res = Result()
        y0 = m(x)
        y = torch.mean(y0 * y0)
        res.step(y)
        res2 = Result()
        self.assertEqual(res.optim_phase, 1)
        self.assertEqual(res2.optim_phase, 1)
        self.assertFalse(torch.equal(m.weight.data, z))
        z = m.weight.data.clone()

        y0 = m(x)
        y = torch.mean(y0 * y0)
        res.step(y)
        self.assertEqual(res.optim_phase, 0)
        self.assertEqual(res2.optim_phase, 0)
        self.assertFalse(torch.equal(m.weight.data, z))
        z = m.weight.data.clone()

        y0 = m(x)
        y = torch.mean(y0 * y0)
        res.step(y, 0)
        self.assertEqual(res.optim_phase, 0)
        self.assertEqual(res2.optim_phase, 0)
        self.assertFalse(torch.equal(m.weight.data, z))
Esempio n. 4
0
    def training_step(self, batch, batch_idx):
        x, y = batch
        y1 = self.model1(x) + self.const
        y2 = self.model2(y1)

        loss = self.crit(y2, y)

        res = Result()
        res.step(loss)
        res.abc = [torch.ones(2), torch.zeros(2)]
        return res
    def test_setattr(self):
        res = Result()

        res.abc = torch.ones(5, 5, 5)
        self.assertTrue(torch.equal(res.abc[0], torch.ones(5, 5, 5)))

        x = torch.tensor(3.14)
        res.abc = (x, x)
        self.assertTrue(iterable_torch_eq(res.abc, [(x, x)]))

        res.abc = (torch.ones(5).cuda(), torch.ones(5).cuda())
        self.assertFalse(res.abc[0][1].is_cuda)
Esempio n. 6
0
    def test(self, testing_dataset, chckpt_suffix=None):
        self.load_state(chckpt_suffix)

        self.testing_loader = DataLoader(testing_dataset,
                                         batch_size=self.batch_size,
                                         shuffle=False,
                                         num_workers=self.n_cpus,
                                         pin_memory=(self.use_gpu
                                                     and self.pin_memory))
        Result.reset_phase()

        self.on_test_start()
        self.to_device()
        with torch.no_grad():
            test_outputs = self.data_loop(self.testing_loader, 'eval',
                                          self.testing_step)
            self.testing_epoch_end(test_outputs)

        self.on_test_end()
    def _shared_step(self, batch, save_img, is_train):
        res = Result()

        X, Y_real = batch
        Y_fake = self.gen(X)

        if is_train:
            Y_pool = self.image_pool.query(Y_fake.detach())
        else:
            Y_pool = Y_fake
            res.recon_error = self.crit(Y_real, Y_fake)

        real_predict = self.disc(Y_real)
        fake_predict = self.disc(Y_pool)

        real_label = self.real_label.expand_as(real_predict)
        fake_label = self.fake_label.expand_as(fake_predict)

        disc_loss = 0.5 * (self.crit(real_predict, real_label) + \
                           self.crit(fake_predict, fake_label))

        if is_train:
            res.step(disc_loss)
        res.disc_loss = disc_loss

        gen_predict = self.disc(Y_fake)
        gen_loss = self.crit(gen_predict, real_label)

        if is_train:
            res.step(gen_loss)
        res.gen_loss = gen_loss

        if save_img:
            res.img = [
                self.un_normalize(X[:self.n_vis]),
                self.un_normalize(Y_fake[:self.n_vis]),
                self.un_normalize(Y_real[:self.n_vis])
            ]
        return res
    def test_collect(self):
        res1 = Result()
        res1.a = torch.tensor(5.)
        res1.b = [torch.zeros(3, 3)]
        res1.c = [torch.tensor(5.), torch.tensor(4.)]

        res2 = Result()
        res2.a = torch.tensor(5.)
        res2.b = torch.zeros(3, 3)
        res2.c = [torch.tensor(5.), torch.tensor(4.)]

        res3 = Result()
        res3.f = torch.zeros(5)
        res3.ab = [torch.ones(4)]
        coll_res = Result.collect([res1, res2, res3])

        self.assertTrue(
            iterable_torch_eq(
                coll_res.a,
                [torch.tensor(5.), torch.tensor(5.)]))
        self.assertTrue(
            iterable_torch_eq(
                coll_res.b,
                [[torch.zeros(3, 3)], torch.zeros(3, 3)]))
        self.assertTrue(
            iterable_torch_eq(
                coll_res.c,
                [[torch.tensor(5.), torch.tensor(4.)],
                 [torch.tensor(5.), torch.tensor(4.)]]))
        self.assertTrue(iterable_torch_eq(coll_res.f, [torch.zeros(5)]))
        self.assertTrue(iterable_torch_eq(coll_res.ab, [[torch.ones(4)]]))

        coll_res = Result.collect([res1, None, res2])
        self.assertEqual(coll_res, None)
Esempio n. 9
0
 def validation_step(self, batch, batch_idx):
     x = torch.ones(5, requires_grad=True)
     x = x * x
     res = Result()
     res.loss = x
     return res