def run(args): env = make_dmc_env(args.domain_name, args.task_name, args.action_repeat) env_test = make_dmc_env(args.domain_name, args.task_name, args.action_repeat) algo = SAC_AE( num_agent_steps=args.num_agent_steps, state_space=env.observation_space, action_space=env.action_space, seed=args.seed, ) time = datetime.now().strftime("%Y%m%d-%H%M") log_dir = os.path.join("logs", f"{args.domain_name}-{args.task_name}", f"{str(algo)}-seed{args.seed}-{time}") trainer = Trainer( env=env, env_test=env_test, algo=algo, log_dir=log_dir, num_agent_steps=args.num_agent_steps, action_repeat=args.action_repeat, eval_interval=args.eval_interval, seed=args.seed, ) trainer.train()
def run(args): env = make_continuous_env(args.env_id) env_test = make_continuous_env(args.env_id) algo = DDPG( num_agent_steps=args.num_agent_steps, state_space=env.observation_space, action_space=env.action_space, seed=args.seed, ) time = datetime.now().strftime("%Y%m%d-%H%M") log_dir = os.path.join("logs", args.env_id, f"{str(algo)}-seed{args.seed}-{time}") trainer = Trainer( env=env, env_test=env_test, algo=algo, log_dir=log_dir, num_agent_steps=args.num_agent_steps, eval_interval=args.eval_interval, seed=args.seed, ) trainer.train()
def run(args): env = make_atari_env(args.env_id, sign_rewards=False, clip_rewards=True) env_test = make_atari_env(args.env_id, episode_life=False, sign_rewards=False) algo = SAC_Discrete( num_agent_steps=args.num_agent_steps, state_space=env.observation_space, action_space=env.action_space, seed=args.seed, ) time = datetime.now().strftime("%Y%m%d-%H%M") log_dir = os.path.join("logs", args.env_id, f"{str(algo)}-seed{args.seed}-{time}") trainer = Trainer( env=env, env_test=env_test, algo=algo, log_dir=log_dir, num_agent_steps=args.num_agent_steps, action_repeat=4, eval_interval=args.eval_interval, seed=args.seed, ) trainer.train()