def resume_experiment(ctxt, saved_dir): """Resume a PyTorch experiment. Args: ctxt (garage.experiment.ExperimentContext): The experiment configuration used by Trainer to create the snapshotter. saved_dir (str): Path where snapshots are saved. """ trainer = Trainer(snapshot_config=ctxt) trainer.restore(from_dir=saved_dir) trainer.resume()
def test_dqn_cartpole(setup): tempdir = tempfile.TemporaryDirectory() config = SnapshotConfig(snapshot_dir=tempdir.name, snapshot_mode='last', snapshot_gap=1) trainer = Trainer(config) algo, env, _, n_epochs, batch_size = setup trainer.setup(algo, env, sampler_cls=LocalSampler) last_avg_return = trainer.train(n_epochs=n_epochs, batch_size=batch_size) assert last_avg_return > 10 env.close() # test resume from snapshot trainer.restore(tempdir.name) trainer.resume(n_epochs=1, batch_size=batch_size)