Ejemplo n.º 1
0
def run_polopt_agent(env_fn,
                     agent=PPOAgent(),
                     actor_critic=mlp_actor_critic,
                     ac_kwargs=dict(),
                     seed=0,
                     render=False,
                     # Experience collection:
                     steps_per_epoch=4000,
                     epochs=50,
                     max_ep_len=1000,
                     # Discount factors:
                     gamma=0.99,
                     lam=0.97,
                     cost_gamma=0.99,
                     cost_lam=0.97,
                     # Policy learning:
                     ent_reg=0.,
                     # Cost constraints / penalties:
                     cost_lim=25,
                     penalty_init=1.,
                     penalty_lr=5e-2,
                     # KL divergence:
                     target_kl=0.01,
                     # Value learning:
                     vf_lr=1e-3,
                     vf_iters=80,
                     # Logging:
                     logger=None,
                     logger_kwargs=dict(),
                     save_freq=1
                     ):


    #=========================================================================#
    #  Prepare logger, seed, and environment in this process                  #
    #=========================================================================#

    logger = EpochLogger(**logger_kwargs) if logger is None else logger
    logger.save_config(locals())

    seed += 10000 * proc_id()
    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = env_fn()

    agent.set_logger(logger)

    #=========================================================================#
    #  Create computation graph for actor and critic (not training routine)   #
    #=========================================================================#

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Inputs to computation graph from environment spaces
    x_ph, a_ph = placeholders_from_spaces(env.observation_space, env.action_space)

    # Inputs to computation graph for batch data
    adv_ph, cadv_ph, ret_ph, cret_ph, logp_old_ph = placeholders(*(None for _ in range(5)))

    # Inputs to computation graph for special purposes
    surr_cost_rescale_ph = tf.compat.v1.placeholder(tf.float32, shape=())
    cur_cost_ph = tf.compat.v1.placeholder(tf.float32, shape=())

    # Outputs from actor critic
    ac_outs = actor_critic(x_ph, a_ph, **ac_kwargs)
    pi, logp, logp_pi, pi_info, pi_info_phs, d_kl, ent, v, vc = ac_outs

    # Organize placeholders for zipping with data from buffer on updates
    buf_phs = [x_ph, a_ph, adv_ph, cadv_ph, ret_ph, cret_ph, logp_old_ph]
    buf_phs += values_as_sorted_list(pi_info_phs)

    # Organize symbols we have to compute at each step of acting in env
    get_action_ops = dict(pi=pi,
                          v=v,
                          logp_pi=logp_pi,
                          pi_info=pi_info)

    # If agent is reward penalized, it doesn't use a separate value function
    # for costs and we don't need to include it in get_action_ops; otherwise we do.
    if not(agent.reward_penalized):
        get_action_ops['vc'] = vc

    # Count variables
    var_counts = tuple(count_vars(scope) for scope in ['pi', 'vf', 'vc'])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t vc: %d\n'%var_counts)

    # Make a sample estimate for entropy to use as sanity check
    approx_ent = tf.reduce_mean(-logp)


    #=========================================================================#
    #  Create replay buffer                                                   #
    #=========================================================================#

    # Obs/act shapes
    obs_shape = env.observation_space.shape
    act_shape = env.action_space.shape

    # Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    pi_info_shapes = {k: v.shape.as_list()[1:] for k,v in pi_info_phs.items()}
    buf = CPOBuffer(local_steps_per_epoch,
                    obs_shape,
                    act_shape,
                    pi_info_shapes,
                    gamma,
                    lam,
                    cost_gamma,
                    cost_lam)


    #=========================================================================#
    #  Create computation graph for penalty learning, if applicable           #
    #=========================================================================#

    if agent.use_penalty:
        with tf.compat.v1.variable_scope('penalty'):
            # param_init = np.log(penalty_init)
            param_init = np.log(max(np.exp(penalty_init)-1, 1e-8))
            penalty_param = tf.compat.v1.get_variable('penalty_param',
                                          initializer=float(param_init),
                                          trainable=agent.learn_penalty,
                                          dtype=tf.float32)
        # penalty = tf.exp(penalty_param)
        penalty = tf.nn.softplus(penalty_param)

    if agent.learn_penalty:
        if agent.penalty_param_loss:
            penalty_loss = -penalty_param * (cur_cost_ph - cost_lim)
        else:
            penalty_loss = -penalty * (cur_cost_ph - cost_lim)
        train_penalty = MpiAdamOptimizer(learning_rate=penalty_lr).minimize(penalty_loss)


    #=========================================================================#
    #  Create computation graph for policy learning                           #
    #=========================================================================#

    # Likelihood ratio
    ratio = tf.exp(logp - logp_old_ph)

    # Surrogate advantage / clipped surrogate advantage
    if agent.clipped_adv:
        min_adv = tf.where(adv_ph>0,
                           (1+agent.clip_ratio)*adv_ph,
                           (1-agent.clip_ratio)*adv_ph
                           )
        surr_adv = tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
    else:
        surr_adv = tf.reduce_mean(ratio * adv_ph)

    # Surrogate cost
    surr_cost = tf.reduce_mean(ratio * cadv_ph)

    # Create policy objective function, including entropy regularization
    pi_objective = surr_adv + ent_reg * ent

    # Possibly include surr_cost in pi_objective
    if agent.objective_penalized:
        pi_objective -= penalty * surr_cost
        pi_objective /= (1 + penalty)

    # Loss function for pi is negative of pi_objective
    pi_loss = -pi_objective

    # Optimizer-specific symbols
    if agent.trust_region:

        # Symbols needed for CG solver for any trust region method
        pi_params = get_vars('pi')
        flat_g = tro.flat_grad(pi_loss, pi_params)
        v_ph, hvp = tro.hessian_vector_product(d_kl, pi_params)
        if agent.damping_coeff > 0:
            hvp += agent.damping_coeff * v_ph

        # Symbols needed for CG solver for CPO only
        flat_b = tro.flat_grad(surr_cost, pi_params)

        # Symbols for getting and setting params
        get_pi_params = tro.flat_concat(pi_params)
        set_pi_params = tro.assign_params_from_flat(v_ph, pi_params)

        training_package = dict(flat_g=flat_g,
                                flat_b=flat_b,
                                v_ph=v_ph,
                                hvp=hvp,
                                get_pi_params=get_pi_params,
                                set_pi_params=set_pi_params)

    elif agent.first_order:

        # Optimizer for first-order policy optimization
        train_pi = MpiAdamOptimizer(learning_rate=agent.pi_lr).minimize(pi_loss)

        # Prepare training package for agent
        training_package = dict(train_pi=train_pi)

    else:
        raise NotImplementedError

    # Provide training package to agent
    training_package.update(dict(pi_loss=pi_loss,
                                 surr_cost=surr_cost,
                                 d_kl=d_kl,
                                 target_kl=target_kl,
                                 cost_lim=cost_lim))
    agent.prepare_update(training_package)

    #=========================================================================#
    #  Create computation graph for value learning                            #
    #=========================================================================#

    # Value losses
    v_loss = tf.reduce_mean((ret_ph - v)**2)
    vc_loss = tf.reduce_mean((cret_ph - vc)**2)

    # If agent uses penalty directly in reward function, don't train a separate
    # value function for predicting cost returns. (Only use one vf for r - p*c.)
    if agent.reward_penalized:
        total_value_loss = v_loss
    else:
        total_value_loss = v_loss + vc_loss

    # Optimizer for value learning
    train_vf = MpiAdamOptimizer(learning_rate=vf_lr).minimize(total_value_loss)


    #=========================================================================#
    #  Create session, sync across procs, and set up saver                    #
    #=========================================================================#

    sess = tf.compat.v1.Session()
    sess.run(tf.global_variables_initializer())

    # Sync params across processes
    sess.run(sync_all_params())

    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x': x_ph}, outputs={'pi': pi, 'v': v, 'vc': vc})


    #=========================================================================#
    #  Provide session to agent                                               #
    #=========================================================================#
    agent.prepare_session(sess)


    #=========================================================================#
    #  Create function for running update (called at end of each epoch)       #
    #=========================================================================#

    def update():
        cur_cost = logger.get_stats('EpCost')[0]
        c = cur_cost - cost_lim
        if c > 0 and agent.cares_about_cost:
            logger.log('Warning! Safety constraint is already violated.', 'red')

        #=====================================================================#
        #  Prepare feed dict                                                  #
        #=====================================================================#

        inputs = {k:v for k,v in zip(buf_phs, buf.get())}
        inputs[surr_cost_rescale_ph] = logger.get_stats('EpLen')[0]
        inputs[cur_cost_ph] = cur_cost

        #=====================================================================#
        #  Make some measurements before updating                             #
        #=====================================================================#

        measures = dict(LossPi=pi_loss,
                        SurrCost=surr_cost,
                        LossV=v_loss,
                        Entropy=ent)
        if not(agent.reward_penalized):
            measures['LossVC'] = vc_loss
        if agent.use_penalty:
            measures['Penalty'] = penalty

        pre_update_measures = sess.run(measures, feed_dict=inputs)
        logger.store(**pre_update_measures)

        #=====================================================================#
        #  Update penalty if learning penalty                                 #
        #=====================================================================#
        if agent.learn_penalty:
            sess.run(train_penalty, feed_dict={cur_cost_ph: cur_cost})

        #=====================================================================#
        #  Update policy                                                      #
        #=====================================================================#
        agent.update_pi(inputs)

        #=====================================================================#
        #  Update value function                                              #
        #=====================================================================#
        for _ in range(vf_iters):
            sess.run(train_vf, feed_dict=inputs)

        #=====================================================================#
        #  Make some measurements after updating                              #
        #=====================================================================#

        del measures['Entropy']
        measures['KL'] = d_kl

        post_update_measures = sess.run(measures, feed_dict=inputs)
        deltas = dict()
        for k in post_update_measures:
            if k in pre_update_measures:
                deltas['Delta'+k] = post_update_measures[k] - pre_update_measures[k]
        logger.store(KL=post_update_measures['KL'], **deltas)




    #=========================================================================#
    #  Run main environment interaction loop                                  #
    #=========================================================================#

    start_time = time.time()
    o, r, d, c, ep_ret, ep_cost, ep_len = env.reset(), 0, False, 0, 0, 0, 0
    cur_penalty = 0
    cum_cost = 0

    for epoch in range(epochs):

        if agent.use_penalty:
            cur_penalty = sess.run(penalty)

        # will gather 30,000 state, action, next state
        for t in range(local_steps_per_epoch):

            # Possibly render
            if render and proc_id()==0 and t < 1000:
                env.render()

            # Get outputs from policy
            get_action_outs = sess.run(get_action_ops,
                                       feed_dict={x_ph: o[np.newaxis]})
            a = get_action_outs['pi']
            v_t = get_action_outs['v']
            vc_t = get_action_outs.get('vc', 0)  # Agent may not use cost value func
            logp_t = get_action_outs['logp_pi']
            pi_info_t = get_action_outs['pi_info']

            # Step in environment
            o2, r, d, info = env.step(a)

            # Include penalty on cost
            c = info.get('cost', 0)

            # Track cumulative cost over training
            cum_cost += c

            # save and log
            if agent.reward_penalized:
                r_total = r - cur_penalty * c
                r_total = r_total / (1 + cur_penalty)
                buf.store(o, a, r_total, v_t, 0, 0, logp_t, pi_info_t)
            else:
                buf.store(o, a, r, v_t, c, vc_t, logp_t, pi_info_t)
            logger.store(VVals=v_t, CostVVals=vc_t)

            o = o2
            ep_ret += r
            ep_cost += c
            ep_len += 1

    """
    t=0                                                        t = 30,000
    |              30,000 local_steps_per_epoch                |
    | ep |  ep    |     ep   |      ep    |    ep    |    ep   |
                                max 1000
    """
            # reach the goal or hit max env timesteps (1000)
            terminal = (d or (ep_len == max_ep_len))
            if terminal or (t==local_steps_per_epoch-1):

                # If trajectory didn't reach terminal state, bootstrap value target(s)
                if d and not(ep_len == max_ep_len):
                    # Note: we do not count env time out as true terminal state
                    last_val, last_cval = 0, 0
                else:
                    feed_dict={x_ph: o[np.newaxis]}
                    if agent.reward_penalized:
                        last_val = sess.run(v, feed_dict=feed_dict)
                        last_cval = 0
                    else:
                        last_val, last_cval = sess.run([v, vc], feed_dict=feed_dict)
                buf.finish_path(last_val, last_cval)

                # Only save EpRet / EpLen if trajectory finished
                if terminal:
                    logger.store(EpRet=ep_ret, EpLen=ep_len, EpCost=ep_cost)
                else:
                    print('Warning: trajectory cut off by epoch at %d steps.'%ep_len)

                # Reset environment
                o, r, d, c, ep_ret, ep_len, ep_cost = env.reset(), 0, False, 0, 0, 0, 0
