logger.log_tabular('Time', time.time() - start_time) logger.dump_tabular() if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--env', type=str, default='HalfCheetah-v2') parser.add_argument('--hid', type=int, default=64) parser.add_argument('--l', type=int, default=2) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--seed', '-s', type=int, default=0) parser.add_argument('--cpu', type=int, default=4) parser.add_argument('--steps', type=int, default=4000) parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--exp_name', type=str, default='vpg') args = parser.parse_args() mpi_fork(args.cpu) # run parallel code with mpi from spinup.utils.run_utils import setup_logger_kwargs logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) vpg(lambda: gym.make(args.env), actor_critic=core.MLPActorCritic, ac_kwargs=dict(hidden_sizes=[args.hid] * args.l), gamma=args.gamma, seed=args.seed, steps_per_epoch=args.steps, epochs=args.epochs, logger_kwargs=logger_kwargs)
if __name__ == '__main__': parser = get_arg_parser() parser.add_argument('--env', type=str, default='CartPole-v1') parser.add_argument('--exp_name', type=str, default=None) parser.add_argument('--allow_run_as_root', action='store_true') parser.add_argument('--hidden_size', type=int, default=256) parser.add_argument('--num_hidden', type=int, default=2) parser.add_argument('--continue_training', '-c', action='store_true') parser.add_argument('--saved_model_file', '-f', type=str, default=None) args = parser.parse_args() if args.cpu > 1: mpi_tools.mpi_fork(args.cpu, allow_run_as_root=args.allow_run_as_root ) # run parallel code with mpi # Setup experiment name env = gym.make(args.env) from spinup.utils.run_utils import setup_logger_kwargs experiment_name = args.exp_name or env.spec.id logger_kwargs = setup_logger_kwargs(experiment_name, args.seed) # Load or create model saved_model_file = None if args.saved_model_file: saved_model_file = pathlib.Path(args.saved_model_file) elif args.continue_training: save_dir = pathlib.Path(logger_kwargs['output_dir'], 'pyt_save')