Пример #1
0
    def test_save_sup_load_rl(self):
        pass

        model_to_save = MockModel(spinn.spinn_core_model.BaseModel,
                                  default_args())

        # Parse command line flags.
        get_flags()
        FLAGS(sys.argv)

        log_temp = tempfile.NamedTemporaryFile()
        ckpt_temp = tempfile.NamedTemporaryFile()

        logger = afs_safe_logger.ProtoLogger(log_temp.name)
        FLAGS.ckpt_path = '.'

        trainer_to_save = ModelTrainer(model_to_save, logger, FLAGS)

        model_to_load = MockModel(spinn.rl_spinn.BaseModel, default_args())
        trainer_to_load = ModelTrainer(model_to_load, logger, FLAGS)

        # Save to and load from temporary file.
        trainer_to_save.save(ckpt_temp.name)
        trainer_to_load.load(ckpt_temp.name, cpu=True)

        compare_models(model_to_save, model_to_load)

        # Cleanup temporary file.
        ckpt_temp.close()
        log_temp.close()
Пример #2
0
    def test_save_load_model(self):
        model_to_save = MockModel(BaseModel, default_args())
        model_to_load = MockModel(BaseModel, default_args())

        # Save to and load from temporary file.
        temp = tempfile.NamedTemporaryFile()
        torch.save(model_to_save.state_dict(), temp.name)
        model_to_load.load_state_dict(torch.load(temp.name))

        compare_models(model_to_save, model_to_load)

        # Cleanup temporary file.
        temp.close()
Пример #3
0
    def test_save_load_model(self):
        scalar = 11
        other_scalar = 0
        model_to_save = MockModel(scalar)
        model_to_load = MockModel(other_scalar)

        # Save to and load from temporary file.
        temp = tempfile.NamedTemporaryFile()
        torch.save(model_to_save.state_dict(), temp.name)
        model_to_load.load_state_dict(torch.load(temp.name))

        compare_models(model_to_save, model_to_load)

        # Check value of scalars.
        assert model_to_save.scalar[0] == 11
        assert model_to_save.scalar[0] == model_to_load.scalar[0]

        # Cleanup temporary file.
        temp.close()
Пример #4
0
    def test_custom_init(self):

        # Concrete class that uses custom init.
        class MyModel(nn.Module):
            def __init__(self):
                super(MyModel, self).__init__()
                self.l = CustomLinear(10, 10)

        model_to_save = MyModel()
        model_to_load = MyModel()

        # Save to and load from temporary file.
        temp = tempfile.NamedTemporaryFile()
        torch.save(model_to_save.state_dict(), temp.name)
        model_to_load.load_state_dict(torch.load(temp.name))

        compare_models(model_to_save, model_to_load)

        # Cleanup temporary file.
        temp.close()
Пример #5
0
    def test_save_sup_load_rl(self):
        pass

        model_to_save = MockModel(spinn.pyramid.Pyramid, default_args())
        opt_to_save = optim.SGD(model_to_save.parameters(), lr=0.1)
        trainer_to_save = ModelTrainer(model_to_save, opt_to_save)

        model_to_load = MockModel(spinn.pyramid.Pyramid, default_args())
        opt_to_load = optim.SGD(model_to_load.parameters(), lr=0.1)
        trainer_to_load = ModelTrainer(model_to_load, opt_to_load)

        # Save to and load from temporary file.
        temp = tempfile.NamedTemporaryFile()
        trainer_to_save.save(temp.name, 0, 0)
        trainer_to_load.load(temp.name)

        compare_models(model_to_save, model_to_load)

        # Cleanup temporary file.
        temp.close()
Пример #6
0
    def test_save_sup_load_rl(self):
        scalar = 11
        other_scalar = 0

        model_to_save = MockModel(spinn.fat_stack.BaseModel, default_args())
        opt_to_save = optim.SGD(model_to_save.parameters(), lr=0.1)
        trainer_to_save = ModelTrainer(model_to_save, opt_to_save)

        model_to_load = MockModel(spinn.rl_spinn.BaseModel, default_args())
        opt_to_load = optim.SGD(model_to_load.parameters(), lr=0.1)
        trainer_to_load = ModelTrainer(model_to_load, opt_to_load)

        # Save to and load from temporary file.
        temp = tempfile.NamedTemporaryFile()
        trainer_to_save.save(temp.name, 0, 0)
        trainer_to_load.load(temp.name)

        compare_models(model_to_save, model_to_load)

        # Cleanup temporary file.
        temp.close()