コード例 #1
0
ファイル: test_pytorch_runner.py プロジェクト: wwxFromTju/ray
 def testSingleLoader(self):
     runner = PyTorchRunner(model_creator, single_loader, optimizer_creator,
                            loss_creator)
     runner.setup()
     runner.step()
     with self.assertRaises(ValueError):
         runner.validate()
コード例 #2
0
ファイル: test_pytorch_runner.py プロジェクト: wwxFromTju/ray
 def testStep(self):
     mock_function = MagicMock(return_value=dict(mean_accuracy=10))
     runner = PyTorchRunner(model_creator,
                            create_dataloaders,
                            optimizer_creator,
                            loss_creator,
                            train_function=mock_function)
     runner.setup()
     runner.step()
     runner.step()
     result = runner.step()
     self.assertEqual(mock_function.call_count, 3)
     self.assertEqual(result["epoch"], 3)
     self.assertEqual(runner.stats()["epoch"], 3)
コード例 #3
0
 def testNativeLoss(self):
     runner = PyTorchRunner(model_creator,
                            single_loader,
                            optimizer_creator,
                            loss_creator=nn.MSELoss)
     runner.setup()
     runner.step()
コード例 #4
0
    def testMultiLoaders(self):
        def three_data_loader(config):
            return (LinearDataset(2, 5), LinearDataset(2, 5, size=400),
                    LinearDataset(2, 5, size=400))

        runner = PyTorchRunner(model_creator, three_data_loader,
                               optimizer_creator, loss_creator)
        with self.assertRaises(ValueError):
            runner.setup()

        runner2 = PyTorchRunner(model_creator, three_data_loader,
                                optimizer_creator, loss_creator)
        with self.assertRaises(ValueError):
            runner2.setup()
コード例 #5
0
ファイル: test_pytorch_runner.py プロジェクト: wwxFromTju/ray
    def testMultiLoaders(self):
        def three_data_loader(batch_size, config):
            train_dataset = LinearDataset(2, 5)
            validation_dataset = LinearDataset(2, 5, size=400)
            train_loader = torch.utils.data.DataLoader(train_dataset)
            validation_loader = torch.utils.data.DataLoader(validation_dataset)
            return train_loader, validation_loader, validation_loader

        runner = PyTorchRunner(model_creator, three_data_loader,
                               optimizer_creator, loss_creator)
        with self.assertRaises(ValueError):
            runner.setup()

        runner2 = PyTorchRunner(model_creator, three_data_loader,
                                optimizer_creator, loss_creator)
        with self.assertRaises(ValueError):
            runner2.setup()
コード例 #6
0
ファイル: test_pytorch_runner.py プロジェクト: wwxFromTju/ray
    def testMultiModel(self):
        def multi_model_creator(config):
            return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)

        def multi_optimizer_creator(models, config):
            opts = [
                torch.optim.SGD(model.parameters(), lr=0.1) for model in models
            ]
            return opts[0], opts[1], opts[2]

        runner = PyTorchRunner(multi_model_creator, single_loader,
                               multi_optimizer_creator, loss_creator)
        runner.setup()
        with self.assertRaises(ValueError):
            runner.step()
コード例 #7
0
ファイル: test_pytorch_runner.py プロジェクト: wwxFromTju/ray
    def testGivens(self):
        def three_model_creator(config):
            return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)

        def three_optimizer_creator(models, config):
            opts = [
                torch.optim.SGD(model.parameters(), lr=0.1) for model in models
            ]
            return opts[0], opts[1], opts[2]

        runner = PyTorchRunner(three_model_creator, single_loader,
                               three_optimizer_creator, loss_creator)
        runner.setup()

        self.assertEqual(len(runner.given_models), 3)
        self.assertEqual(len(runner.given_optimizers), 3)

        runner2 = PyTorchRunner(model_creator, single_loader,
                                optimizer_creator, loss_creator)
        runner2.setup()

        self.assertNotEqual(runner2.given_models, runner2.models)
        self.assertNotEqual(runner2.given_optimizers, runner2.optimizers)
コード例 #8
0
ファイル: test_pytorch_runner.py プロジェクト: wwxFromTju/ray
 def testValidate(self):
     mock_function = MagicMock(returns=dict(mean_accuracy=10))
     runner = PyTorchRunner(model_creator,
                            create_dataloaders,
                            optimizer_creator,
                            loss_creator,
                            validation_function=mock_function)
     runner.setup()
     runner.step()
     runner.step()
     runner.step()
     self.assertEqual(mock_function.call_count, 0)
     runner.validate()
     self.assertTrue(mock_function.called)
     self.assertEqual(runner.stats()["epoch"], 3)