示例#1
0
def resume_experiment(ctxt, saved_dir):
    """Resume a PyTorch experiment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        saved_dir (str): Path where snapshots are saved.

    """
    trainer = Trainer(snapshot_config=ctxt)
    trainer.restore(from_dir=saved_dir)
    trainer.resume()
示例#2
0
def test_dqn_cartpole(setup):
    tempdir = tempfile.TemporaryDirectory()
    config = SnapshotConfig(snapshot_dir=tempdir.name,
                            snapshot_mode='last',
                            snapshot_gap=1)

    trainer = Trainer(config)
    algo, env, _, n_epochs, batch_size = setup
    trainer.setup(algo, env, sampler_cls=LocalSampler)
    last_avg_return = trainer.train(n_epochs=n_epochs, batch_size=batch_size)
    assert last_avg_return > 10
    env.close()

    # test resume from snapshot
    trainer.restore(tempdir.name)
    trainer.resume(n_epochs=1, batch_size=batch_size)