Beispiel #1
0
    def run(self):
        if self.duration is not None:
            self.end_time = time.time() + self.duration
        self.total_rew_avg = 0.0
        self.n_episodes = 0
        while self.duration is None or time.time() < self.end_time:
            if len(self.policies) == 1:
                action, _ = self.policies[0].act(self.ob)
            else:
                self.ob = splitobs(self.ob, keepdims=False)
                ob_policy_idx = np.split(np.arange(len(self.ob)), len(self.policies))
                actions = []
                for i, policy in enumerate(self.policies):
                    inp = itemgetter(*ob_policy_idx[i])(self.ob)
                    inp = listdict2dictnp([inp] if ob_policy_idx[i].shape[0] == 1 else inp)
                    ac, info = policy.act(inp)
                    actions.append(ac)
                action = listdict2dictnp(actions, keepdims=True)

            self.ob, rew, done, env_info = self.env.step(action)
            self.total_rew += rew

            if done or env_info.get('discard_episode', False):
                self.reset_increment()

            if self.display_window:
                self.add_overlay(const.GRID_TOPRIGHT, "Reset env; (current seed: {})".format(self.seed), "N - next / P - previous ")
                self.add_overlay(const.GRID_TOPRIGHT, "Reward", str(self.total_rew))
                if hasattr(self.env.unwrapped, "viewer_stats"):
                    for k, v in self.env.unwrapped.viewer_stats.items():
                        self.add_overlay(const.GRID_TOPRIGHT, k, str(v))

                self.env.render()
Beispiel #2
0
 def process_state_batch(self, states):
     '''
         Batch states together.
         args:
             states -- list (batch) of dicts of states with shape (n_agent, dim state).
     '''
     new_states = listdict2dictnp(states, keepdims=True)
     return new_states
Beispiel #3
0
    def process_observation_batch(self, obs):
        '''
            Batch obs together.
            Args:
                obs -- list of lists (batch, time), where elements are dictionary observations
        '''

        new_obs = deepcopy(obs)
        # List tranpose -- now in (time, batch)
        new_obs = list(map(list, zip(*new_obs)))
        # Convert list of list of dicts to dict of numpy arrays
        new_obs = listdict2dictnp(
            [listdict2dictnp(batch, keepdims=True) for batch in new_obs])
        # Flatten out the agent dimension, so batches look like normal SA batches
        new_obs = {
            k: self.reshape_ma_observations(v)
            for k, v in new_obs.items()
        }

        return new_obs
