Beispiel #1
0
def get_model(model_checkpoint_path):
    checkpoint_dict = Trainer.load_checkpoint_from_path(model_checkpoint_path)
    model_state = checkpoint_dict["model_state_dict"]

    model = ResNet18(None)
    model.conv1 = nn.Conv2d(1,
                            64,
                            kernel_size=7,
                            stride=1,
                            padding=3,
                            bias=False)
    model.load_state_dict(model_state)

    return model
Beispiel #2
0
def test_load_checkpoint_from_path(ray_start_2_cpus, tmpdir):
    config = TestConfig()

    checkpoint_strategy = CheckpointStrategy(checkpoint_score_attribute="loss",
                                             checkpoint_score_order="min")

    def train_func_checkpoint():
        train.save_checkpoint(loss=3)
        train.save_checkpoint(loss=7)

    trainer = Trainer(config, num_workers=2, logdir=tmpdir)
    trainer.start()
    trainer.run(train_func_checkpoint, checkpoint_strategy=checkpoint_strategy)

    assert trainer.best_checkpoint["loss"] == 3
    assert (Trainer.load_checkpoint_from_path(
        trainer.best_checkpoint_path) == trainer.best_checkpoint)