def debug(args, model_fn, act_update_fns, multi_thread): create_if_need(args.logdir) env = create_env(args) if args.flip_state_action and hasattr(env, "state_transform"): args.flip_states = env.state_transform.flip_states args.n_action = env.action_space.shape[0] args.n_observation = env.observation_space.shape[0] args.actor_layers = str2params(args.actor_layers) args.critic_layers = str2params(args.critic_layers) args.actor_activation = activations[args.actor_activation] args.critic_activation = activations[args.critic_activation] actor, critic = model_fn(args) if args.restore_actor_from is not None: actor.load_state_dict(torch.load(args.restore_actor_from)) if args.restore_critic_from is not None: critic.load_state_dict(torch.load(args.restore_critic_from)) actor.train() critic.train() actor.share_memory() critic.share_memory() target_actor = copy.deepcopy(actor) target_critic = copy.deepcopy(critic) hard_update(target_actor, actor) hard_update(target_critic, critic) target_actor.train() critic.train() target_actor.share_memory() target_critic.share_memory() _, _, save_fn = act_update_fns(actor, critic, target_actor, target_critic, args) args.thread = 0 best_reward = Value("f", 0.0) multi_thread(actor, critic, target_actor, target_critic, args, act_update_fns, best_reward) save_fn()
def submit_or_test(args, model_fn, act_update_fn, submit_fn, test_fn): args = restore_args(args) env = create_env(args) args.n_action = env.action_space.shape[0] args.n_observation = env.observation_space.shape[0] args.actor_layers = str2params(args.actor_layers) args.critic_layers = str2params(args.critic_layers) args.actor_activation = activations[args.actor_activation] args.critic_activation = activations[args.critic_activation] actor, critic = model_fn(args) actor.load_state_dict(torch.load(args.restore_actor_from)) critic.load_state_dict(torch.load(args.restore_critic_from)) if args.submit: submit_fn(actor, critic, args, act_update_fn) else: test_fn(actor, critic, args, act_update_fn)
def play_single_thread( actor, critic, target_actor, target_critic, args, prepare_fn, global_episode, global_update_step, episodes_queue, best_reward): workerseed = args.seed + 241 * args.thread set_global_seeds(workerseed) args.logdir = "{}/thread_{}".format(args.logdir, args.thread) create_if_need(args.logdir) act_fn, _, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args) logger = Logger(args.logdir) env = create_env(args) random_process = create_random_process(args) epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2) epsilon_decay_fn = create_decay_fn( "cycle", initial_value=args.initial_epsilon, final_value=args.final_epsilon, cycle_len=epsilon_cycle_len, num_cycles=args.max_episodes // epsilon_cycle_len) episode = 1 step = 0 start_time = time.time() while global_episode.value < args.max_episodes * (args.num_threads - args.num_train_threads) \ and global_update_step.value < args.max_update_steps * args.num_train_threads: if episode % 100 == 0: env = create_env(args) seed = random.randrange(2 ** 32 - 2) epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode))) episode_metrics = { "reward": 0.0, "step": 0, "epsilon": epsilon } observation = env.reset(seed=seed, difficulty=args.difficulty) random_process.reset_states() done = False replay = [] while not done: action = act_fn(observation, noise=epsilon * random_process.sample()) next_observation, reward, done, _ = env.step(action) replay.append((observation, action, reward, next_observation, done)) episode_metrics["reward"] += reward episode_metrics["step"] += 1 observation = next_observation episodes_queue.put(replay) episode += 1 global_episode.value += 1 if episode_metrics["reward"] > best_reward.value: best_reward.value = episode_metrics["reward"] logger.scalar_summary("best reward", best_reward.value, episode) if episode_metrics["reward"] > 15.0 * args.reward_scale: save_fn(episode) step += episode_metrics["step"] elapsed_time = time.time() - start_time for key, value in episode_metrics.items(): logger.scalar_summary(key, value, episode) logger.scalar_summary( "episode per minute", episode / elapsed_time * 60, episode) logger.scalar_summary( "step per second", step / elapsed_time, episode) if elapsed_time > 86400 * args.max_train_days: global_episode.value = args.max_episodes * (args.num_threads - args.num_train_threads) + 1 raise KeyboardInterrupt
def train_multi_thread(actor, critic, target_actor, target_critic, args, prepare_fn, best_reward): workerseed = args.seed + 241 * args.thread set_global_seeds(workerseed) args.logdir = "{}/thread_{}".format(args.logdir, args.thread) create_if_need(args.logdir) act_fn, update_fn, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args) logger = Logger(args.logdir) buffer = create_buffer(args) if args.prioritized_replay: beta_deacy_fn = create_decay_fn( "linear", initial_value=args.prioritized_replay_beta0, final_value=1.0, max_step=args.max_episodes) env = create_env(args) random_process = create_random_process(args) actor_learning_rate_decay_fn = create_decay_fn( "linear", initial_value=args.actor_lr, final_value=args.actor_lr_end, max_step=args.max_episodes) critic_learning_rate_decay_fn = create_decay_fn( "linear", initial_value=args.critic_lr, final_value=args.critic_lr_end, max_step=args.max_episodes) epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2) epsilon_decay_fn = create_decay_fn( "cycle", initial_value=args.initial_epsilon, final_value=args.final_epsilon, cycle_len=epsilon_cycle_len, num_cycles=args.max_episodes // epsilon_cycle_len) episode = 0 step = 0 start_time = time.time() while episode < args.max_episodes: if episode % 100 == 0: env = create_env(args) seed = random.randrange(2 ** 32 - 2) actor_lr = actor_learning_rate_decay_fn(episode) critic_lr = critic_learning_rate_decay_fn(episode) epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode))) episode_metrics = { "value_loss": 0.0, "policy_loss": 0.0, "reward": 0.0, "step": 0, "epsilon": epsilon } observation = env.reset(seed=seed, difficulty=args.difficulty) random_process.reset_states() done = False while not done: action = act_fn(observation, noise=epsilon*random_process.sample()) next_observation, reward, done, _ = env.step(action) buffer.add(observation, action, reward, next_observation, done) episode_metrics["reward"] += reward episode_metrics["step"] += 1 if len(buffer) >= args.train_steps: if args.prioritized_replay: (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones, weights, batch_idxes) = \ buffer.sample(batch_size=args.batch_size, beta=beta_deacy_fn(episode)) else: (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones) = \ buffer.sample(batch_size=args.batch_size) weights, batch_idxes = np.ones_like(tr_rewards), None step_metrics, step_info = update_fn( tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones, weights, actor_lr, critic_lr) if args.prioritized_replay: new_priorities = np.abs(step_info["td_error"]) + 1e-6 buffer.update_priorities(batch_idxes, new_priorities) for key, value in step_metrics.items(): value = to_numpy(value)[0] episode_metrics[key] += value observation = next_observation episode += 1 if episode_metrics["reward"] > 15.0 * args.reward_scale \ and episode_metrics["reward"] > best_reward.value: best_reward.value = episode_metrics["reward"] logger.scalar_summary("best reward", best_reward.value, episode) save_fn(episode) step += episode_metrics["step"] elapsed_time = time.time() - start_time for key, value in episode_metrics.items(): value = value if "loss" not in key else value / episode_metrics["step"] logger.scalar_summary(key, value, episode) logger.scalar_summary( "episode per minute", episode / elapsed_time * 60, episode) logger.scalar_summary( "step per second", step / elapsed_time, episode) logger.scalar_summary("actor lr", actor_lr, episode) logger.scalar_summary("critic lr", critic_lr, episode) if episode % args.save_step == 0: save_fn(episode) if elapsed_time > 86400 * args.max_train_days: episode = args.max_episodes + 1 save_fn(episode) raise KeyboardInterrupt
def train(args, model_fn, act_update_fns, multi_thread, train_single, play_single): create_if_need(args.logdir) if args.restore_args_from is not None: args = restore_args(args) with open("{}/args.json".format(args.logdir), "w") as fout: json.dump(vars(args), fout, indent=4, ensure_ascii=False, sort_keys=True) env = create_env(args) if args.flip_state_action and hasattr(env, "state_transform"): args.flip_states = env.state_transform.flip_states args.batch_size = args.batch_size // 2 args.n_action = env.action_space.shape[0] args.n_observation = env.observation_space.shape[0] args.actor_layers = str2params(args.actor_layers) args.critic_layers = str2params(args.critic_layers) args.actor_activation = activations[args.actor_activation] args.critic_activation = activations[args.critic_activation] actor, critic = model_fn(args) if args.restore_actor_from is not None: actor.load_state_dict(torch.load(args.restore_actor_from)) if args.restore_critic_from is not None: critic.load_state_dict(torch.load(args.restore_critic_from)) actor.train() critic.train() actor.share_memory() critic.share_memory() target_actor = copy.deepcopy(actor) target_critic = copy.deepcopy(critic) hard_update(target_actor, actor) hard_update(target_critic, critic) target_actor.train() target_critic.train() target_actor.share_memory() target_critic.share_memory() _, _, save_fn = act_update_fns(actor, critic, target_actor, target_critic, args) processes = [] best_reward = Value("f", 0.0) try: if args.num_threads == args.num_train_threads: for rank in range(args.num_threads): args.thread = rank p = mp.Process(target=multi_thread, args=(actor, critic, target_actor, target_critic, args, act_update_fns, best_reward)) p.start() processes.append(p) else: global_episode = Value("i", 0) global_update_step = Value("i", 0) episodes_queue = mp.Queue() for rank in range(args.num_threads): args.thread = rank if rank < args.num_train_threads: p = mp.Process(target=train_single, args=(actor, critic, target_actor, target_critic, args, act_update_fns, global_episode, global_update_step, episodes_queue)) else: p = mp.Process(target=play_single, args=(actor, critic, target_actor, target_critic, args, act_update_fns, global_episode, global_update_step, episodes_queue, best_reward)) p.start() processes.append(p) for p in processes: p.join() except KeyboardInterrupt: pass save_fn()
def train(args): import baselines.baselines_common.tf_util as U sess = U.single_threaded_session() sess.__enter__() if args.restore_args_from is not None: args = restore_params(args) rank = MPI.COMM_WORLD.Get_rank() workerseed = args.seed + 241 * MPI.COMM_WORLD.Get_rank() set_global_seeds(workerseed) def policy_fn(name, ob_space, ac_space): return Actor(name=name, ob_space=ob_space, ac_space=ac_space, hid_size=args.hid_size, num_hid_layers=args.num_hid_layers, noise_type=args.noise_type) env = create_env(args) env.seed(workerseed) if rank == 0: create_if_need(args.logdir) with open("{}/args.json".format(args.logdir), "w") as fout: json.dump(vars(args), fout, indent=4, ensure_ascii=False, sort_keys=True) try: args.thread = rank if args.agent == "trpo": trpo.learn(env, policy_fn, args, timesteps_per_batch=1024, gamma=args.gamma, lam=0.98, max_kl=0.01, cg_iters=10, cg_damping=0.1, vf_iters=5, vf_stepsize=1e-3) elif args.agent == "ppo": # optimal settings: # timesteps_per_batch = optim_epochs * optim_batchsize ppo.learn(env, policy_fn, args, timesteps_per_batch=256, gamma=args.gamma, lam=0.95, clip_param=0.2, entcoeff=0.0, optim_epochs=4, optim_stepsize=3e-4, optim_batchsize=64, schedule='constant') else: raise NotImplementedError except KeyboardInterrupt: print("closing envs...") env.close()