コード例 #1
0
ファイル: test_torch_runner.py プロジェクト: zqxyz73/ray
    def testValidate(self):
        class MockOperator(TrainingOperator):
            def setup(self, config):
                self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
                self.validate = MagicMock(returns=dict(mean_accuracy=10))

        runner = TorchRunner(model_creator,
                             create_dataloaders,
                             optimizer_creator,
                             loss_creator,
                             training_operator_cls=MockOperator)
        runner.setup()
        runner.train_epoch()
        runner.train_epoch()
        runner.train_epoch()
        self.assertEqual(runner.training_operator.validate.call_count, 0)
        runner.validate()
        self.assertTrue(runner.training_operator.validate.called)
        self.assertEqual(runner.stats()["epoch"], 3)
コード例 #2
0
ファイル: test_torch_runner.py プロジェクト: zqxyz73/ray
    def testtrain_epoch(self):
        class MockOperator(TrainingOperator):
            def setup(self, config):
                self.count = 0

            def train_epoch(self, *args, **kwargs):
                self.count += 1
                return {"count": self.count}

        runner = TorchRunner(model_creator,
                             create_dataloaders,
                             optimizer_creator,
                             loss_creator,
                             training_operator_cls=MockOperator)
        runner.setup()
        runner.train_epoch(num_steps=1)
        runner.train_epoch(num_steps=1)
        result = runner.train_epoch()
        self.assertEqual(runner.training_operator.count, 3)
        self.assertEqual(result["count"], 3)
        self.assertEqual(runner.stats()["epoch"], 3)