コード例 #1
0
    def test_elapsed_time_serialization(self):
        self.trainer.run()
        serialized_time = self.trainer.elapsed_time

        new_trainer = self._create_mock_trainer(5)
        testing.save_and_load_pth(self.trainer, new_trainer)

        new_trainer.run()
        self.assertGreater(new_trainer.elapsed_time, serialized_time)
コード例 #2
0
    def test_serialize_overwrite_different_names(self):
        self.summary.add({'a': 3., 'b': 1.})
        self.summary.add({'a': 1., 'b': 5.})

        summary = pytorch_trainer.reporter.DictSummary()
        summary.add({'c': 5.})
        testing.save_and_load_pth(self.summary, summary)

        self.check(summary, {
            'a': (3., 1.),
            'b': (1., 5.),
        })
コード例 #3
0
    def test_serialize_names_with_slash(self):
        self.summary.add({'a/b': 3., '/a/b': 1., 'a/b/': 4.})
        self.summary.add({'a/b': 1., '/a/b': 5., 'a/b/': 9.})
        self.summary.add({'a/b': 2., '/a/b': 6., 'a/b/': 5.})

        summary = pytorch_trainer.reporter.DictSummary()
        testing.save_and_load_pth(self.summary, summary)
        summary.add({'a/b': 3., '/a/b': 5., 'a/b/': 8.})

        self.check(summary, {
            'a/b': (3., 1., 2., 3.),
            '/a/b': (1., 5., 6., 5.),
            'a/b/': (4., 9., 5., 8.),
        })
コード例 #4
0
    def test_serialize(self):
        self.summary.add({'numpy': numpy.array(3, 'f'), 'int': 1, 'float': 4.})
        self.summary.add({'numpy': numpy.array(1, 'f'), 'int': 5, 'float': 9.})
        self.summary.add({'numpy': numpy.array(2, 'f'), 'int': 6, 'float': 5.})

        summary = pytorch_trainer.reporter.DictSummary()
        testing.save_and_load_pth(self.summary, summary)
        summary.add({'numpy': numpy.array(3, 'f'), 'int': 5, 'float': 8.})

        self.check(summary, {
            'numpy': (3., 1., 2., 3.),
            'int': (1, 5, 6, 5),
            'float': (4., 9., 5., 8.),
        })
コード例 #5
0
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_optimizer.param_groups = [{'x': None}]
        new_extension = extensions.LinearShift('x', self.value_range,
                                               self.time_range, new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((5, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_pth(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.param_groups[0]['x'],
                         self.optimizer.param_groups[0]['x'])
        self.assertIsInstance(new_optimizer.param_groups[0]['x'], float)
コード例 #6
0
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_optimizer.param_groups = [{'x': None}]
        new_extension = extensions.InverseShift('x', self.gamma, self.power,
                                                self.init, self.target,
                                                new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_pth(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.param_groups[0]['x'],
                         self.optimizer.param_groups[0]['x'])
        self.assertIsInstance(new_optimizer.param_groups[0]['x'], float)
コード例 #7
0
    def check_serialize(self, value1, value2, value3):
        self.summary.add(value1)
        self.summary.add(value2)

        summary = pytorch_trainer.reporter.Summary()
        testing.save_and_load_pth(self.summary, summary)
        summary.add(value3)

        expected_mean = (value1 + value2 + value3).to(dtype=torch.float) / 3.
        expected_std = ((value1 ** 2 + value2 ** 2 + value3 ** 2)
            .to(dtype=torch.float) / 3. - expected_mean ** 2).sqrt()

        mean = summary.compute_mean()
        testing.assert_allclose(mean, expected_mean)

        mean, std = summary.make_statistics()
        testing.assert_allclose(mean, expected_mean)
        testing.assert_allclose(std, expected_std)