Ejemplo n.º 1
0
def debug(args, model_fn, act_update_fns, multi_thread):
    create_if_need(args.logdir)
    env = create_env(args)

    if args.flip_state_action and hasattr(env, "state_transform"):
        args.flip_states = env.state_transform.flip_states

    args.n_action = env.action_space.shape[0]
    args.n_observation = env.observation_space.shape[0]

    args.actor_layers = str2params(args.actor_layers)
    args.critic_layers = str2params(args.critic_layers)

    args.actor_activation = activations[args.actor_activation]
    args.critic_activation = activations[args.critic_activation]

    actor, critic = model_fn(args)

    if args.restore_actor_from is not None:
        actor.load_state_dict(torch.load(args.restore_actor_from))
    if args.restore_critic_from is not None:
        critic.load_state_dict(torch.load(args.restore_critic_from))

    actor.train()
    critic.train()
    actor.share_memory()
    critic.share_memory()

    target_actor = copy.deepcopy(actor)
    target_critic = copy.deepcopy(critic)

    hard_update(target_actor, actor)
    hard_update(target_critic, critic)

    target_actor.train()
    critic.train()
    target_actor.share_memory()
    target_critic.share_memory()

    _, _, save_fn = act_update_fns(actor, critic, target_actor, target_critic,
                                   args)

    args.thread = 0
    best_reward = Value("f", 0.0)
    multi_thread(actor, critic, target_actor, target_critic, args,
                 act_update_fns, best_reward)

    save_fn()
Ejemplo n.º 2
0
def submit_or_test(args, model_fn, act_update_fn, submit_fn, test_fn):
    args = restore_args(args)
    env = create_env(args)

    args.n_action = env.action_space.shape[0]
    args.n_observation = env.observation_space.shape[0]

    args.actor_layers = str2params(args.actor_layers)
    args.critic_layers = str2params(args.critic_layers)

    args.actor_activation = activations[args.actor_activation]
    args.critic_activation = activations[args.critic_activation]

    actor, critic = model_fn(args)
    actor.load_state_dict(torch.load(args.restore_actor_from))
    critic.load_state_dict(torch.load(args.restore_critic_from))

    if args.submit:
        submit_fn(actor, critic, args, act_update_fn)
    else:
        test_fn(actor, critic, args, act_update_fn)
Ejemplo n.º 3
0
def play_single_thread(
        actor, critic, target_actor, target_critic, args, prepare_fn,
        global_episode, global_update_step, episodes_queue,
        best_reward):
    workerseed = args.seed + 241 * args.thread
    set_global_seeds(workerseed)

    args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
    create_if_need(args.logdir)

    act_fn, _, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args)

    logger = Logger(args.logdir)
    env = create_env(args)
    random_process = create_random_process(args)

    epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2)

    epsilon_decay_fn = create_decay_fn(
        "cycle",
        initial_value=args.initial_epsilon,
        final_value=args.final_epsilon,
        cycle_len=epsilon_cycle_len,
        num_cycles=args.max_episodes // epsilon_cycle_len)

    episode = 1
    step = 0
    start_time = time.time()
    while global_episode.value < args.max_episodes * (args.num_threads - args.num_train_threads) \
            and global_update_step.value < args.max_update_steps * args.num_train_threads:
        if episode % 100 == 0:
            env = create_env(args)
        seed = random.randrange(2 ** 32 - 2)

        epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode)))

        episode_metrics = {
            "reward": 0.0,
            "step": 0,
            "epsilon": epsilon
        }

        observation = env.reset(seed=seed, difficulty=args.difficulty)
        random_process.reset_states()
        done = False

        replay = []
        while not done:
            action = act_fn(observation, noise=epsilon * random_process.sample())
            next_observation, reward, done, _ = env.step(action)

            replay.append((observation, action, reward, next_observation, done))
            episode_metrics["reward"] += reward
            episode_metrics["step"] += 1

            observation = next_observation

        episodes_queue.put(replay)

        episode += 1
        global_episode.value += 1

        if episode_metrics["reward"] > best_reward.value:
            best_reward.value = episode_metrics["reward"]
            logger.scalar_summary("best reward", best_reward.value, episode)

            if episode_metrics["reward"] > 15.0 * args.reward_scale:
                save_fn(episode)

        step += episode_metrics["step"]
        elapsed_time = time.time() - start_time

        for key, value in episode_metrics.items():
            logger.scalar_summary(key, value, episode)
        logger.scalar_summary(
            "episode per minute",
            episode / elapsed_time * 60,
            episode)
        logger.scalar_summary(
            "step per second",
            step / elapsed_time,
            episode)

        if elapsed_time > 86400 * args.max_train_days:
            global_episode.value = args.max_episodes * (args.num_threads - args.num_train_threads) + 1

    raise KeyboardInterrupt
