Example #1
0
def sac(env_fn,
        actor_fn=mlp_actor,
        critic_fn=mlp_critic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=1000,
        epochs=100,
        replay_size=int(1e6),
        gamma=0.99,
        polyak=0.995,
        lr=1e-4,
        batch_size=1024,
        local_start_steps=int(1e3),
        max_ep_len=1000,
        logger_kwargs=dict(),
        save_freq=10,
        local_update_after=int(1e3),
        update_freq=1,
        render=False,
        fixed_entropy_bonus=None,
        entropy_constraint=-1.0,
        fixed_cost_penalty=None,
        cost_constraint=None,
        cost_lim=None,
        reward_scale=1,
        penalty_lr=5e-2,
        use_discor=False,
        cost_maxq=True):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_fn: A function which takes in placeholder symbols
            for state, ``x_ph``, and action, ``a_ph``, and returns the actor
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``mu``       (batch, act_dim)  | Computes mean actions from policy
                                           | given states.
            ``pi``       (batch, act_dim)  | Samples actions from policy given
                                           | states.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``. Critical: must be differentiable
                                           | with respect to policy parameters all
                                           | the way through action sampling.
            ===========  ================  ======================================

        critic_fn: A function which takes in placeholder symbols
            for state, ``x_ph``, action, ``a_ph``, and policy ``pi``,
            and returns the critic outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``critic``    (batch,)         | Gives one estimate of Q* for
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``critic_pi`` (batch,)         | Gives another estimate of Q* for
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_fn / critic_fn
            function you provided to SAC.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target
            networks. Target networks are updated towards main networks
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        batch_size (int): Minibatch size for SGD.

        local_start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

        fixed_entropy_bonus (float or None): Fixed bonus to reward for entropy.
            Units are (points of discounted sum of future reward) / (nats of policy entropy).
            If None, use ``entropy_constraint`` to set bonus value instead.

        entropy_constraint (float): If ``fixed_entropy_bonus`` is None,
            Adjust entropy bonus to maintain at least this much entropy.
            Actual constraint value is multiplied by the dimensions of the action space.
            Units are (nats of policy entropy) / (action dimenson).

        fixed_cost_penalty (float or None): Fixed penalty to reward for cost.
            Units are (points of discounted sum of future reward) / (points of discounted sum of future costs).
            If None, use ``cost_constraint`` to set penalty value instead.

        cost_constraint (float or None): If ``fixed_cost_penalty`` is None,
            Adjust cost penalty to maintain at most this much cost.
            Units are (points of discounted sum of future costs).
            Note: to get an approximate cost_constraint from a cost_lim (undiscounted sum of costs),
            multiply cost_lim by (1 - gamma ** episode_len) / (1 - gamma).
            If None, use cost_lim to calculate constraint.

        cost_lim (float or None): If ``cost_constraint`` is None,
            calculate an approximate constraint cost from this cost limit.
            Units are (expectation of undiscounted sum of costs in a single episode).
            If None, cost_lim is not used, and if no cost constraints are used, do naive optimization.
    """
    use_costs = fixed_cost_penalty or cost_constraint or cost_lim

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    # Env instantiation
    env, test_env = env_fn(), env_fn()
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    # Setting seeds
    tf.set_random_seed(seed)
    np.random.seed(seed)
    env.seed(seed)
    test_env.seed(seed)

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    act_limit = env.action_space.high[0]

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Inputs to computation graph
    x_ph, a_ph, x2_ph, r_ph, d_ph, c_ph = placeholders(obs_dim, act_dim,
                                                       obs_dim, None, None,
                                                       None)

    # Main outputs from computation graph
    with tf.variable_scope('main'):
        mu, pi, logp_pi = actor_fn(x_ph, a_ph, **ac_kwargs)
        qr1, qr1_pi = critic_fn(x_ph, a_ph, pi, name='qr1', **ac_kwargs)
        qr2, qr2_pi = critic_fn(x_ph, a_ph, pi, name='qr2', **ac_kwargs)
        qc1, qc1_pi = critic_fn(x_ph, a_ph, pi, name='qc1', **ac_kwargs)
        if cost_maxq:
            qc2, qc2_pi = critic_fn(x_ph, a_ph, pi, name='qc2', **ac_kwargs)
        if use_discor:
            er1, er1_targ = critic_fn(x_ph, a_ph, pi, name='er1', **ac_kwargs)
            er2, er2_targ = critic_fn(x_ph, a_ph, pi, name='er2', **ac_kwargs)
            ec1, ec1_targ = critic_fn(x_ph, a_ph, pi, name='ec1', **ac_kwargs)
            if cost_maxq:
                ec2, ec2_targ = critic_fn(x_ph,
                                          a_ph,
                                          pi,
                                          name='ec2',
                                          **ac_kwargs)

    with tf.variable_scope('main', reuse=True):
        # Additional policy output from a different observation placeholder
        # This lets us do separate optimization updates (actor, critics, etc)
        # in a single tensorflow op.
        _, pi2, logp_pi2 = actor_fn(x2_ph, a_ph, **ac_kwargs)

    # Target value network
    with tf.variable_scope('target'):
        _, qr1_pi_targ = critic_fn(x2_ph, a_ph, pi2, name='qr1', **ac_kwargs)
        _, qr2_pi_targ = critic_fn(x2_ph, a_ph, pi2, name='qr2', **ac_kwargs)
        _, qc1_pi_targ = critic_fn(x2_ph, a_ph, pi2, name='qc1', **ac_kwargs)
        if cost_maxq:
            _, qc2_pi_targ = critic_fn(x2_ph,
                                       a_ph,
                                       pi2,
                                       name='qc2',
                                       **ac_kwargs)
        if use_discor:
            _, er1_pi_targ = critic_fn(x_ph, a_ph, pi, name='er1', **ac_kwargs)
            _, er2_pi_targ = critic_fn(x_ph, a_ph, pi, name='er2', **ac_kwargs)
            _, ec1_pi_targ = critic_fn(x_ph, a_ph, pi, name='ec1', **ac_kwargs)
            if cost_maxq:
                _, ec2_pi_targ = critic_fn(x_ph,
                                           a_ph,
                                           pi,
                                           name='ec2',
                                           **ac_kwargs)

    # Entropy bonus
    if fixed_entropy_bonus is None:
        with tf.variable_scope('entreg'):
            soft_alpha = tf.get_variable('soft_alpha',
                                         initializer=0.0,
                                         trainable=True,
                                         dtype=tf.float32)
        alpha = tf.nn.softplus(soft_alpha)
    else:
        alpha = tf.constant(fixed_entropy_bonus)
    log_alpha = tf.log(alpha)

    # Cost penalty
    if use_costs:
        if fixed_cost_penalty is None:
            with tf.variable_scope('costpen'):
                soft_beta = tf.get_variable('soft_beta',
                                            initializer=0.0,
                                            trainable=True,
                                            dtype=tf.float32)
            beta = tf.nn.softplus(soft_beta)
            log_beta = tf.log(beta)
        else:
            beta = tf.constant(fixed_cost_penalty)
            log_beta = tf.log(beta)
    else:
        beta = 0.0  # costs do not contribute to policy optimization
        print('Not using costs')

    if use_discor:
        with tf.variable_scope('discor'):
            tr1 = tf.get_variable('tr1',
                                  initializer=10.0,
                                  trainable=False,
                                  dtype=tf.float32)
            tr2 = tf.get_variable('tr2',
                                  initializer=10.0,
                                  trainable=False,
                                  dtype=tf.float32)
            tc1 = tf.get_variable('tc1',
                                  initializer=10.0,
                                  trainable=False,
                                  dtype=tf.float32)
            if cost_maxq:
                tc2 = tf.get_variable('tc2',
                                      initializer=10.0,
                                      trainable=False,
                                      dtype=tf.float32)

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim,
                                 act_dim=act_dim,
                                 size=replay_size)

    # Count variables
    if proc_id() == 0:
        var_counts = tuple(
            count_vars(scope) for scope in
            ['main/pi', 'main/qr1', 'main/qr2', 'main/qc1', 'main'])
        print((
            '\nNumber of parameters: \t pi: %d, \t qr1: %d, \t qr2: %d, \t qc1: %d, \t total: %d\n'
        ) % var_counts)

    # Min Double-Q:
    min_q_pi = tf.minimum(qr1_pi, qr2_pi)
    min_q_pi_targ = tf.minimum(qr1_pi_targ, qr2_pi_targ)

    if cost_maxq:
        max_qc_pi = tf.maximum(qc1_pi, qc2_pi)
        max_qc_pi_targ = tf.maximum(qc1_pi_targ, qc2_pi_targ)
    else:
        max_qc_pi = qc1_pi
        max_qc_pi_targ = qc1_pi_targ

    # Targets for Q and V regression
    q_backup = tf.stop_gradient(r_ph + gamma * (1 - d_ph) *
                                (min_q_pi_targ - alpha * logp_pi2))
    qc_backup = tf.stop_gradient(c_ph + gamma * (1 - d_ph) * max_qc_pi_targ)

    if use_discor:
        er1_backup = tf.stop_gradient(
            tf.abs(qr1 - q_backup) + gamma * (1 - d_ph) * er1_pi_targ)
        er2_backup = tf.stop_gradient(
            tf.abs(qr2 - q_backup) + gamma * (1 - d_ph) * er2_pi_targ)
        ec1_backup = tf.stop_gradient(
            tf.abs(qc1 - qc_backup) + gamma * (1 - d_ph) * ec1_pi_targ)
        if cost_maxq:
            ec2_backup = tf.stop_gradient(
                tf.abs(qc2 - qc_backup) + gamma * (1 - d_ph) * ec2_pi_targ)

        qr1_loss = 0.5 * tf.reduce_sum(
            tf.nn.softmax(er1_backup / tr1, axis=0) * (q_backup - qr1)**2)
        qr2_loss = 0.5 * tf.reduce_sum(
            tf.nn.softmax(er2_backup / tr2, axis=0) * (q_backup - qr2)**2)
        qc1_loss = 0.5 * tf.reduce_sum(
            tf.nn.softmax(ec1_backup / tc1, axis=0) * (qc_backup - qc1)**2)
        if cost_maxq:
            qc2_loss = 0.5 * tf.reduce_sum(
                tf.nn.softmax(ec2_backup / tc2, axis=0) * (qc_backup - qc2)**2)
    else:
        qr1_loss = 0.5 * tf.reduce_mean((q_backup - qr1)**2)
        qr2_loss = 0.5 * tf.reduce_mean((q_backup - qr2)**2)
        qc1_loss = 0.5 * tf.reduce_mean((qc_backup - qc1)**2)
        if cost_maxq:
            qc2_loss = 0.5 * tf.reduce_mean((qc_backup - qc2)**2)
    # Soft actor-critic losses
    q_loss = qr1_loss + qr2_loss + qc1_loss
    if cost_maxq:
        q_loss += qc2_loss
    pi_loss = tf.reduce_mean(alpha * logp_pi - min_q_pi +
                             beta * max_qc_pi) / (1 + beta)

    if use_discor:
        er1_loss = 0.5 * tf.reduce_mean((er1_backup - er1)**2)
        er2_loss = 0.5 * tf.reduce_mean((er2_backup - er2)**2)
        ec1_loss = 0.5 * tf.reduce_mean((ec1_backup - ec1)**2)
        error_loss = er1_loss + er2_loss + ec1_loss
        if cost_maxq:
            ec2_loss = 0.5 * tf.reduce_mean((ec2_backup - ec2)**2)
            error_loss += +ec2_loss
            ec2_mean = tf.reduce_mean(ec2)
        er1_mean = tf.reduce_mean(er1)
        er2_mean = tf.reduce_mean(er2)
        ec1_mean = tf.reduce_mean(ec1)

    # Loss for alpha
    entropy_constraint *= act_dim
    pi_entropy = -tf.reduce_mean(logp_pi)
    # alpha_loss = - soft_alpha * (entropy_constraint - pi_entropy)
    alpha_loss = -alpha * (entropy_constraint - pi_entropy)
    print('using entropy constraint', entropy_constraint)

    # Loss for beta
    if use_costs and not fixed_cost_penalty:
        if cost_constraint is None:
            # Convert assuming equal cost accumulated each step
            # Note this isn't the case, since the early in episode doesn't usually have cost,
            # but since our algorithm optimizes the discounted infinite horizon from each entry
            # in the replay buffer, we should be approximately correct here.
            # It's worth checking empirical total undiscounted costs to see if they match.
            cost_constraint = cost_lim * (1 - gamma**max_ep_len) / (
                1 - gamma) / max_ep_len
        print('using cost constraint', cost_constraint)
        beta_loss = beta * (cost_constraint - qc1)
        #TODO: What is the correct target here?
    # Policy train op
    # (has to be separate from value train op, because qr1_pi appears in pi_loss)
    train_pi_op = MpiAdamOptimizer(learning_rate=lr).minimize(
        pi_loss, var_list=get_vars('main/pi'), name='train_pi')

    # Value train op
    with tf.control_dependencies([train_pi_op]):
        train_q_op = MpiAdamOptimizer(learning_rate=lr).minimize(
            q_loss, var_list=get_vars('main/q'), name='train_q')
    with tf.control_dependencies([train_q_op]):
        if use_discor:
            train_e_op = MpiAdamOptimizer(learning_rate=lr).minimize(
                error_loss, var_list=get_vars('main/e'), name='train_e')
            with tf.control_dependencies([train_e_op]):
                if cost_maxq:
                    train_e_out_op = tf.group([
                        tf.assign(tr1, (1 - polyak) * er1_mean + polyak * tr1),
                        tf.assign(tr2, (1 - polyak) * er2_mean + polyak * tr2),
                        tf.assign(tc1, (1 - polyak) * ec1_mean + polyak * tc1),
                        tf.assign(tc2, (1 - polyak) * ec2_mean + polyak * tc2)
                    ])
                else:
                    train_e_out_op = tf.group([
                        tf.assign(tr1, (1 - polyak) * er1_mean + polyak * tr1),
                        tf.assign(tr2, (1 - polyak) * er2_mean + polyak * tr2),
                        tf.assign(tc1, (1 - polyak) * ec1_mean + polyak * tc1)
                    ])
        else:
            train_e_out_op = tf.no_op()
    if fixed_entropy_bonus is None:
        entreg_optimizer = MpiAdamOptimizer(learning_rate=lr)
        with tf.control_dependencies([train_e_out_op]):
            train_entreg_op = entreg_optimizer.minimize(
                alpha_loss, var_list=get_vars('entreg'))
    if use_costs and fixed_cost_penalty is None:
        costpen_optimizer = MpiAdamOptimizer(learning_rate=penalty_lr)
        with tf.control_dependencies([train_entreg_op]):
            train_costpen_op = costpen_optimizer.minimize(
                beta_loss, var_list=get_vars('costpen'))

    # Polyak averaging for target variables
    target_update = get_target_update('main', 'target', polyak)

    # Single monolithic update with explicit control dependencies
    with tf.control_dependencies([train_pi_op]):
        with tf.control_dependencies([train_q_op]):
            if use_discor:
                with tf.control_dependencies([train_e_op]):
                    with tf.control_dependencies([train_e_out_op]):
                        grouped_update = tf.group([target_update])
            else:
                grouped_update = tf.group([target_update])

    if fixed_entropy_bonus is None:
        grouped_update = tf.group([grouped_update, train_entreg_op])
    if use_costs and fixed_cost_penalty is None:
        grouped_update_a = tf.group([grouped_update, train_costpen_op])

    # Initializing targets to match main variables
    # As a shortcut, use our exponential moving average update w/ coefficient zero
    target_init = get_target_update('main', 'target', 0.0)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    sess.run(target_init)

    # Sync params across processes
    sess.run(sync_all_params())

    # Setup model saving
    logger.setup_tf_saver(sess,
                          inputs={
                              'x': x_ph,
                              'a': a_ph
                          },
                          outputs={
                              'mu': mu,
                              'pi': pi,
                              'qr1': qr1,
                              'qr2': qr2,
                              'qc1': qc1
                          })

    def get_action(o, deterministic=False):
        act_op = mu if deterministic else pi
        return sess.run(act_op, feed_dict={x_ph: o.reshape(1, -1)})[0]

    def test_agent(n=10):
        for j in range(n):
            o, r, d, ep_ret, ep_cost, ep_len, ep_goals, = test_env.reset(
            ), 0, False, 0, 0, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, info = test_env.step(get_action(o, True))
                if render and proc_id() == 0 and j == 0:
                    test_env.render()
                ep_ret += r
                ep_cost += info.get('cost', 0)
                ep_len += 1
                ep_goals += 1 if info.get('goal_met', False) else 0
            logger.store(TestEpRet=ep_ret,
                         TestEpCost=ep_cost,
                         TestEpLen=ep_len,
                         TestEpGoals=ep_goals)

    start_time = time.time()
    o, r, d, ep_ret, ep_cost, ep_len, ep_goals = env.reset(
    ), 0, False, 0, 0, 0, 0
    total_steps = steps_per_epoch * epochs

    # variables to measure in an update
    vars_to_get = dict(LossPi=pi_loss,
                       LossQR1=qr1_loss,
                       LossQR2=qr2_loss,
                       LossQC1=qc1_loss,
                       QR1Vals=qr1,
                       QR2Vals=qr2,
                       QC1Vals=qc1,
                       LogPi=logp_pi,
                       PiEntropy=pi_entropy,
                       Alpha=alpha,
                       LogAlpha=log_alpha,
                       LossAlpha=alpha_loss)
    if use_costs and not fixed_cost_penalty:
        vars_to_get.update(
            dict(Beta=beta, LogBeta=log_beta, LossBeta=beta_loss))
    if use_discor:
        vars_to_get.update(dict(TR1=tr1))

    print('starting training', proc_id())

    # Main loop: collect experience in env and update/log each epoch
    local_steps = 0
    local_steps_per_epoch = steps_per_epoch // num_procs()
    local_batch_size = batch_size // num_procs()
    epoch_start_time = time.time()
    for t in range(total_steps // num_procs()):
        """
        Until local_start_steps have elapsed, randomly sample actions
        from a uniform distribution for better exploration. Afterwards,
        use the learned policy.
        """
        if t > local_start_steps:
            a = get_action(o)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, r, d, info = env.step(a)
        r *= reward_scale  # yee-haw
        c = info.get('cost', 0)
        ep_ret += r
        ep_cost += c
        ep_len += 1
        ep_goals += 1 if info.get('goal_met', False) else 0
        local_steps += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d, c)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        if d or (ep_len == max_ep_len):
            logger.store(EpRet=ep_ret,
                         EpCost=ep_cost,
                         EpLen=ep_len,
                         EpGoals=ep_goals)
            o, r, d, ep_ret, ep_cost, ep_len, ep_goals = env.reset(
            ), 0, False, 0, 0, 0, 0

        if t > 0 and t % update_freq == 0:
            for j in range(update_freq):
                batch = replay_buffer.sample_batch(local_batch_size)
                feed_dict = {
                    x_ph: batch['obs1'],
                    x2_ph: batch['obs2'],
                    a_ph: batch['acts'],
                    r_ph: batch['rews'],
                    c_ph: batch['costs'],
                    d_ph: batch['done'],
                }
                if t < local_update_after:
                    logger.store(**sess.run(vars_to_get, feed_dict))
                else:
                    if (not j == update_freq -
                            1) or not (use_costs and not fixed_cost_penalty):
                        values, _ = sess.run([vars_to_get, grouped_update],
                                             feed_dict)
                        logger.store(**values)
                    else:
                        values, _ = sess.run([vars_to_get, grouped_update_a],
                                             feed_dict)
                        logger.store(**values)

        # End of epoch wrap-up
        if t > 0 and t % local_steps_per_epoch == 0:
            epoch = t // local_steps_per_epoch

            # Save model
            if (epoch % save_freq == 0) or (epoch == epochs - 1):
                logger.save_state({'env': env}, None)

            # Test the performance of the deterministic version of the agent.
            test_start_time = time.time()
            test_agent()
            logger.store(TestTime=time.time() - test_start_time)

            logger.store(EpochTime=time.time() - epoch_start_time)
            epoch_start_time = time.time()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpCost', with_min_and_max=True)
            logger.log_tabular('TestEpCost', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('EpGoals', average_only=True)
            logger.log_tabular('TestEpGoals', average_only=True)
            logger.log_tabular('TotalEnvInteracts', mpi_sum(local_steps))
            logger.log_tabular('QR1Vals', with_min_and_max=True)
            logger.log_tabular('QR2Vals', with_min_and_max=True)
            logger.log_tabular('QC1Vals', with_min_and_max=True)
            logger.log_tabular('LogPi', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQR1', average_only=True)
            logger.log_tabular('LossQR2', average_only=True)
            logger.log_tabular('LossQC1', average_only=True)
            logger.log_tabular('LossAlpha', average_only=True)
            logger.log_tabular('LogAlpha', average_only=True)
            logger.log_tabular('Alpha', average_only=True)
            if use_costs and not fixed_cost_penalty:
                logger.log_tabular('LossBeta', average_only=True)
                logger.log_tabular('LogBeta', average_only=True)
                logger.log_tabular('Beta', average_only=True)
            if use_discor:
                logger.log_tabular('TR1', average_only=True)
            logger.log_tabular('PiEntropy', average_only=True)
            logger.log_tabular('TestTime', average_only=True)
            logger.log_tabular('EpochTime', average_only=True)
            logger.log_tabular('TotalTime', time.time() - start_time)
            logger.dump_tabular()
Example #2
0
def run_polopt_agent(env_fn,
                     agent=PPOAgent(),
                     actor_critic=mlp_actor_critic,
                     ac_kwargs=dict(),
                     seed=0,
                     render=False,
                     # Experience collection:
                     steps_per_epoch=4000,
                     epochs=50,
                     max_ep_len=1000,
                     # Discount factors:
                     gamma=0.99,
                     lam=0.97,
                     cost_gamma=0.99,
                     cost_lam=0.97,
                     # Policy learning:
                     ent_reg=0.,
                     # Cost constraints / penalties:
                     cost_lim=25,
                     penalty_init=1.,
                     penalty_lr=5e-2,
                     # KL divergence:
                     target_kl=0.01,
                     # Value learning:
                     vf_lr=1e-3,
                     vf_iters=80,
                     # Logging:
                     logger=None,
                     logger_kwargs=dict(),
                     save_freq=1
                     ):


    #=========================================================================#
    #  Prepare logger, seed, and environment in this process                  #
    #=========================================================================#

    logger = EpochLogger(**logger_kwargs) if logger is None else logger
    logger.save_config(locals())

    seed += 10000 * proc_id()
    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = env_fn()

    agent.set_logger(logger)

    #=========================================================================#
    #  Create computation graph for actor and critic (not training routine)   #
    #=========================================================================#

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Inputs to computation graph from environment spaces
    x_ph, a_ph = placeholders_from_spaces(env.observation_space, env.action_space)

    # Inputs to computation graph for batch data
    adv_ph, cadv_ph, ret_ph, cret_ph, logp_old_ph = placeholders(*(None for _ in range(5)))

    # Inputs to computation graph for special purposes
    surr_cost_rescale_ph = tf.compat.v1.placeholder(tf.float32, shape=())
    cur_cost_ph = tf.compat.v1.placeholder(tf.float32, shape=())

    # Outputs from actor critic
    ac_outs = actor_critic(x_ph, a_ph, **ac_kwargs)
    pi, logp, logp_pi, pi_info, pi_info_phs, d_kl, ent, v, vc = ac_outs

    # Organize placeholders for zipping with data from buffer on updates
    buf_phs = [x_ph, a_ph, adv_ph, cadv_ph, ret_ph, cret_ph, logp_old_ph]
    buf_phs += values_as_sorted_list(pi_info_phs)

    # Organize symbols we have to compute at each step of acting in env
    get_action_ops = dict(pi=pi,
                          v=v,
                          logp_pi=logp_pi,
                          pi_info=pi_info)

    # If agent is reward penalized, it doesn't use a separate value function
    # for costs and we don't need to include it in get_action_ops; otherwise we do.
    if not(agent.reward_penalized):
        get_action_ops['vc'] = vc

    # Count variables
    var_counts = tuple(count_vars(scope) for scope in ['pi', 'vf', 'vc'])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t vc: %d\n'%var_counts)

    # Make a sample estimate for entropy to use as sanity check
    approx_ent = tf.reduce_mean(-logp)


    #=========================================================================#
    #  Create replay buffer                                                   #
    #=========================================================================#

    # Obs/act shapes
    obs_shape = env.observation_space.shape
    act_shape = env.action_space.shape

    # Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    pi_info_shapes = {k: v.shape.as_list()[1:] for k,v in pi_info_phs.items()}
    buf = CPOBuffer(local_steps_per_epoch,
                    obs_shape,
                    act_shape,
                    pi_info_shapes,
                    gamma,
                    lam,
                    cost_gamma,
                    cost_lam)


    #=========================================================================#
    #  Create computation graph for penalty learning, if applicable           #
    #=========================================================================#

    if agent.use_penalty:
        with tf.compat.v1.variable_scope('penalty'):
            # param_init = np.log(penalty_init)
            param_init = np.log(max(np.exp(penalty_init)-1, 1e-8))
            penalty_param = tf.compat.v1.get_variable('penalty_param',
                                          initializer=float(param_init),
                                          trainable=agent.learn_penalty,
                                          dtype=tf.float32)
        # penalty = tf.exp(penalty_param)
        penalty = tf.nn.softplus(penalty_param)

    if agent.learn_penalty:
        if agent.penalty_param_loss:
            penalty_loss = -penalty_param * (cur_cost_ph - cost_lim)
        else:
            penalty_loss = -penalty * (cur_cost_ph - cost_lim)
        train_penalty = MpiAdamOptimizer(learning_rate=penalty_lr).minimize(penalty_loss)


    #=========================================================================#
    #  Create computation graph for policy learning                           #
    #=========================================================================#

    # Likelihood ratio
    ratio = tf.exp(logp - logp_old_ph)

    # Surrogate advantage / clipped surrogate advantage
    if agent.clipped_adv:
        min_adv = tf.where(adv_ph>0,
                           (1+agent.clip_ratio)*adv_ph,
                           (1-agent.clip_ratio)*adv_ph
                           )
        surr_adv = tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
    else:
        surr_adv = tf.reduce_mean(ratio * adv_ph)

    # Surrogate cost
    surr_cost = tf.reduce_mean(ratio * cadv_ph)

    # Create policy objective function, including entropy regularization
    pi_objective = surr_adv + ent_reg * ent

    # Possibly include surr_cost in pi_objective
    if agent.objective_penalized:
        pi_objective -= penalty * surr_cost
        pi_objective /= (1 + penalty)

    # Loss function for pi is negative of pi_objective
    pi_loss = -pi_objective

    # Optimizer-specific symbols
    if agent.trust_region:

        # Symbols needed for CG solver for any trust region method
        pi_params = get_vars('pi')
        flat_g = tro.flat_grad(pi_loss, pi_params)
        v_ph, hvp = tro.hessian_vector_product(d_kl, pi_params)
        if agent.damping_coeff > 0:
            hvp += agent.damping_coeff * v_ph

        # Symbols needed for CG solver for CPO only
        flat_b = tro.flat_grad(surr_cost, pi_params)

        # Symbols for getting and setting params
        get_pi_params = tro.flat_concat(pi_params)
        set_pi_params = tro.assign_params_from_flat(v_ph, pi_params)

        training_package = dict(flat_g=flat_g,
                                flat_b=flat_b,
                                v_ph=v_ph,
                                hvp=hvp,
                                get_pi_params=get_pi_params,
                                set_pi_params=set_pi_params)

    elif agent.first_order:

        # Optimizer for first-order policy optimization
        train_pi = MpiAdamOptimizer(learning_rate=agent.pi_lr).minimize(pi_loss)

        # Prepare training package for agent
        training_package = dict(train_pi=train_pi)

    else:
        raise NotImplementedError

    # Provide training package to agent
    training_package.update(dict(pi_loss=pi_loss,
                                 surr_cost=surr_cost,
                                 d_kl=d_kl,
                                 target_kl=target_kl,
                                 cost_lim=cost_lim))
    agent.prepare_update(training_package)

    #=========================================================================#
    #  Create computation graph for value learning                            #
    #=========================================================================#

    # Value losses
    v_loss = tf.reduce_mean((ret_ph - v)**2)
    vc_loss = tf.reduce_mean((cret_ph - vc)**2)

    # If agent uses penalty directly in reward function, don't train a separate
    # value function for predicting cost returns. (Only use one vf for r - p*c.)
    if agent.reward_penalized:
        total_value_loss = v_loss
    else:
        total_value_loss = v_loss + vc_loss

    # Optimizer for value learning
    train_vf = MpiAdamOptimizer(learning_rate=vf_lr).minimize(total_value_loss)


    #=========================================================================#
    #  Create session, sync across procs, and set up saver                    #
    #=========================================================================#

    sess = tf.compat.v1.Session()
    sess.run(tf.global_variables_initializer())

    # Sync params across processes
    sess.run(sync_all_params())

    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x': x_ph}, outputs={'pi': pi, 'v': v, 'vc': vc})


    #=========================================================================#
    #  Provide session to agent                                               #
    #=========================================================================#
    agent.prepare_session(sess)


    #=========================================================================#
    #  Create function for running update (called at end of each epoch)       #
    #=========================================================================#

    def update():
        cur_cost = logger.get_stats('EpCost')[0]
        c = cur_cost - cost_lim
        if c > 0 and agent.cares_about_cost:
            logger.log('Warning! Safety constraint is already violated.', 'red')

        #=====================================================================#
        #  Prepare feed dict                                                  #
        #=====================================================================#

        inputs = {k:v for k,v in zip(buf_phs, buf.get())}
        inputs[surr_cost_rescale_ph] = logger.get_stats('EpLen')[0]
        inputs[cur_cost_ph] = cur_cost

        #=====================================================================#
        #  Make some measurements before updating                             #
        #=====================================================================#

        measures = dict(LossPi=pi_loss,
                        SurrCost=surr_cost,
                        LossV=v_loss,
                        Entropy=ent)
        if not(agent.reward_penalized):
            measures['LossVC'] = vc_loss
        if agent.use_penalty:
            measures['Penalty'] = penalty

        pre_update_measures = sess.run(measures, feed_dict=inputs)
        logger.store(**pre_update_measures)

        #=====================================================================#
        #  Update penalty if learning penalty                                 #
        #=====================================================================#
        if agent.learn_penalty:
            sess.run(train_penalty, feed_dict={cur_cost_ph: cur_cost})

        #=====================================================================#
        #  Update policy                                                      #
        #=====================================================================#
        agent.update_pi(inputs)

        #=====================================================================#
        #  Update value function                                              #
        #=====================================================================#
        for _ in range(vf_iters):
            sess.run(train_vf, feed_dict=inputs)

        #=====================================================================#
        #  Make some measurements after updating                              #
        #=====================================================================#

        del measures['Entropy']
        measures['KL'] = d_kl

        post_update_measures = sess.run(measures, feed_dict=inputs)
        deltas = dict()
        for k in post_update_measures:
            if k in pre_update_measures:
                deltas['Delta'+k] = post_update_measures[k] - pre_update_measures[k]
        logger.store(KL=post_update_measures['KL'], **deltas)




    #=========================================================================#
    #  Run main environment interaction loop                                  #
    #=========================================================================#

    start_time = time.time()
    o, r, d, c, ep_ret, ep_cost, ep_len = env.reset(), 0, False, 0, 0, 0, 0
    cur_penalty = 0
    cum_cost = 0

    for epoch in range(epochs):

        if agent.use_penalty:
            cur_penalty = sess.run(penalty)

        # will gather 30,000 state, action, next state
        for t in range(local_steps_per_epoch):

            # Possibly render
            if render and proc_id()==0 and t < 1000:
                env.render()

            # Get outputs from policy
            get_action_outs = sess.run(get_action_ops,
                                       feed_dict={x_ph: o[np.newaxis]})
            a = get_action_outs['pi']
            v_t = get_action_outs['v']
            vc_t = get_action_outs.get('vc', 0)  # Agent may not use cost value func
            logp_t = get_action_outs['logp_pi']
            pi_info_t = get_action_outs['pi_info']

            # Step in environment
            o2, r, d, info = env.step(a)

            # Include penalty on cost
            c = info.get('cost', 0)

            # Track cumulative cost over training
            cum_cost += c

            # save and log
            if agent.reward_penalized:
                r_total = r - cur_penalty * c
                r_total = r_total / (1 + cur_penalty)
                buf.store(o, a, r_total, v_t, 0, 0, logp_t, pi_info_t)
            else:
                buf.store(o, a, r, v_t, c, vc_t, logp_t, pi_info_t)
            logger.store(VVals=v_t, CostVVals=vc_t)

            o = o2
            ep_ret += r
            ep_cost += c
            ep_len += 1

    """
    t=0                                                        t = 30,000
    |              30,000 local_steps_per_epoch                |
    | ep |  ep    |     ep   |      ep    |    ep    |    ep   |
                                max 1000
    """
            # reach the goal or hit max env timesteps (1000)
            terminal = (d or (ep_len == max_ep_len))
            if terminal or (t==local_steps_per_epoch-1):

                # If trajectory didn't reach terminal state, bootstrap value target(s)
                if d and not(ep_len == max_ep_len):
                    # Note: we do not count env time out as true terminal state
                    last_val, last_cval = 0, 0
                else:
                    feed_dict={x_ph: o[np.newaxis]}
                    if agent.reward_penalized:
                        last_val = sess.run(v, feed_dict=feed_dict)
                        last_cval = 0
                    else:
                        last_val, last_cval = sess.run([v, vc], feed_dict=feed_dict)
                buf.finish_path(last_val, last_cval)

                # Only save EpRet / EpLen if trajectory finished
                if terminal:
                    logger.store(EpRet=ep_ret, EpLen=ep_len, EpCost=ep_cost)
                else:
                    print('Warning: trajectory cut off by epoch at %d steps.'%ep_len)

                # Reset environment
                o, r, d, c, ep_ret, ep_len, ep_cost = env.reset(), 0, False, 0, 0, 0, 0
Example #3
0
def run_polopt_agent(
    env_fn,
    agent=PPOAgent(),
    actor_critic=actor_critic,
    ac_kwargs=dict(),
    seed=0,
    render=False,
    # Experience collection:
    steps_per_epoch=4000,
    epochs=50,
    max_ep_len=1000,
    # Discount factors:
    gamma=0.99,
    lam=0.97,
    cost_gamma=0.99,
    cost_lam=0.97,
    # Policy learning:
    ent_reg=0.0,
    # Cost constraints / penalties:
    cost_lim=25,
    penalty_init=1.0,
    penalty_lr=5e-2,
    # KL divergence:
    target_kl=0.01,
    # Value learning:
    vf_lr=1e-3,
    vf_iters=80,
    # Logging:
    logger=None,
    logger_kwargs=dict(),
    save_freq=1,
    visual_obs=False,
    safety_checks=False,
    sym_features=False,
    env_name="",
    verbose=False,
    log_params=None,
    n_envs=6,
    discretize=False,
):

    oracle = True
    assert not discretize, "not yet supported; have to change the loss function too?"

    sym_features = False
    global IMG_SIZE
    if not (safety_checks or sym_features):
        IMG_SIZE = IMG_RESIZE

    # =========================================================================#
    #  Prepare logger, seed, and environment in this process                  #
    # =========================================================================#

    logger = EpochLogger(**logger_kwargs) if logger is None else logger
    logger.save_config(locals())

    tf.set_random_seed(seed)
    np.random.seed(seed)

    ray.init()

    if safety_checks or sym_features:
        if not oracle:
            device = torch.device("cuda")
            model = torch.jit.load(
                "/srl/models/model_0166d3228ffa4cb0a55a7c7c696e43b7_final.zip")
            model = model.to(device).eval()
            sym_map = SymMapCenterNet(model, device)

    @ray.remote
    class RemoteEnv:
        def __init__(self, env, visual_obs: bool, safety_checks: bool):
            if visual_obs or safety_checks:
                self.visual_env = VisionWrapper(env, IMG_SIZE, IMG_SIZE)
                if self.visual_env.viewer is None:
                    self.visual_env.reset()
                    self.visual_env._make_viewer()
            if visual_obs:
                env = self.visual_env

            self.env = env
            self.state = State()
            self.n_unsafe_allowed = 0
            self.safety_checks = safety_checks
            self.visual_obs = visual_obs
            self.sym_map = SymMapOracle(env)

        def reset(self):
            i = 0
            while True:
                obs = self.env.reset()
                robot_position, speed, robot_direction, obstacles = self.sym_map(
                )
                self.state.robot_position = robot_position
                self.state.robot_velocity = speed
                self.state.robot_direction = robot_direction
                self.state.obstacles = obstacles
                if self.state.is_safe_action(0, 0):
                    break  # found a safe starting position.
                i += 1
                if i > 100:
                    print(
                        "proceeding with an unsafe starting position.... why is it so"
                        " hard to find a safe way to start life?!")

            if self.safety_checks and not self.visual_obs:
                # have to render still for safety
                visual_obs = self.visual_env._render()
                return obs, visual_obs
            else:
                return obs, None

        def get_oracle_features(self):
            return self.sym_map()

        def get_n_unsafe_allowed(self):
            return self.n_unsafe_allowed

        def step(
            self,
            mu,
            log_std,
            oracle=False,
            robot_position=None,
            robot_direction=None,
            obstacles=None,
        ):
            std = np.exp(log_std)
            action = mu + np.random.normal(scale=std, size=mu.shape)

            if self.safety_checks:
                if oracle:
                    robot_position, speed, robot_direction, obstacles = self.sym_map(
                    )
                else:
                    vel = self.env.world.data.get_body_xvelp("robot")
                    speed = math.sqrt((vel[0]**2 + vel[1]**2))

                self.state.robot_position = robot_position
                self.state.robot_velocity = speed
                self.state.robot_direction = robot_direction
                self.state.obstacles = obstacles

                # TODO - better to discretize or to use sampling?
                # discretization might help if the probability of safe actions is
                # very low
                thresh = 100
                n_attempts = 0
                while not self.state.is_safe_action(*action):
                    action = mu + np.random.normal(scale=std, size=mu.shape)
                    n_attempts += 1
                    if n_attempts >= thresh:
                        # self.n_unsafe_allowed += 1
                        try:
                            action = self.state.find_safe_action()
                        except:
                            action = self.state.safe_fallback_action()
                            # Note: you can set this flag to true in order to get more info about the fact that the safe fallback is not actually safe.
                            if not self.state.is_safe_action(*action, False):
                                self.n_unsafe_allowed += 1
                                break
                                # print(f"allowing an unsafe action: {self.state.robot_position} {self.state.robot_velocity} {self.state.obstacles}\n")
                        # assert False, "No safe action found."

            eps = 1e-10
            pre_sum = -0.5 * (
                ((action - mu) /
                 (std + eps))**2 + 2 * log_std + np.log(2 * np.pi))
            log_p = pre_sum.sum()

            if self.safety_checks and not self.visual_obs:
                visual_obs = self.visual_env._render()
            else:
                visual_obs = None

            return (*self.env.step(action), action, log_p, visual_obs)

    envs = [env_fn() for _ in range(n_envs)]
    envs = [RemoteEnv.remote(env, visual_obs, safety_checks) for env in envs]

    # one extra to more easily get shapes, etc.
    env = env_fn()
    if visual_obs:
        env = VisionWrapper(env, IMG_SIZE, IMG_SIZE)
    if discretize:
        n_bins = 20
        action_space = gym.spaces.MultiDiscrete((n_bins, n_bins))
    else:
        action_space = env.action_space

    range_ = lambda *args, **kwargs: trange(*args, leave=False, **kwargs)
    exp = comet_ml.Experiment(log_env_gpu=False, log_env_cpu=False)
    exp.add_tag("crl")

    if exp:
        if "Point" in env_name:
            robot_type = "Point"
        elif "Car" in env_name:
            robot_type = "Car"
        elif "Doggo" in env_name:
            robot_type = "Doggo"
        else:
            assert False
        task = (env_name.replace("-v0",
                                 "").replace("Safexp-",
                                             "").replace(robot_type, ""))
        task, difficulty = task[:-1], task[-1]

        exp.log_parameters({
            "robot": robot_type,
            "task": task,
            "difficulty": difficulty,
            "model": "cnn0" if visual_obs else "mlp",
            "use_vision": visual_obs,
            "steps_per_epoch": steps_per_epoch,
            "vf_iters": vf_iters,
            "reduced_obstacles": True,
            "cost_lim": cost_lim,
            "oracle": oracle,
        })
        if log_params:
            exp.log_parameters(log_params)

    agent.set_logger(logger)

    # =========================================================================#
    #  Create computation graph for actor and critic (not training routine)   #
    # =========================================================================#

    # Share information about action space with policy architecture
    ac_kwargs["action_space"] = action_space

    if visual_obs:
        ac_kwargs["net_type"] = "cnn"

    # Inputs to computation graph from environment spaces
    if visual_obs:
        a_ph = placeholder_from_space(action_space)
        x_ph = tf.placeholder(dtype=tf.float32,
                              shape=(None, IMG_RESIZE, IMG_RESIZE, 3))
    else:
        x_ph, a_ph = placeholders_from_spaces(env.observation_space,
                                              action_space)

    # Inputs to computation graph for batch data
    adv_ph, cadv_ph, ret_ph, cret_ph, logp_old_ph = placeholders(
        *(None for _ in range(5)))

    # Inputs to computation graph for special purposes
    surr_cost_rescale_ph = tf.placeholder(tf.float32, shape=())
    cur_cost_ph = tf.placeholder(tf.float32, shape=())

    # Outputs from actor critic
    ac_outs = actor_critic(x_ph, a_ph, **ac_kwargs)
    pi, logp, logp_pi, pi_info, pi_info_phs, d_kl, ent, v, vc = ac_outs

    # Organize placeholders for zipping with data from buffer on updates
    buf_phs = [x_ph, a_ph, adv_ph, cadv_ph, ret_ph, cret_ph, logp_old_ph]
    buf_phs += values_as_sorted_list(pi_info_phs)

    # Organize symbols we have to compute at each step of acting in env
    get_action_ops = dict(pi=pi, v=v, logp_pi=logp_pi, pi_info=pi_info)

    # If agent is reward penalized, it doesn't use a separate value function
    # for costs and we don't need to include it in get_action_ops; otherwise we do.
    if not (agent.reward_penalized):
        get_action_ops["vc"] = vc

    # Count variables
    var_counts = tuple(count_vars(scope) for scope in ["pi", "vf", "vc"])
    logger.log("\nNumber of parameters: \t pi: %d, \t v: %d, \t vc: %d\n" %
               var_counts)

    # Make a sample estimate for entropy to use as sanity check
    approx_ent = tf.reduce_mean(-logp)

    # =========================================================================#
    #  Create replay buffer                                                   #
    # =========================================================================#

    # Obs/act shapes
    if visual_obs:
        obs_shape = (IMG_RESIZE, IMG_RESIZE, 3)
    else:
        obs_shape = env.observation_space.shape
    act_shape = action_space.shape

    # Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / n_envs)
    pi_info_shapes = {k: v.shape.as_list()[1:] for k, v in pi_info_phs.items()}
    bufs = [
        CPOBuffer(
            local_steps_per_epoch,
            obs_shape,
            act_shape,
            pi_info_shapes,
            gamma,
            lam,
            cost_gamma,
            cost_lam,
        ) for _ in range(n_envs)
    ]

    # =========================================================================#
    #  Create computation graph for penalty learning, if applicable           #
    # =========================================================================#

    if agent.use_penalty:
        with tf.variable_scope("penalty"):
            # param_init = np.log(penalty_init)
            param_init = np.log(max(np.exp(penalty_init) - 1, 1e-8))
            penalty_param = tf.get_variable(
                "penalty_param",
                initializer=float(param_init),
                trainable=agent.learn_penalty,
                dtype=tf.float32,
            )
        # penalty = tf.exp(penalty_param)
        penalty = tf.nn.softplus(penalty_param)

    if agent.learn_penalty:
        if agent.penalty_param_loss:
            penalty_loss = -penalty_param * (cur_cost_ph - cost_lim)
        else:
            penalty_loss = -penalty * (cur_cost_ph - cost_lim)
        # train_penalty = MpiAdamOptimizer(learning_rate=penalty_lr).minimize(penalty_loss)
        train_penalty = tf.train.AdamOptimizer(
            learning_rate=penalty_lr).minimize(penalty_loss)

    # =========================================================================#
    #  Create computation graph for policy learning                           #
    # =========================================================================#

    # Likelihood ratio
    ratio = tf.exp(logp - logp_old_ph)

    # Surrogate advantage / clipped surrogate advantage
    if agent.clipped_adv:
        min_adv = tf.where(adv_ph > 0, (1 + agent.clip_ratio) * adv_ph,
                           (1 - agent.clip_ratio) * adv_ph)
        surr_adv = tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
    else:
        surr_adv = tf.reduce_mean(ratio * adv_ph)

    # Surrogate cost
    surr_cost = tf.reduce_mean(ratio * cadv_ph)

    # Create policy objective function, including entropy regularization
    pi_objective = surr_adv + ent_reg * ent

    # Possibly include surr_cost in pi_objective
    if agent.objective_penalized:
        pi_objective -= penalty * surr_cost
        pi_objective /= 1 + penalty

    # Loss function for pi is negative of pi_objective
    pi_loss = -pi_objective

    # Optimizer-specific symbols
    if agent.trust_region:

        # Symbols needed for CG solver for any trust region method
        pi_params = get_vars("pi")
        flat_g = tro.flat_grad(pi_loss, pi_params)
        v_ph, hvp = tro.hessian_vector_product(d_kl, pi_params)
        if agent.damping_coeff > 0:
            hvp += agent.damping_coeff * v_ph

        # Symbols needed for CG solver for CPO only
        flat_b = tro.flat_grad(surr_cost, pi_params)

        # Symbols for getting and setting params
        get_pi_params = tro.flat_concat(pi_params)
        set_pi_params = tro.assign_params_from_flat(v_ph, pi_params)

        training_package = dict(
            flat_g=flat_g,
            flat_b=flat_b,
            v_ph=v_ph,
            hvp=hvp,
            get_pi_params=get_pi_params,
            set_pi_params=set_pi_params,
        )

    elif agent.first_order:

        # Optimizer for first-order policy optimization
        # train_pi = MpiAdamOptimizer(learning_rate=agent.pi_lr).minimize(pi_loss)
        train_pi = tf.train.AdamOptimizer(
            learning_rate=agent.pi_lr).minimize(pi_loss)

        # Prepare training package for agent
        training_package = dict(train_pi=train_pi)

    else:
        raise NotImplementedError

    # Provide training package to agent
    training_package.update(
        dict(
            pi_loss=pi_loss,
            surr_cost=surr_cost,
            d_kl=d_kl,
            target_kl=target_kl,
            cost_lim=cost_lim,
        ))
    agent.prepare_update(training_package)

    # =========================================================================#
    #  Create computation graph for value learning                            #
    # =========================================================================#

    # Value losses
    v_loss = tf.reduce_mean((ret_ph - v)**2)
    vc_loss = tf.reduce_mean((cret_ph - vc)**2)

    # If agent uses penalty directly in reward function, don't train a separate
    # value function for predicting cost returns. (Only use one vf for r - p*c.)
    if agent.reward_penalized:
        total_value_loss = v_loss
    else:
        total_value_loss = v_loss + vc_loss

    # Optimizer for value learning
    # train_vf = MpiAdamOptimizer(learning_rate=vf_lr).minimize(total_value_loss)
    train_vf = tf.train.AdamOptimizer(
        learning_rate=vf_lr).minimize(total_value_loss)

    # =========================================================================#
    #  Create session, sync across procs, and set up saver                    #
    # =========================================================================#

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # Sync params across processes
    # sess.run(sync_all_params())

    # Setup model saving
    logger.setup_tf_saver(sess,
                          inputs={"x": x_ph},
                          outputs={
                              "pi": pi,
                              "v": v,
                              "vc": vc
                          })

    # =========================================================================#
    #  Provide session to agent                                               #
    # =========================================================================#
    agent.prepare_session(sess)

    # =========================================================================#
    #  Create function for running update (called at end of each epoch)       #
    # =========================================================================#

    def update():
        # TODO!!! - is this the correct epcost...
        cur_cost = logger.get_stats("EpCost")[0]
        c = cur_cost - cost_lim
        if c > 0 and agent.cares_about_cost:
            if verbose:
                logger.log("Warning! Safety constraint is already violated.",
                           "red")

        # =====================================================================#
        #  Prepare feed dict                                                  #
        # =====================================================================#

        inputs = {}
        inputs[surr_cost_rescale_ph] = logger.get_stats("EpLen")[0]
        inputs[cur_cost_ph] = cur_cost

        buf_inputs = [buf.get() for buf in bufs]
        if visual_obs:
            splits = 2
        else:
            splits = 1
        for j in range(splits):
            for i, ph in enumerate(buf_phs):
                inputs[ph] = np.concatenate(
                    [buf_input[i][j::splits] for buf_input in buf_inputs])

            # =====================================================================#
            #  Make some measurements before updating                             #
            # =====================================================================#

            measures = dict(LossPi=pi_loss,
                            SurrCost=surr_cost,
                            LossV=v_loss,
                            Entropy=ent)
            if not (agent.reward_penalized):
                measures["LossVC"] = vc_loss
            if agent.use_penalty:
                measures["Penalty"] = penalty

            pre_update_measures = sess.run(measures, feed_dict=inputs)
            logger.store(**pre_update_measures)

            # =====================================================================#
            #  Update penalty if learning penalty                                 #
            # =====================================================================#
            if agent.learn_penalty:
                sess.run(train_penalty, feed_dict={cur_cost_ph: cur_cost})

            # =====================================================================#
            #  Update policy                                                      #
            # =====================================================================#
            agent.update_pi(inputs)

            # =====================================================================#
            #  Update value function                                              #
            # =====================================================================#
            for _ in range(vf_iters):
                sess.run(train_vf, feed_dict=inputs)

            # =====================================================================#
            #  Make some measurements after updating                              #
            # =====================================================================#

            del measures["Entropy"]
            measures["KL"] = d_kl

            post_update_measures = sess.run(measures, feed_dict=inputs)
            deltas = dict()
            for k in post_update_measures:
                if k in pre_update_measures:
                    deltas["Delta" + k] = (post_update_measures[k] -
                                           pre_update_measures[k])
            logger.store(KL=post_update_measures["KL"], **deltas)

    # =========================================================================#
    #  Run main environment interaction loop                                  #
    # =========================================================================#

    start_time = time.time()
    rs = np.zeros(n_envs)
    ds = [False] * n_envs
    cs = np.zeros(n_envs)
    ep_rets = np.zeros(n_envs)
    ep_costs = np.zeros(n_envs)
    ep_lens = np.zeros(n_envs)
    vc_t0 = np.zeros(n_envs)

    os = []
    visual_os = []
    for o, visual_o in ray.get([env.reset.remote() for env in envs]):
        os.append(o)
        if safety_checks and not visual_obs:
            visual_os.append(visual_o)
    os = np.stack(os)
    if safety_checks and not visual_obs:
        visual_os = np.stack(visual_os)

    cur_penalty = 0
    cum_cost = 0

    n_unsafe = 0
    n_unsafe_allowed = 0
    for epoch in range_(epochs):

        if agent.use_penalty:
            cur_penalty = sess.run(penalty)

        for t in range_(local_steps_per_epoch):

            # Possibly render
            # if render and rank == 0 and t < 1000:
            #     env.render()

            if safety_checks or sym_features:
                if visual_obs:
                    if not oracle:
                        robot_position, robot_direction, obstacles = sym_map(
                            os)
                    os = np.stack([
                        np.array(
                            Image.fromarray((o * 255).astype(np.uint8)).resize(
                                (IMG_RESIZE, IMG_RESIZE), resample=4))
                        for o in os
                    ])
                else:
                    if not oracle:
                        robot_position, robot_direction, obstacles = sym_map(
                            visual_os)

            # Get outputs from policy
            get_action_outs = sess.run(get_action_ops, feed_dict={x_ph: os})
            a = get_action_outs["pi"]
            v_t = get_action_outs["v"]
            vc_t = get_action_outs.get(
                "vc", vc_t0)  # Agent may not use cost value func
            logp_t = get_action_outs["logp_pi"]
            pi_info_t = get_action_outs["pi_info"]
            mu = pi_info_t["mu"]
            log_std = pi_info_t["log_std"]
            pi_info_t = [{
                "mu": mu[i:i + 1],
                "log_std": log_std
            } for i in range(n_envs)]

            # Step in environment

            args = []
            for i in range(n_envs):
                if safety_checks:
                    if oracle:
                        args.append((mu[i], log_std, oracle))
                    else:
                        args.append((
                            mu[i],
                            log_std,
                            oracle,
                            robot_position[i],
                            robot_direction[i],
                            obstacles[obstacles[:, 0] == i, 1:],
                        ))
                else:
                    args.append((mu[i], log_std, oracle))

            # could consider using ray.wait and handling each env separately. since we use
            # a for loop for much of the computation below anyway, this would probably
            # be faster (time before + after)
            o2s, rs, ds, infos, actions, logps, visual_os = zip(*ray.get(
                [env.step.remote(*arg) for env, arg in zip(envs, args)]))
            a[:] = actions  # new actions
            logp_t[:] = logps  # new log ps
            o2s = np.stack(o2s)
            if safety_checks and not visual_obs:
                visual_os = np.stack(visual_os)
            rs = np.array(rs)

            # Include penalty on cost
            cs = np.array([info.get("cost", 0) for info in infos])

            # Track cumulative cost over training
            n_unsafe += (cs > 0).sum()
            cum_cost += cs.sum()

            # save and log
            if agent.reward_penalized:
                r_totals = rs - cur_penalty * cs
                r_totals = r_totals / (1 + cur_penalty)
                for i, buf in enumerate(bufs):
                    buf.store(os[i], a[i], r_totals[i], v_t[i], 0, 0,
                              logp_t[i], pi_info_t[i])
            else:
                for i, buf in enumerate(bufs):
                    buf.store(
                        os[i],
                        a[i],
                        rs[i],
                        v_t[i],
                        cs[i],
                        vc_t[i],
                        logp_t[i],
                        pi_info_t[i],
                    )
            # TODO - what values to use here??
            logger.store(VVals=v_t[0], CostVVals=vc_t[0])

            os = o2s
            ep_rets += rs
            ep_costs += cs
            ep_lens += 1

            for i, buf in enumerate(bufs):
                ep_len = ep_lens[i]
                d = ds[i]
                terminal = d or (ep_len == max_ep_len)
                if terminal or (t == local_steps_per_epoch - 1):
                    # start resetting environment now; get results later
                    reset_id = envs[i].reset.remote()

                    # If trajectory didn't reach terminal state, bootstrap value target(s)
                    if d and not (ep_len == max_ep_len):
                        # Note: we do not count env time out as true terminal state
                        last_val, last_cval = 0, 0
                    else:
                        if visual_obs:
                            o = np.array(
                                Image.fromarray(
                                    (os[i] * 255).astype(np.uint8)).resize(
                                        (IMG_RESIZE, IMG_RESIZE), resample=4))
                            print(
                                "check o's dtype; make float32. Make necessary changes"
                                " after calling sym_map(os) too.")
                            breakpoint()
                        else:
                            o = os[i]
                        feed_dict = {x_ph: o[None]}
                        if agent.reward_penalized:
                            last_val = sess.run(v, feed_dict=feed_dict)
                            last_cval = 0
                        else:
                            last_val, last_cval = sess.run([v, vc],
                                                           feed_dict=feed_dict)
                    buf.finish_path(last_val, last_cval)

                    # Only save EpRet / EpLen if trajectory finished
                    if terminal:
                        ep_ret = ep_rets[i]
                        ep_cost = ep_costs[i]
                        logger.store(EpRet=ep_ret,
                                     EpLen=ep_len,
                                     EpCost=ep_cost)
                        if exp:
                            exp.log_metrics(
                                {
                                    "return": ep_ret,
                                    "episode_length": ep_len,
                                    "cost": ep_cost,
                                },
                                step=epoch * steps_per_epoch + t,
                            )
                    else:
                        if verbose:
                            print(
                                "Warning: trajectory cut off by epoch at %d steps."
                                % ep_len)

                    o, visual_o = ray.get(reset_id)
                    os[i] = o
                    if safety_checks and not visual_obs:
                        visual_os[i] = visual_o
                    rs[i] = 0
                    # ds[i] = False
                    cs[i] = 0
                    ep_rets[i] = 0
                    ep_lens[i] = 0
                    ep_costs[i] = 0

        cost_rate = cum_cost / ((epoch + 1) * steps_per_epoch)

        n_unsafe_allowed += sum(
            ray.get([env.get_n_unsafe_allowed.remote() for env in envs]))
        exp.log_metrics(
            {
                "n_unsafe_allowed": n_unsafe_allowed,
                "n_unsafe": n_unsafe,
                "cum_cost": cum_cost,
                "cost_rate": cost_rate,
            },
            step=epoch * steps_per_epoch + t,
        )

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({"env": env}, None)

        # =====================================================================#
        #  Run RL update                                                      #
        # =====================================================================#
        update()

        # =====================================================================#
        #  Log performance and stats                                          #
        # =====================================================================#

        logger.log_tabular("Epoch", epoch)

        # Performance stats
        logger.log_tabular("EpRet", with_min_and_max=True)
        logger.log_tabular("EpCost", with_min_and_max=True)
        logger.log_tabular("EpLen", average_only=True)
        logger.log_tabular("CumulativeCost", cum_cost)
        logger.log_tabular("CostRate", cost_rate)

        # Value function values
        logger.log_tabular("VVals", with_min_and_max=True)
        logger.log_tabular("CostVVals", with_min_and_max=True)

        # Pi loss and change
        logger.log_tabular("LossPi", average_only=True)
        logger.log_tabular("DeltaLossPi", average_only=True)

        # Surr cost and change
        logger.log_tabular("SurrCost", average_only=True)
        logger.log_tabular("DeltaSurrCost", average_only=True)

        # V loss and change
        logger.log_tabular("LossV", average_only=True)
        logger.log_tabular("DeltaLossV", average_only=True)

        # Vc loss and change, if applicable (reward_penalized agents don't use vc)
        if not (agent.reward_penalized):
            logger.log_tabular("LossVC", average_only=True)
            logger.log_tabular("DeltaLossVC", average_only=True)

        if agent.use_penalty or agent.save_penalty:
            logger.log_tabular("Penalty", average_only=True)
            logger.log_tabular("DeltaPenalty", average_only=True)
        else:
            logger.log_tabular("Penalty", 0)
            logger.log_tabular("DeltaPenalty", 0)

        # Anything from the agent?
        agent.log()

        # Policy stats
        logger.log_tabular("Entropy", average_only=True)
        logger.log_tabular("KL", average_only=True)

        # Time and steps elapsed
        logger.log_tabular("TotalEnvInteracts", (epoch + 1) * steps_per_epoch)
        logger.log_tabular("Time", time.time() - start_time)

        # Show results!
        if verbose:
            logger.dump_tabular()
        else:
            logger.log_current_row.clear()
            logger.first_row = False