Ejemplo n.º 2
0
def sac(env_fn,
        actor_fn=mlp_actor,
        critic_fn=mlp_critic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=1000,
        epochs=100,
        replay_size=int(1e6),
        gamma=0.99,
        polyak=0.995,
        lr=1e-4,
        batch_size=1024,
        local_start_steps=int(1e3),
        max_ep_len=1000,
        logger_kwargs=dict(),
        save_freq=10,
        local_update_after=int(1e3),
        update_freq=1,
        render=False,
        fixed_entropy_bonus=None,
        entropy_constraint=-1.0,
        fixed_cost_penalty=None,
        cost_constraint=None,
        cost_lim=None,
        reward_scale=1,
        penalty_lr=5e-2,
        use_discor=False,
        cost_maxq=True):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_fn: A function which takes in placeholder symbols
            for state, ``x_ph``, and action, ``a_ph``, and returns the actor
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``mu``       (batch, act_dim)  | Computes mean actions from policy
                                           | given states.
            ``pi``       (batch, act_dim)  | Samples actions from policy given
                                           | states.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``. Critical: must be differentiable
                                           | with respect to policy parameters all
                                           | the way through action sampling.
            ===========  ================  ======================================

        critic_fn: A function which takes in placeholder symbols
            for state, ``x_ph``, action, ``a_ph``, and policy ``pi``,
            and returns the critic outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``critic``    (batch,)         | Gives one estimate of Q* for
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``critic_pi`` (batch,)         | Gives another estimate of Q* for
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_fn / critic_fn
            function you provided to SAC.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target
            networks. Target networks are updated towards main networks
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        batch_size (int): Minibatch size for SGD.

        local_start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

        fixed_entropy_bonus (float or None): Fixed bonus to reward for entropy.
            Units are (points of discounted sum of future reward) / (nats of policy entropy).
            If None, use ``entropy_constraint`` to set bonus value instead.

        entropy_constraint (float): If ``fixed_entropy_bonus`` is None,
            Adjust entropy bonus to maintain at least this much entropy.
            Actual constraint value is multiplied by the dimensions of the action space.
            Units are (nats of policy entropy) / (action dimenson).

        fixed_cost_penalty (float or None): Fixed penalty to reward for cost.
            Units are (points of discounted sum of future reward) / (points of discounted sum of future costs).
            If None, use ``cost_constraint`` to set penalty value instead.

        cost_constraint (float or None): If ``fixed_cost_penalty`` is None,
            Adjust cost penalty to maintain at most this much cost.
            Units are (points of discounted sum of future costs).
            Note: to get an approximate cost_constraint from a cost_lim (undiscounted sum of costs),
            multiply cost_lim by (1 - gamma ** episode_len) / (1 - gamma).
            If None, use cost_lim to calculate constraint.

        cost_lim (float or None): If ``cost_constraint`` is None,
            calculate an approximate constraint cost from this cost limit.
            Units are (expectation of undiscounted sum of costs in a single episode).
            If None, cost_lim is not used, and if no cost constraints are used, do naive optimization.
    """
    use_costs = fixed_cost_penalty or cost_constraint or cost_lim

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    # Env instantiation
    env, test_env = env_fn(), env_fn()
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    # Setting seeds
    tf.set_random_seed(seed)
    np.random.seed(seed)
    env.seed(seed)
    test_env.seed(seed)

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    act_limit = env.action_space.high[0]

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Inputs to computation graph
    x_ph, a_ph, x2_ph, r_ph, d_ph, c_ph = placeholders(obs_dim, act_dim,
                                                       obs_dim, None, None,
                                                       None)

    # Main outputs from computation graph
    with tf.variable_scope('main'):
        mu, pi, logp_pi = actor_fn(x_ph, a_ph, **ac_kwargs)
        qr1, qr1_pi = critic_fn(x_ph, a_ph, pi, name='qr1', **ac_kwargs)
        qr2, qr2_pi = critic_fn(x_ph, a_ph, pi, name='qr2', **ac_kwargs)
        qc1, qc1_pi = critic_fn(x_ph, a_ph, pi, name='qc1', **ac_kwargs)
        if cost_maxq:
            qc2, qc2_pi = critic_fn(x_ph, a_ph, pi, name='qc2', **ac_kwargs)
        if use_discor:
            er1, er1_targ = critic_fn(x_ph, a_ph, pi, name='er1', **ac_kwargs)
            er2, er2_targ = critic_fn(x_ph, a_ph, pi, name='er2', **ac_kwargs)
            ec1, ec1_targ = critic_fn(x_ph, a_ph, pi, name='ec1', **ac_kwargs)
            if cost_maxq:
                ec2, ec2_targ = critic_fn(x_ph,
                                          a_ph,
                                          pi,
                                          name='ec2',
                                          **ac_kwargs)

    with tf.variable_scope('main', reuse=True):
        # Additional policy output from a different observation placeholder
        # This lets us do separate optimization updates (actor, critics, etc)
        # in a single tensorflow op.
        _, pi2, logp_pi2 = actor_fn(x2_ph, a_ph, **ac_kwargs)

    # Target value network
    with tf.variable_scope('target'):
        _, qr1_pi_targ = critic_fn(x2_ph, a_ph, pi2, name='qr1', **ac_kwargs)
        _, qr2_pi_targ = critic_fn(x2_ph, a_ph, pi2, name='qr2', **ac_kwargs)
        _, qc1_pi_targ = critic_fn(x2_ph, a_ph, pi2, name='qc1', **ac_kwargs)
        if cost_maxq:
            _, qc2_pi_targ = critic_fn(x2_ph,
                                       a_ph,
                                       pi2,
                                       name='qc2',
                                       **ac_kwargs)
        if use_discor:
            _, er1_pi_targ = critic_fn(x_ph, a_ph, pi, name='er1', **ac_kwargs)
            _, er2_pi_targ = critic_fn(x_ph, a_ph, pi, name='er2', **ac_kwargs)
            _, ec1_pi_targ = critic_fn(x_ph, a_ph, pi, name='ec1', **ac_kwargs)
            if cost_maxq:
                _, ec2_pi_targ = critic_fn(x_ph,
                                           a_ph,
                                           pi,
                                           name='ec2',
                                           **ac_kwargs)

    # Entropy bonus
    if fixed_entropy_bonus is None:
        with tf.variable_scope('entreg'):
            soft_alpha = tf.get_variable('soft_alpha',
                                         initializer=0.0,
                                         trainable=True,
                                         dtype=tf.float32)
        alpha = tf.nn.softplus(soft_alpha)
    else:
        alpha = tf.constant(fixed_entropy_bonus)
    log_alpha = tf.log(alpha)

    # Cost penalty
    if use_costs:
        if fixed_cost_penalty is None:
            with tf.variable_scope('costpen'):
                soft_beta = tf.get_variable('soft_beta',
                                            initializer=0.0,
                                            trainable=True,
                                            dtype=tf.float32)
            beta = tf.nn.softplus(soft_beta)
            log_beta = tf.log(beta)
        else:
            beta = tf.constant(fixed_cost_penalty)
            log_beta = tf.log(beta)
    else:
        beta = 0.0  # costs do not contribute to policy optimization
        print('Not using costs')

    if use_discor:
        with tf.variable_scope('discor'):
            tr1 = tf.get_variable('tr1',
                                  initializer=10.0,
                                  trainable=False,
                                  dtype=tf.float32)
            tr2 = tf.get_variable('tr2',
                                  initializer=10.0,
                                  trainable=False,
                                  dtype=tf.float32)
            tc1 = tf.get_variable('tc1',
                                  initializer=10.0,
                                  trainable=False,
                                  dtype=tf.float32)
            if cost_maxq:
                tc2 = tf.get_variable('tc2',
                                      initializer=10.0,
                                      trainable=False,
                                      dtype=tf.float32)

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim,
                                 act_dim=act_dim,
                                 size=replay_size)

    # Count variables
    if proc_id() == 0:
        var_counts = tuple(
            count_vars(scope) for scope in
            ['main/pi', 'main/qr1', 'main/qr2', 'main/qc1', 'main'])
        print((
            '\nNumber of parameters: \t pi: %d, \t qr1: %d, \t qr2: %d, \t qc1: %d, \t total: %d\n'
        ) % var_counts)

    # Min Double-Q:
    min_q_pi = tf.minimum(qr1_pi, qr2_pi)
    min_q_pi_targ = tf.minimum(qr1_pi_targ, qr2_pi_targ)

    if cost_maxq:
        max_qc_pi = tf.maximum(qc1_pi, qc2_pi)
        max_qc_pi_targ = tf.maximum(qc1_pi_targ, qc2_pi_targ)
    else:
        max_qc_pi = qc1_pi
        max_qc_pi_targ = qc1_pi_targ

    # Targets for Q and V regression
    q_backup = tf.stop_gradient(r_ph + gamma * (1 - d_ph) *
                                (min_q_pi_targ - alpha * logp_pi2))
    qc_backup = tf.stop_gradient(c_ph + gamma * (1 - d_ph) * max_qc_pi_targ)

    if use_discor:
        er1_backup = tf.stop_gradient(
            tf.abs(qr1 - q_backup) + gamma * (1 - d_ph) * er1_pi_targ)
        er2_backup = tf.stop_gradient(
            tf.abs(qr2 - q_backup) + gamma * (1 - d_ph) * er2_pi_targ)
        ec1_backup = tf.stop_gradient(
            tf.abs(qc1 - qc_backup) + gamma * (1 - d_ph) * ec1_pi_targ)
        if cost_maxq:
            ec2_backup = tf.stop_gradient(
                tf.abs(qc2 - qc_backup) + gamma * (1 - d_ph) * ec2_pi_targ)

        qr1_loss = 0.5 * tf.reduce_sum(
            tf.nn.softmax(er1_backup / tr1, axis=0) * (q_backup - qr1)**2)
        qr2_loss = 0.5 * tf.reduce_sum(
            tf.nn.softmax(er2_backup / tr2, axis=0) * (q_backup - qr2)**2)
        qc1_loss = 0.5 * tf.reduce_sum(
            tf.nn.softmax(ec1_backup / tc1, axis=0) * (qc_backup - qc1)**2)
        if cost_maxq:
            qc2_loss = 0.5 * tf.reduce_sum(
                tf.nn.softmax(ec2_backup / tc2, axis=0) * (qc_backup - qc2)**2)
    else:
        qr1_loss = 0.5 * tf.reduce_mean((q_backup - qr1)**2)
        qr2_loss = 0.5 * tf.reduce_mean((q_backup - qr2)**2)
        qc1_loss = 0.5 * tf.reduce_mean((qc_backup - qc1)**2)
        if cost_maxq:
            qc2_loss = 0.5 * tf.reduce_mean((qc_backup - qc2)**2)
    # Soft actor-critic losses
    q_loss = qr1_loss + qr2_loss + qc1_loss
    if cost_maxq:
        q_loss += qc2_loss
    pi_loss = tf.reduce_mean(alpha * logp_pi - min_q_pi +
                             beta * max_qc_pi) / (1 + beta)

    if use_discor:
        er1_loss = 0.5 * tf.reduce_mean((er1_backup - er1)**2)
        er2_loss = 0.5 * tf.reduce_mean((er2_backup - er2)**2)
        ec1_loss = 0.5 * tf.reduce_mean((ec1_backup - ec1)**2)
        error_loss = er1_loss + er2_loss + ec1_loss
        if cost_maxq:
            ec2_loss = 0.5 * tf.reduce_mean((ec2_backup - ec2)**2)
            error_loss += +ec2_loss
            ec2_mean = tf.reduce_mean(ec2)
        er1_mean = tf.reduce_mean(er1)
        er2_mean = tf.reduce_mean(er2)
        ec1_mean = tf.reduce_mean(ec1)

    # Loss for alpha
    entropy_constraint *= act_dim
    pi_entropy = -tf.reduce_mean(logp_pi)
    # alpha_loss = - soft_alpha * (entropy_constraint - pi_entropy)
    alpha_loss = -alpha * (entropy_constraint - pi_entropy)
    print('using entropy constraint', entropy_constraint)

    # Loss for beta
    if use_costs and not fixed_cost_penalty:
        if cost_constraint is None:
            # Convert assuming equal cost accumulated each step
            # Note this isn't the case, since the early in episode doesn't usually have cost,
            # but since our algorithm optimizes the discounted infinite horizon from each entry
            # in the replay buffer, we should be approximately correct here.
            # It's worth checking empirical total undiscounted costs to see if they match.
            cost_constraint = cost_lim * (1 - gamma**max_ep_len) / (
                1 - gamma) / max_ep_len
        print('using cost constraint', cost_constraint)
        beta_loss = beta * (cost_constraint - qc1)
        #TODO: What is the correct target here?
    # Policy train op
    # (has to be separate from value train op, because qr1_pi appears in pi_loss)
    train_pi_op = MpiAdamOptimizer(learning_rate=lr).minimize(
        pi_loss, var_list=get_vars('main/pi'), name='train_pi')

    # Value train op
    with tf.control_dependencies([train_pi_op]):
        train_q_op = MpiAdamOptimizer(learning_rate=lr).minimize(
            q_loss, var_list=get_vars('main/q'), name='train_q')
    with tf.control_dependencies([train_q_op]):
        if use_discor:
            train_e_op = MpiAdamOptimizer(learning_rate=lr).minimize(
                error_loss, var_list=get_vars('main/e'), name='train_e')
            with tf.control_dependencies([train_e_op]):
                if cost_maxq:
                    train_e_out_op = tf.group([
                        tf.assign(tr1, (1 - polyak) * er1_mean + polyak * tr1),
                        tf.assign(tr2, (1 - polyak) * er2_mean + polyak * tr2),
                        tf.assign(tc1, (1 - polyak) * ec1_mean + polyak * tc1),
                        tf.assign(tc2, (1 - polyak) * ec2_mean + polyak * tc2)
                    ])
                else:
                    train_e_out_op = tf.group([
                        tf.assign(tr1, (1 - polyak) * er1_mean + polyak * tr1),
                        tf.assign(tr2, (1 - polyak) * er2_mean + polyak * tr2),
                        tf.assign(tc1, (1 - polyak) * ec1_mean + polyak * tc1)
                    ])
        else:
            train_e_out_op = tf.no_op()
    if fixed_entropy_bonus is None:
        entreg_optimizer = MpiAdamOptimizer(learning_rate=lr)
        with tf.control_dependencies([train_e_out_op]):
            train_entreg_op = entreg_optimizer.minimize(
                alpha_loss, var_list=get_vars('entreg'))
    if use_costs and fixed_cost_penalty is None:
        costpen_optimizer = MpiAdamOptimizer(learning_rate=penalty_lr)
        with tf.control_dependencies([train_entreg_op]):
            train_costpen_op = costpen_optimizer.minimize(
                beta_loss, var_list=get_vars('costpen'))

    # Polyak averaging for target variables
    target_update = get_target_update('main', 'target', polyak)

    # Single monolithic update with explicit control dependencies
    with tf.control_dependencies([train_pi_op]):
        with tf.control_dependencies([train_q_op]):
            if use_discor:
                with tf.control_dependencies([train_e_op]):
                    with tf.control_dependencies([train_e_out_op]):
                        grouped_update = tf.group([target_update])
            else:
                grouped_update = tf.group([target_update])

    if fixed_entropy_bonus is None:
        grouped_update = tf.group([grouped_update, train_entreg_op])
    if use_costs and fixed_cost_penalty is None:
        grouped_update_a = tf.group([grouped_update, train_costpen_op])

    # Initializing targets to match main variables
    # As a shortcut, use our exponential moving average update w/ coefficient zero
    target_init = get_target_update('main', 'target', 0.0)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    sess.run(target_init)

    # Sync params across processes
    sess.run(sync_all_params())

    # Setup model saving
    logger.setup_tf_saver(sess,
                          inputs={
                              'x': x_ph,
                              'a': a_ph
                          },
                          outputs={
                              'mu': mu,
                              'pi': pi,
                              'qr1': qr1,
                              'qr2': qr2,
                              'qc1': qc1
                          })

    def get_action(o, deterministic=False):
        act_op = mu if deterministic else pi
        return sess.run(act_op, feed_dict={x_ph: o.reshape(1, -1)})[0]

    def test_agent(n=10):
        for j in range(n):
            o, r, d, ep_ret, ep_cost, ep_len, ep_goals, = test_env.reset(
            ), 0, False, 0, 0, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, info = test_env.step(get_action(o, True))
                if render and proc_id() == 0 and j == 0:
                    test_env.render()
                ep_ret += r
                ep_cost += info.get('cost', 0)
                ep_len += 1
                ep_goals += 1 if info.get('goal_met', False) else 0
            logger.store(TestEpRet=ep_ret,
                         TestEpCost=ep_cost,
                         TestEpLen=ep_len,
                         TestEpGoals=ep_goals)

    start_time = time.time()
    o, r, d, ep_ret, ep_cost, ep_len, ep_goals = env.reset(
    ), 0, False, 0, 0, 0, 0
    total_steps = steps_per_epoch * epochs

    # variables to measure in an update
    vars_to_get = dict(LossPi=pi_loss,
                       LossQR1=qr1_loss,
                       LossQR2=qr2_loss,
                       LossQC1=qc1_loss,
                       QR1Vals=qr1,
                       QR2Vals=qr2,
                       QC1Vals=qc1,
                       LogPi=logp_pi,
                       PiEntropy=pi_entropy,
                       Alpha=alpha,
                       LogAlpha=log_alpha,
                       LossAlpha=alpha_loss)
    if use_costs and not fixed_cost_penalty:
        vars_to_get.update(
            dict(Beta=beta, LogBeta=log_beta, LossBeta=beta_loss))
    if use_discor:
        vars_to_get.update(dict(TR1=tr1))

    print('starting training', proc_id())

    # Main loop: collect experience in env and update/log each epoch
    local_steps = 0
    local_steps_per_epoch = steps_per_epoch // num_procs()
    local_batch_size = batch_size // num_procs()
    epoch_start_time = time.time()
    for t in range(total_steps // num_procs()):
        """
        Until local_start_steps have elapsed, randomly sample actions
        from a uniform distribution for better exploration. Afterwards,
        use the learned policy.
        """
        if t > local_start_steps:
            a = get_action(o)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, r, d, info = env.step(a)
        r *= reward_scale  # yee-haw
        c = info.get('cost', 0)
        ep_ret += r
        ep_cost += c
        ep_len += 1
        ep_goals += 1 if info.get('goal_met', False) else 0
        local_steps += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d, c)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        if d or (ep_len == max_ep_len):
            logger.store(EpRet=ep_ret,
                         EpCost=ep_cost,
                         EpLen=ep_len,
                         EpGoals=ep_goals)
            o, r, d, ep_ret, ep_cost, ep_len, ep_goals = env.reset(
            ), 0, False, 0, 0, 0, 0

        if t > 0 and t % update_freq == 0:
            for j in range(update_freq):
                batch = replay_buffer.sample_batch(local_batch_size)
                feed_dict = {
                    x_ph: batch['obs1'],
                    x2_ph: batch['obs2'],
                    a_ph: batch['acts'],
                    r_ph: batch['rews'],
                    c_ph: batch['costs'],
                    d_ph: batch['done'],
                }
                if t < local_update_after:
                    logger.store(**sess.run(vars_to_get, feed_dict))
                else:
                    if (not j == update_freq -
                            1) or not (use_costs and not fixed_cost_penalty):
                        values, _ = sess.run([vars_to_get, grouped_update],
                                             feed_dict)
                        logger.store(**values)
                    else:
                        values, _ = sess.run([vars_to_get, grouped_update_a],
                                             feed_dict)
                        logger.store(**values)

        # End of epoch wrap-up
        if t > 0 and t % local_steps_per_epoch == 0:
            epoch = t // local_steps_per_epoch

            # Save model
            if (epoch % save_freq == 0) or (epoch == epochs - 1):
                logger.save_state({'env': env}, None)

            # Test the performance of the deterministic version of the agent.
            test_start_time = time.time()
            test_agent()
            logger.store(TestTime=time.time() - test_start_time)

            logger.store(EpochTime=time.time() - epoch_start_time)
            epoch_start_time = time.time()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpCost', with_min_and_max=True)
            logger.log_tabular('TestEpCost', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('EpGoals', average_only=True)
            logger.log_tabular('TestEpGoals', average_only=True)
            logger.log_tabular('TotalEnvInteracts', mpi_sum(local_steps))
            logger.log_tabular('QR1Vals', with_min_and_max=True)
            logger.log_tabular('QR2Vals', with_min_and_max=True)
            logger.log_tabular('QC1Vals', with_min_and_max=True)
            logger.log_tabular('LogPi', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQR1', average_only=True)
            logger.log_tabular('LossQR2', average_only=True)
            logger.log_tabular('LossQC1', average_only=True)
            logger.log_tabular('LossAlpha', average_only=True)
            logger.log_tabular('LogAlpha', average_only=True)
            logger.log_tabular('Alpha', average_only=True)
            if use_costs and not fixed_cost_penalty:
                logger.log_tabular('LossBeta', average_only=True)
                logger.log_tabular('LogBeta', average_only=True)
                logger.log_tabular('Beta', average_only=True)
            if use_discor:
                logger.log_tabular('TR1', average_only=True)
            logger.log_tabular('PiEntropy', average_only=True)
            logger.log_tabular('TestTime', average_only=True)
            logger.log_tabular('EpochTime', average_only=True)
            logger.log_tabular('TotalTime', time.time() - start_time)
            logger.dump_tabular()