Ejemplo n.º 4
0
def train_multi_thread(actor, critic, target_actor, target_critic, args, prepare_fn, best_reward):
    workerseed = args.seed + 241 * args.thread
    set_global_seeds(workerseed)

    args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
    create_if_need(args.logdir)

    act_fn, update_fn, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args)
    logger = Logger(args.logdir)

    buffer = create_buffer(args)
    if args.prioritized_replay:
        beta_deacy_fn = create_decay_fn(
            "linear",
            initial_value=args.prioritized_replay_beta0,
            final_value=1.0,
            max_step=args.max_episodes)

    env = create_env(args)
    random_process = create_random_process(args)

    actor_learning_rate_decay_fn = create_decay_fn(
        "linear",
        initial_value=args.actor_lr,
        final_value=args.actor_lr_end,
        max_step=args.max_episodes)
    critic_learning_rate_decay_fn = create_decay_fn(
        "linear",
        initial_value=args.critic_lr,
        final_value=args.critic_lr_end,
        max_step=args.max_episodes)

    epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2)

    epsilon_decay_fn = create_decay_fn(
        "cycle",
        initial_value=args.initial_epsilon,
        final_value=args.final_epsilon,
        cycle_len=epsilon_cycle_len,
        num_cycles=args.max_episodes // epsilon_cycle_len)

    episode = 0
    step = 0
    start_time = time.time()
    while episode < args.max_episodes:
        if episode % 100 == 0:
            env = create_env(args)
        seed = random.randrange(2 ** 32 - 2)

        actor_lr = actor_learning_rate_decay_fn(episode)
        critic_lr = critic_learning_rate_decay_fn(episode)
        epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode)))

        episode_metrics = {
            "value_loss": 0.0,
            "policy_loss": 0.0,
            "reward": 0.0,
            "step": 0,
            "epsilon": epsilon
        }

        observation = env.reset(seed=seed, difficulty=args.difficulty)
        random_process.reset_states()
        done = False

        while not done:
            action = act_fn(observation, noise=epsilon*random_process.sample())
            next_observation, reward, done, _ = env.step(action)

            buffer.add(observation, action, reward, next_observation, done)
            episode_metrics["reward"] += reward
            episode_metrics["step"] += 1

            if len(buffer) >= args.train_steps:

                if args.prioritized_replay:
                    (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones,
                     weights, batch_idxes) = \
                        buffer.sample(batch_size=args.batch_size, beta=beta_deacy_fn(episode))
                else:
                    (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones) = \
                        buffer.sample(batch_size=args.batch_size)
                    weights, batch_idxes = np.ones_like(tr_rewards), None

                step_metrics, step_info = update_fn(
                    tr_observations, tr_actions, tr_rewards,
                    tr_next_observations, tr_dones,
                    weights, actor_lr, critic_lr)

                if args.prioritized_replay:
                    new_priorities = np.abs(step_info["td_error"]) + 1e-6
                    buffer.update_priorities(batch_idxes, new_priorities)

                for key, value in step_metrics.items():
                    value = to_numpy(value)[0]
                    episode_metrics[key] += value

            observation = next_observation

        episode += 1

        if episode_metrics["reward"] > 15.0 * args.reward_scale \
                and episode_metrics["reward"] > best_reward.value:
            best_reward.value = episode_metrics["reward"]
            logger.scalar_summary("best reward", best_reward.value, episode)
            save_fn(episode)

        step += episode_metrics["step"]
        elapsed_time = time.time() - start_time

        for key, value in episode_metrics.items():
            value = value if "loss" not in key else value / episode_metrics["step"]
            logger.scalar_summary(key, value, episode)
        logger.scalar_summary(
            "episode per minute",
            episode / elapsed_time * 60,
            episode)
        logger.scalar_summary(
            "step per second",
            step / elapsed_time,
            episode)
        logger.scalar_summary("actor lr", actor_lr, episode)
        logger.scalar_summary("critic lr", critic_lr, episode)

        if episode % args.save_step == 0:
            save_fn(episode)

        if elapsed_time > 86400 * args.max_train_days:
            episode = args.max_episodes + 1

    save_fn(episode)

    raise KeyboardInterrupt