def ppo(env_fn,
        actor_critic=core.mlp_actor_critic,
        ac_kwargs=dict(),
        seed=33,
        steps_per_epoch=4000,
        epochs=50,
        gamma=0.998,
        clip_ratio=0.2,
        pi_lr=3e-4,
        vf_lr=1e-3,
        train_pi_iters=60,
        train_v_iters=60,
        lam=0.95,
        max_ep_len=1000,
        target_kl=0.01,
        logger_kwargs=dict(),
        save_freq=10):
    """
    Proximal Policy Optimization (by clipping),
    with early stopping based on approximate KL
    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.
        actor_critic: A function which takes in placeholder symbols
            for state, ``x_ph``, and action, ``a_ph``, and returns the main
            outputs from the agent's Tensorflow computation graph:
            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Samples actions from policy given
                                           | states.
            ``logp``     (batch,)          | Gives log probability, according to
                                           | the policy, of taking actions ``a_ph``
                                           | in states ``x_ph``.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``.
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x_ph``. (Critical: make sure
                                           | to flatten this!)
            ===========  ================  ======================================
        ac_kwargs (dict): Any kwargs appropriate for the actor_critic
            function you provided to PPO.
        seed (int): Seed for random number generators.
        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.
        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.
        gamma (float): Discount factor. (Always between 0 and 1.)
        clip_ratio (float): Hyperparameter for clipping in the policy objective.
            Roughly: how far can the new policy go from the old policy while
            still profiting (improving the objective function)? The new policy
            can still go farther than the clip_ratio says, but it doesn't help
            on the objective anymore. (Usually small, 0.1 to 0.3.) Typically
            denoted by :math:`\epsilon`.
        pi_lr (float): Learning rate for policy optimizer.
        vf_lr (float): Learning rate for value function optimizer.
        train_pi_iters (int): Maximum number of gradient descent steps to take
            on policy loss per epoch. (Early stopping may cause optimizer
            to take fewer than this.)
        train_v_iters (int): Number of gradient descent steps to take on
            value function per epoch.
        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)
        max_ep_len (int): Maximum length of trajectory / episode / rollout.
        target_kl (float): Roughly what KL divergence we think is appropriate
            between new and old policies after an update. This will get used
            for early stopping. (Usually small, 0.01 or 0.05.)
        logger_kwargs (dict): Keyword args for EpochLogger.
        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.
    """

    ## Logger setup
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    ## Random seed setting
    seed += 10000 * proc_id()
    tf.set_random_seed(seed)
    np.random.seed(seed)

    ## Environment instantiation
    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Policies vector (only one for this project)
    policies = []

    # TensorFlow session
    sess = tf.Session()

    # Build policy anc value networks
    MAP = MAPolicy(scope='policy_0',
                   ob_space=env.observation_space,
                   ac_space=env.action_space,
                   network_spec=pi_specs,
                   normalize=True,
                   v_network_spec=v_specs)
    policies = [MAP]

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Create aux placeholders for the computation graph
    adv_ph, ret_ph, logp_old_ph = core.placeholders(1, 1, 1)

    # Get main placeholders for the computation graph
    map_phs_dict = MAP.phs
    map_phs = [v for k, v in map_phs_dict.items()]

    for k, v in map_phs_dict.items():
        if v.name == None:
            v.name = k

            # Append aux and main placeholders
    # Need placeholders in *this* order later (to zip with data from buffer)
    new_phs = [adv_ph, ret_ph, logp_old_ph]
    all_phs = np.append(map_phs, new_phs)

    # Intantiate Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)

    # Count variables
    var_counts = tuple(
        core.count_vars(scope) for scope in ['policy_net', 'vpred_net'])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # PPO objectives
    ratio = tf.exp(MAP.taken_action_logp -
                   logp_old_ph)  # pi(a|s) / pi_old(a|s)
    min_adv = tf.where(adv_ph > 0, (1 + clip_ratio) * adv_ph,
                       (1 - clip_ratio) * adv_ph)  # PPO-clip limits
    pi_loss = -tf.reduce_mean(tf.minimum(ratio * adv_ph,
                                         min_adv))  # Policy loss function
    v_loss = tf.reduce_mean(
        (ret_ph - MAP.scaled_value_tensor)**2)  # Value loss function

    # Info (useful to watch during learning)
    approx_kl = tf.reduce_mean(
        logp_old_ph - MAP.taken_action_logp
    )  # a sample estimate for KL-divergence, easy to compute
    approx_ent = tf.reduce_mean(
        -MAP.taken_action_logp
    )  # a sample estimate for entropy, also easy to compute
    clipped = tf.logical_or(
        ratio > (1 + clip_ratio), ratio < (1 - clip_ratio)
    )  # a logical value which states whether there was clipping
    clipfrac = tf.reduce_mean(tf.cast(
        clipped, tf.float32))  # a measure of clipping for posterior analysis

    # Optimizers
    train_pi = MpiAdamOptimizer(learning_rate=pi_lr).minimize(
        pi_loss)  # Policy network optimizer
    train_v = MpiAdamOptimizer(learning_rate=vf_lr).minimize(
        v_loss)  # Value network optimizer

    # initialize TensorFlow variabels
    sess.run(tf.global_variables_initializer())

    # Sync params across processes
    sess.run(sync_all_params())

    # Set up logger variables to be saved (it is necessary to save everything that is
    # input/output to the networks so that the policy can be played afterwards during testing)
    out_act_dict = MAP.sampled_action
    out_state_dict = MAP.state_out
    logger_outputs = {**out_act_dict, **out_state_dict}

    for k, v in logger_outputs.items():
        if 'lstm' in k:
            logger_outputs[k + '_out'] = logger_outputs.pop(k)

    logger_inputs = map_phs_dict

    logger.setup_tf_saver(sess, inputs=logger_inputs, outputs=logger_outputs)

    # ======================================================================== #
    # ===================== Auxiliary Training Functions ===================== #
    # ======================================================================== #

    # Compute metrics for analysis during and after training
    def compute_metrics(extra_dict={}):

        loss_outs = {
            'pi_loss': pi_loss,
            'v_loss': v_loss,
            'approx_ent': approx_ent,
            'approx_kl': approx_kl,
            'approx_cf': clipfrac,
            'taken_action_logp': MAP.taken_action_logp,
            'ratio': ratio,
            'min_adv': min_adv
        }

        out_loss = policies[0].sess_run(buf.obs_buf,
                                        sess_act=sess,
                                        extra_feed_dict=extra_dict,
                                        other_outputs=loss_outs,
                                        replace=True)

        return out_loss['pi_loss'], out_loss['v_loss'], out_loss[
            'approx_ent'], out_loss['approx_kl'], out_loss['approx_cf']

    # ======================================================================= #

    # Run session on policy and value optimizers for training their respective networks
    def train(net, extra_dict={}):

        if net == 'pi':
            train_outs = {'train_pi': train_pi, 'approx_kl': approx_kl}
        elif net == 'v':
            train_outs = {'train_v': train_v}
        else:
            print("Error: Network not defined")
            return

        out_train = policies[0].sess_run(buf.obs_buf,
                                         sess_act=sess,
                                         extra_feed_dict=extra_dict,
                                         other_outputs=train_outs,
                                         replace=True)
        if net == 'pi':
            return out_train['approx_kl']

    # ======================================================================= #

    # Perform training procedure
    def update():

        print("======= update!")

        # get aux data from the buffer and match it with its respective placeholders
        buf_data = buf.get(aux_vars_only=True)
        aux_inputs = {k: v for k, v in zip(new_phs, buf_data)}

        # for the training, the actions taken during the experience loop are also inputs to the network
        extra_dict = {k: v for k, v in buf.act_buf.items() if k is not 'vpred'}

        for k, v in extra_dict.items():
            if k == 'action_movement':
                extra_dict[k] = np.expand_dims(v, 1)

        # actions and aux variables from the buffer are joined and passed to compute_metrics (observations are joined within the functions)
        extra_dict.update(aux_inputs)
        pi_l_old, v_l_old, ent, kl, cf = compute_metrics(extra_dict)

        # Policy training loop
        for i in range(train_pi_iters):
            if i % 10 == 0:
                print("training pi iter ", i)
            kl = train('pi', extra_dict)
            kl = mpi_avg(kl)
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break

        logger.store(StopIter=i)
        print("")

        # Value training loop
        for j in range(train_v_iters):
            if j % 10 == 0:
                print("training v iter ", j)
            train('v', extra_dict)

        # Log changes from update with a new run on compute_metrics
        pi_l_new, v_l_new, ent, kl, cf = compute_metrics(extra_dict)

        # Store information
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

        # Reset experience varibales
        o, ep_ret, ep_len = env.reset(), 0, 0

        # Reset policy
        for policy in policies:
            policy.reset()

        print("======= update finished!")

    # ======================================================================= #

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    # ======================================================================= #
    # ========================== Experience Loop ============================ #
    # ======================================================================= #

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        print("epoch: ", epoch)
        for t in range(local_steps_per_epoch):

            # Pass observations through networs and get action + predicted value
            if len(policies) == 1:  # this project's case
                a, info = policies[0].sess_run(o, sess_act=sess)
                v_t = info['vpred']
                logp_t = info['ac_logp']
            else:
                o = splitobs(o, keepdims=False)
                ob_policy_idx = np.split(np.arange(len(o)), len(policies))
                actions = []
                for i, policy in enumerate(policies):
                    inp = operator.itemgetter(*ob_policy_idx[i])(o)
                    inp = listdict2dictnp([inp] if ob_policy_idx[i].shape[0] ==
                                          1 else inp)
                    ac, info = policy.act(inp)
                    actions.append(ac)
                action = listdict2dictnp(actions, keepdims=True)

            # Take a step in the environment
            o2, r, d, env_info = env.step(a)
            ep_ret += r
            ep_len += 1

            # If env.render is uncommented, the experience loop is displayed (visualized)
            # in real time (much slower, but excelent debugging)

            # env.render()

            # save experience in buffer and log
            buf.store(o, a, r, v_t, logp_t)
            logger.store(VVals=v_t)

            # Update obs (critical!)
            o = o2

            # Treat the end of a trajectory
            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == local_steps_per_epoch - 1) or env_info.get(
                    'discard_episode', False):
                if not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)
                # if trajectory didn't reach terminal state, bootstrap value target
                if d:
                    last_val = 0
                else:
                    _, info = policies[0].sess_run(o, sess_act=sess)
                    last_val = info['vpred']

                # Compute advantage estimates and rewards-to-go
                buf.finish_path(last_val)

                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)

                o, ep_ret, ep_len = env.reset(), 0, 0

                for policy in policies:
                    policy.reset()

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            print("Saved epoch: ", epoch)
            logger.save_state({'env': env}, None)

        # Perform PPO update!
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()