Ejemplo n.º 5
0
def train(args, model_fn, act_update_fns, multi_thread, train_single,
          play_single):
    create_if_need(args.logdir)

    if args.restore_args_from is not None:
        args = restore_args(args)

    with open("{}/args.json".format(args.logdir), "w") as fout:
        json.dump(vars(args),
                  fout,
                  indent=4,
                  ensure_ascii=False,
                  sort_keys=True)

    env = create_env(args)

    if args.flip_state_action and hasattr(env, "state_transform"):
        args.flip_states = env.state_transform.flip_states
        args.batch_size = args.batch_size // 2

    args.n_action = env.action_space.shape[0]
    args.n_observation = env.observation_space.shape[0]

    args.actor_layers = str2params(args.actor_layers)
    args.critic_layers = str2params(args.critic_layers)

    args.actor_activation = activations[args.actor_activation]
    args.critic_activation = activations[args.critic_activation]

    actor, critic = model_fn(args)

    if args.restore_actor_from is not None:
        actor.load_state_dict(torch.load(args.restore_actor_from))
    if args.restore_critic_from is not None:
        critic.load_state_dict(torch.load(args.restore_critic_from))

    actor.train()
    critic.train()
    actor.share_memory()
    critic.share_memory()

    target_actor = copy.deepcopy(actor)
    target_critic = copy.deepcopy(critic)

    hard_update(target_actor, actor)
    hard_update(target_critic, critic)

    target_actor.train()
    target_critic.train()
    target_actor.share_memory()
    target_critic.share_memory()

    _, _, save_fn = act_update_fns(actor, critic, target_actor, target_critic,
                                   args)

    processes = []
    best_reward = Value("f", 0.0)
    try:
        if args.num_threads == args.num_train_threads:
            for rank in range(args.num_threads):
                args.thread = rank
                p = mp.Process(target=multi_thread,
                               args=(actor, critic, target_actor,
                                     target_critic, args, act_update_fns,
                                     best_reward))
                p.start()
                processes.append(p)
        else:
            global_episode = Value("i", 0)
            global_update_step = Value("i", 0)
            episodes_queue = mp.Queue()
            for rank in range(args.num_threads):
                args.thread = rank
                if rank < args.num_train_threads:
                    p = mp.Process(target=train_single,
                                   args=(actor, critic, target_actor,
                                         target_critic, args, act_update_fns,
                                         global_episode, global_update_step,
                                         episodes_queue))
                else:
                    p = mp.Process(target=play_single,
                                   args=(actor, critic, target_actor,
                                         target_critic, args, act_update_fns,
                                         global_episode, global_update_step,
                                         episodes_queue, best_reward))
                p.start()
                processes.append(p)

        for p in processes:
            p.join()
    except KeyboardInterrupt:
        pass

    save_fn()
Ejemplo n.º 6
0
def train(args):
    import baselines.baselines_common.tf_util as U

    sess = U.single_threaded_session()
    sess.__enter__()

    if args.restore_args_from is not None:
        args = restore_params(args)

    rank = MPI.COMM_WORLD.Get_rank()

    workerseed = args.seed + 241 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)

    def policy_fn(name, ob_space, ac_space):
        return Actor(name=name,
                     ob_space=ob_space,
                     ac_space=ac_space,
                     hid_size=args.hid_size,
                     num_hid_layers=args.num_hid_layers,
                     noise_type=args.noise_type)

    env = create_env(args)
    env.seed(workerseed)

    if rank == 0:
        create_if_need(args.logdir)
        with open("{}/args.json".format(args.logdir), "w") as fout:
            json.dump(vars(args),
                      fout,
                      indent=4,
                      ensure_ascii=False,
                      sort_keys=True)

    try:
        args.thread = rank
        if args.agent == "trpo":
            trpo.learn(env,
                       policy_fn,
                       args,
                       timesteps_per_batch=1024,
                       gamma=args.gamma,
                       lam=0.98,
                       max_kl=0.01,
                       cg_iters=10,
                       cg_damping=0.1,
                       vf_iters=5,
                       vf_stepsize=1e-3)
        elif args.agent == "ppo":
            # optimal settings:
            # timesteps_per_batch = optim_epochs *  optim_batchsize
            ppo.learn(env,
                      policy_fn,
                      args,
                      timesteps_per_batch=256,
                      gamma=args.gamma,
                      lam=0.95,
                      clip_param=0.2,
                      entcoeff=0.0,
                      optim_epochs=4,
                      optim_stepsize=3e-4,
                      optim_batchsize=64,
                      schedule='constant')
        else:
            raise NotImplementedError
    except KeyboardInterrupt:
        print("closing envs...")

    env.close()