Esempio n. 1
0
    def setup_actor(self):
        logger.info("setting up actor optimizer")

        losses = OrderedDict()

        # Create the Q loss as the negative of the cumulated Q values
        q_loss = -tf.reduce_mean(self.critic_pred_w_actor)
        q_loss *= self.hps.q_actor_loss_scale

        # Create the actor loss w/ the scaled Q loss
        loss = q_loss

        losses.update({'actor_q_loss': q_loss})

        # Create the D loss as the negative of the cumulated D values
        d_loss = -tf.reduce_mean(self.d_pred_w_actor)
        d_loss *= self.hps.d_actor_loss_scale

        # Add the D loss to the actor loss
        loss += d_loss

        losses.update({'actor_d_loss': d_loss})

        # Add assembled actor loss
        losses.update({'actor_total_loss': loss})

        # Create gradients
        grads = flatgrad(loss, self.actor.trainable_vars, self.hps.clip_norm)

        # Create mpi adam optimizer
        optimizer = MpiAdamOptimizer(comm=self.comm,
                                     clip_norm=self.hps.clip_norm,
                                     learning_rate=self.hps.actor_lr,
                                     name='actor_adam')
        optimize_ = optimizer.minimize(loss=loss,
                                       var_list=self.actor.trainable_vars)

        # Create callable objects
        get_losses = TheanoFunction(inputs=[self.obs0],
                                    outputs=list(losses.values()))
        get_grads = TheanoFunction(inputs=[self.obs0], outputs=grads)
        optimize = TheanoFunction(inputs=[self.obs0], outputs=optimize_)

        # Log statistics
        log_module_info(logger, self.name, self.actor)

        # Return the actor ops
        return {
            'names': list(losses.keys()),
            'losses': get_losses,
            'grads': get_grads,
            'optimizer': optimizer,
            'optimize': optimize
        }
Esempio n. 2
0
    def setup_critic(self):
        logger.info("setting up critic optimizer")

        losses = OrderedDict()

        phs = [self.obs0, self.acs]

        if self.hps.prioritized_replay:
            phs.append(self.iws)

        # Create the 1-step look-ahead TD error loss
        td_errors_1 = self.critic_pred - self.tc1z
        hubered_td_errors_1 = huber_loss(td_errors_1)
        if self.hps.prioritized_replay:
            # Adjust with importance weights
            hubered_td_errors_1 *= self.iws
        td_loss_1 = tf.reduce_mean(hubered_td_errors_1)
        td_loss_1 *= self.hps.td_loss_1_scale

        # Create the critic loss w/ the scaled 1-step TD loss
        loss = td_loss_1

        losses.update({'critic_td_loss_1': td_loss_1})

        phs.append(self.tc1s)

        if self.hps.n_step_returns:
            # Create the n-step look-ahead TD error loss
            td_errors_n = self.critic_pred - self.tcnz
            hubered_td_errors_n = huber_loss(td_errors_n)
            if self.hps.prioritized_replay:
                # Adjust with importance weights
                hubered_td_errors_n *= self.iws
            td_loss_n = tf.reduce_mean(hubered_td_errors_n)
            td_loss_n *= self.hps.td_loss_n_scale

            # Add the scaled n-step TD loss to the critic loss
            loss += td_loss_n

            losses.update({'critic_td_loss_n': td_loss_n})

            phs.append(self.tcns)

        # Fetch critic's regularization losses (@property of the network)
        wd_loss = tf.reduce_sum(self.critic.regularization_losses)
        # Note: no need to multiply by a scale as it has already been scaled
        logger.info("setting up weight decay")
        if self.hps.wd_scale > 0:
            for var in self.critic.trainable_vars:
                if var in self.critic.decayable_vars:
                    logger.info("  {} <- wd w/ scale {}".format(
                        var.name, self.hps.wd_scale))
                else:
                    logger.info("  {}".format(var.name))

        # Add critic weight decay regularization to the critic loss
        loss += wd_loss

        losses.update({'critic_wd': wd_loss})

        # Add assembled critic loss
        losses.update({'critic_total_loss': loss})

        # Create gradients
        grads = flatgrad(loss, self.critic.trainable_vars, self.hps.clip_norm)

        # Create mpi adam optimizer
        optimizer = MpiAdamOptimizer(comm=self.comm,
                                     clip_norm=self.hps.clip_norm,
                                     learning_rate=self.hps.critic_lr,
                                     name='critic_adam')
        optimize_ = optimizer.minimize(loss=loss,
                                       var_list=self.critic.trainable_vars)

        # Create callable objects
        get_losses = TheanoFunction(inputs=phs, outputs=list(losses.values()))
        get_grads = TheanoFunction(inputs=phs, outputs=grads)
        optimize = TheanoFunction(inputs=phs, outputs=optimize_)

        if self.hps.prioritized_replay:
            td_errors_ops = [td_errors_1] + ([td_errors_n] if
                                             self.hps.n_step_returns else [])
            get_td_errors = TheanoFunction(inputs=phs, outputs=td_errors_ops)

        # Log statistics
        log_module_info(logger, self.name, self.critic)

        # Return the critic ops
        out = {
            'names': list(losses.keys()),
            'losses': get_losses,
            'grads': get_grads,
            'optimizer': optimizer,
            'optimize': optimize
        }
        if self.hps.prioritized_replay:
            out.update({'td_errors': get_td_errors})

        return out
Esempio n. 3
0
    def _init(self, env, hps, comm):
        self.env = env
        self.ob_shape = self.env.observation_space.shape
        self.ac_space = self.env.action_space
        self.ac_shape = self.ac_space.shape
        if "NoFrameskip" in env.spec.id:
            # Expand the dimension for Atari
            self.ac_shape = (1,) + self.ac_shape
        self.hps = hps
        assert self.hps.ent_reg_scale >= 0, "'ent_reg_scale' must be non-negative"
        self.comm = comm

        # Assemble clipping functions
        unlimited_range = (-np.infty, np.infty)
        if isinstance(self.ac_space, spaces.Box):
            self.clip_obs = clip((-5., 5.))
        elif isinstance(self.ac_space, spaces.Discrete):
            self.clip_obs = clip(unlimited_range)
        else:
            raise RuntimeError("ac space is neither Box nor Discrete")

        # Define the synthetic reward network
        self.reward_nn = RewardNN(scope=self.scope, name='sr', hps=self.hps)

        # Create inputs
        self.p_obs = tf.placeholder(name='p_obs', dtype=tf.float32,
                                    shape=(None,) + self.ob_shape)
        self.p_acs = tf.placeholder(name='p_acs', dtype=tf.float32,
                                    shape=(None,) + self.ac_shape)
        self.e_obs = tf.placeholder(name='e_obs', dtype=tf.float32,
                                    shape=(None,) + self.ob_shape)
        self.e_acs = tf.placeholder(name='e_acs', dtype=tf.float32,
                                    shape=(None,) + self.ac_shape)

        # Rescale observations
        if self.hps.from_raw_pixels:
            # Scale de pixel values
            p_obz = self.p_obs / 255.0
            e_obz = self.e_obs / 255.0
        else:
            if self.hps.rmsify_obs:
                # Smooth out observations using running statistics and clip
                with tf.variable_scope("apply_obs_rms"):
                    self.obs_rms = MpiRunningMeanStd(shape=self.ob_shape)
                p_obz = self.clip_obs(rmsify(self.p_obs, self.obs_rms))
                e_obz = self.clip_obs(rmsify(self.e_obs, self.obs_rms))
            else:
                p_obz = self.p_obs
                e_obz = self.e_obs

        # Build graph
        p_scores = self.reward_nn(p_obz, self.p_acs)
        e_scores = self.reward_nn(e_obz, self.e_acs)
        scores = tf.concat([p_scores, e_scores], axis=0)
        # `scores` define the conditional distribution D(s,a) := p(label|(state,action))

        # Create entropy loss
        bernouilli_pd = BernoulliPd(logits=scores)
        ent_mean = tf.reduce_mean(bernouilli_pd.entropy())
        ent_loss = -self.hps.ent_reg_scale * ent_mean

        # Create labels
        fake_labels = tf.zeros_like(p_scores)
        real_labels = tf.ones_like(e_scores)
        if self.hps.label_smoothing:
            # Label smoothing, suggested in 'Improved Techniques for Training GANs',
            # Salimans 2016, https://arxiv.org/abs/1606.03498
            # The paper advises on the use of one-sided label smoothing (positive targets side)
            # Extra comment explanation: https://github.com/openai/improved-gan/blob/
            # 9ff96a7e9e5ac4346796985ddbb9af3239c6eed1/imagenet/build_model.py#L88-L121
            if not self.hps.one_sided_label_smoothing:
                # Fake labels (negative targets)
                soft_fake_u_b = 0.0  # standard, hyperparameterization not needed
                soft_fake_l_b = 0.3  # standard, hyperparameterization not needed
                fake_labels = tf.random_uniform(shape=tf.shape(fake_labels),
                                                name="fake_labels_smoothing",
                                                minval=soft_fake_l_b, maxval=soft_fake_u_b)
            # Real labels (positive targets)
            soft_real_u_b = 0.7  # standard, hyperparameterization not needed
            soft_real_l_b = 1.2  # standard, hyperparameterization not needed
            real_labels = tf.random_uniform(shape=tf.shape(real_labels),
                                            name="real_labels_smoothing",
                                            minval=soft_real_l_b, maxval=soft_real_u_b)

        # # Build accuracies
        p_acc = tf.reduce_mean(tf.sigmoid(p_scores))
        e_acc = tf.reduce_mean(tf.sigmoid(e_scores))

        # Build binary classification (cross-entropy) losses, equal to the negative log likelihood
        # for random variables following a Bernoulli law, divided by the batch size
        p_bernoulli_pd = BernoulliPd(logits=p_scores)
        p_loss_mean = tf.reduce_mean(p_bernoulli_pd.neglogp(fake_labels))
        e_bernoulli_pd = BernoulliPd(logits=e_scores)
        e_loss_mean = tf.reduce_mean(e_bernoulli_pd.neglogp(real_labels))

        # Add a gradient penalty (motivation from WGANs (Gulrajani),
        # but empirically useful in JS-GANs (Lucic et al. 2017))

        def batch_size(x):
            """Returns an int corresponding to the batch size of the input tensor"""
            return tf.to_float(tf.shape(x)[0], name='get_batch_size_in_fl32')

        shape_obz = (tf.to_int64(batch_size(p_obz)),) + self.ob_shape
        eps_obz = tf.random_uniform(shape=shape_obz, minval=0.0, maxval=1.0)
        obz_interp = eps_obz * p_obz + (1. - eps_obz) * e_obz
        shape_acs = (tf.to_int64(batch_size(self.p_acs)),) + self.ac_shape
        eps_acs = tf.random_uniform(shape=shape_acs, minval=0.0, maxval=1.0)
        acs_interp = eps_acs * self.p_acs + (1. - eps_acs) * self.e_acs
        interp_scores = self.reward_nn(obz_interp, acs_interp)
        grads = tf.gradients(interp_scores, [obz_interp, acs_interp], name="interp_grads")
        assert len(grads) == 2, "length must be exacty 2"
        grad_squared_norms = [tf.reduce_mean(tf.square(grad)) for grad in grads]
        grad_norm = tf.sqrt(tf.reduce_sum(grad_squared_norms))
        grad_pen = tf.reduce_mean(tf.square(grad_norm - 1.0))

        losses = OrderedDict()

        # Add losses
        losses.update({'d_policy_loss': p_loss_mean,
                       'd_expert_loss': e_loss_mean,
                       'd_ent_mean': ent_mean,
                       'd_ent_loss': ent_loss,
                       'd_policy_acc': p_acc,
                       'd_expert_acc': e_acc,
                       'd_grad_pen': grad_pen})

        # Assemble discriminator loss
        loss = p_loss_mean + e_loss_mean + ent_loss + 10 * grad_pen
        # gradient penalty coefficient aligned with the value used in Gulrajani et al.

        # Add assembled disciminator loss
        losses.update({'d_total_loss': loss})

        # Compute gradients
        grads = flatgrad(loss, self.trainable_vars, self.hps.clip_norm)

        # Create mpi adam optimizer
        self.optimizer = MpiAdamOptimizer(comm=self.comm,
                                          clip_norm=self.hps.clip_norm,
                                          learning_rate=self.hps.d_lr,
                                          name='d_adam')
        optimize_ = self.optimizer.minimize(loss=loss, var_list=self.trainable_vars)

        # Create callable objects
        phs = [self.p_obs, self.p_acs, self.e_obs, self.e_acs]
        self.get_losses = TheanoFunction(inputs=phs, outputs=list(losses.values()))
        self.get_grads = TheanoFunction(inputs=phs, outputs=grads)
        self.optimize = TheanoFunction(inputs=phs, outputs=optimize_)

        # Make loss names graspable from outside
        self.names = list(losses.keys())

        # Define synthetic reward
        if self.hps.non_satur_grad:
            # Recommended in the original GAN paper and later in Fedus et al. 2017 (Many Paths...)
            # 0 for expert-like states, goes to -inf for non-expert-like states
            # compatible with envs with traj cutoffs for good (expert-like) behavior
            # e.g. mountain car, which gets cut off when the car reaches the destination
            reward = tf.log_sigmoid(p_scores)
        else:
            # 0 for non-expert-like states, goes to +inf for expert-like states
            # compatible with envs with traj cutoffs for bad (non-expert-like) behavior
            # e.g. walking simulations that get cut off when the robot falls over
            reward = -tf.log(1. - tf.sigmoid(p_scores) + 1e-8)  # HAXX: avoids log(0)

        # Create Theano-like op that compute the synthetic reward
        self.compute_reward = TheanoFunction(inputs=[self.p_obs, self.p_acs],
                                             outputs=reward)

        # Summarize module information in logs
        log_module_info(logger, self.name, self.reward_nn)
Esempio n. 4
0
def learn(comm, env, xpo_agent_wrapper, sample_or_mode, gamma, max_kl,
          save_frequency, ckpt_dir, summary_dir, timesteps_per_batch,
          batch_size, experiment_name, ent_reg_scale, gae_lambda, cg_iters,
          cg_damping, vf_iters, vf_lr, max_iters):

    rank = comm.Get_rank()

    # Create policies
    pi = xpo_agent_wrapper('pi')
    old_pi = xpo_agent_wrapper('old_pi')

    # Create and retrieve already-existing placeholders
    ob = get_placeholder_cached(name='ob')
    ac = pi.pd_type.sample_placeholder([None])
    adv = tf.placeholder(name='adv', dtype=tf.float32, shape=[None])
    ret = tf.placeholder(name='ret', dtype=tf.float32, shape=[None])
    flat_tangent = tf.placeholder(name='flat_tan',
                                  dtype=tf.float32,
                                  shape=[None])

    # Build graphs
    kl_mean = tf.reduce_mean(old_pi.pd_pred.kl(pi.pd_pred))
    ent_mean = tf.reduce_mean(pi.pd_pred.entropy())
    ent_bonus = ent_reg_scale * ent_mean
    vf_err = tf.reduce_mean(tf.square(pi.v_pred - ret))  # MC error
    # The surrogate objective is defined as: advantage * pnew / pold
    ratio = tf.exp(pi.pd_pred.logp(ac) - old_pi.pd_pred.logp(ac))  # IS
    surr_gain = tf.reduce_mean(ratio * adv)  # surrogate objective (CPI)
    # Add entropy bonus
    optim_gain = surr_gain + ent_bonus

    losses = OrderedDict()

    # Add losses
    losses.update({
        'pol_kl_mean': kl_mean,
        'pol_ent_mean': ent_mean,
        'pol_ent_bonus': ent_bonus,
        'pol_surr_gain': surr_gain,
        'pol_optim_gain': optim_gain,
        'pol_vf_err': vf_err
    })

    # Build natural gradient material
    get_flat = GetFlat(pi.pol_trainable_vars)
    set_from_flat = SetFromFlat(pi.pol_trainable_vars)
    kl_grads = tf.gradients(kl_mean, pi.pol_trainable_vars)
    shapes = [var.get_shape().as_list() for var in pi.pol_trainable_vars]
    start = 0
    tangents = []
    for shape in shapes:
        sz = intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    # Create the gradient vector product
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(kl_grads, tangents)
    ])
    # Create the Fisher vector product
    fvp = flatgrad(gvp, pi.pol_trainable_vars)

    # Make the current `pi` become the next `old_pi`
    zipped = zipsame(old_pi.vars, pi.vars)
    updates_op = []
    for k, v in zipped:
        # Populate list of assignment operations
        logger.info("assignment: {} <- {}".format(k, v))
        assign_op = tf.assign(k, v)
        updates_op.append(assign_op)
    assert len(updates_op) == len(pi.vars)

    # Create mpi adam optimizer for the value function
    vf_optimizer = MpiAdamOptimizer(comm=comm,
                                    clip_norm=5.0,
                                    learning_rate=vf_lr,
                                    name='vf_adam')
    optimize_vf = vf_optimizer.minimize(loss=vf_err,
                                        var_list=pi.vf_trainable_vars)

    # Create gradients
    grads = flatgrad(optim_gain, pi.pol_trainable_vars)

    # Create callable objects
    assign_old_eq_new = TheanoFunction(inputs=[], outputs=updates_op)
    compute_losses = TheanoFunction(inputs=[ob, ac, adv, ret],
                                    outputs=list(losses.values()))
    compute_losses_grads = TheanoFunction(inputs=[ob, ac, adv, ret],
                                          outputs=list(losses.values()) +
                                          [grads])
    compute_fvp = TheanoFunction(inputs=[flat_tangent, ob, ac, adv],
                                 outputs=fvp)
    optimize_vf = TheanoFunction(inputs=[ob, ret], outputs=optimize_vf)

    # Initialise variables
    initialize()

    # Sync params of all processes with the params of the root process
    theta_init = get_flat()
    comm.Bcast(theta_init, root=0)
    set_from_flat(theta_init)

    vf_optimizer.sync_from_root(pi.vf_trainable_vars)

    # Create context manager that records the time taken by encapsulated ops
    timed = timed_cm_wrapper(comm, logger)

    if rank == 0:
        # Create summary writer
        summary_writer = tf.summary.FileWriterCache.get(summary_dir)

    # Create segment generator
    seg_gen = traj_segment_generator(env, pi, timesteps_per_batch,
                                     sample_or_mode)

    eps_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()

    # Define rolling buffers for recent stats aggregation
    maxlen = 100
    len_buffer = deque(maxlen=maxlen)
    env_ret_buffer = deque(maxlen=maxlen)
    pol_losses_buffer = deque(maxlen=maxlen)

    while iters_so_far <= max_iters:

        pretty_iter(logger, iters_so_far)
        pretty_elapsed(logger, tstart)

        # Verify that the processes are still in sync
        if iters_so_far > 0 and iters_so_far % 10 == 0:
            vf_optimizer.check_synced(pi.vf_trainable_vars)
            logger.info("vf params still in sync across processes")

        # Save the model
        if rank == 0 and iters_so_far % save_frequency == 0 and ckpt_dir is not None:
            model_path = osp.join(ckpt_dir, experiment_name)
            save_state(model_path, iters_so_far=iters_so_far)
            logger.info("saving model")
            logger.info("  @: {}".format(model_path))

        with timed("sampling mini-batch"):
            seg = seg_gen.__next__()

        augment_segment_gae_stats(seg, gamma, gae_lambda, rew_key="env_rews")

        # Standardize advantage function estimate
        seg['advs'] = (seg['advs'] - seg['advs'].mean()) / (seg['advs'].std() +
                                                            1e-8)

        # Update running mean and std
        if hasattr(pi, 'obs_rms'):
            with timed("normalizing obs via rms"):
                pi.obs_rms.update(seg['obs'], comm)

        def fisher_vector_product(p):
            computed_fvp = compute_fvp({
                flat_tangent: p,
                ob: seg['obs'],
                ac: seg['acs'],
                adv: seg['advs']
            })
            return mpi_mean_like(computed_fvp, comm) + cg_damping * p

        assign_old_eq_new({})

        # Compute gradients
        with timed("computing gradients"):
            *loss_before, g = compute_losses_grads({
                ob: seg['obs'],
                ac: seg['acs'],
                adv: seg['advs'],
                ret: seg['td_lam_rets']
            })

        loss_before = mpi_mean_like(loss_before, comm)

        g = mpi_mean_like(g, comm)

        if np.allclose(g, 0):
            logger.info("got zero gradient -> not updating")
        else:
            with timed("performing conjugate gradient procedure"):
                step_direction = conjugate_gradient(f_Ax=fisher_vector_product,
                                                    b=g,
                                                    cg_iters=cg_iters,
                                                    verbose=(rank == 0))
            assert np.isfinite(step_direction).all()
            shs = 0.5 * step_direction.dot(
                fisher_vector_product(step_direction))
            # shs is (1/2)*s^T*A*s in the paper
            lm = np.sqrt(shs / max_kl)
            # lm is 1/beta in the paper (max_kl is user-specified delta)
            full_step = step_direction / lm  # beta*s
            expected_improve = g.dot(full_step)  # project s on g
            surr_before = loss_before[4]  # 5-th in loss list
            step_size = 1.0
            theta_before = get_flat()

            with timed("updating policy"):
                for _ in range(
                        10):  # trying (10 times max) until the stepsize is OK
                    # Update the policy parameters
                    theta_new = theta_before + full_step * step_size
                    set_from_flat(theta_new)
                    pol_losses = compute_losses({
                        ob: seg['obs'],
                        ac: seg['acs'],
                        adv: seg['advs'],
                        ret: seg['td_lam_rets']
                    })

                    pol_losses_buffer.append(pol_losses)

                    pol_losses_mpi_mean = mpi_mean_like(pol_losses, comm)
                    surr = pol_losses_mpi_mean[4]
                    kl = pol_losses_mpi_mean[0]
                    actual_improve = surr - surr_before
                    logger.info("  expected: {:.3f} | actual: {:.3f}".format(
                        expected_improve, actual_improve))
                    if not np.isfinite(pol_losses_mpi_mean).all():
                        logger.info("  got non-finite value of losses :(")
                    elif kl > max_kl * 1.5:
                        logger.info(
                            "  violated KL constraint -> shrinking step.")
                    elif actual_improve < 0:
                        logger.info(
                            "  surrogate didn't improve -> shrinking step.")
                    else:
                        logger.info("  stepsize fine :)")
                        break
                    step_size *= 0.5  # backtracking when the step size is deemed inappropriate
                else:
                    logger.info("  couldn't compute a good step")
                    set_from_flat(theta_before)

        # Create Feeder object to iterate over (ob, ret) pairs
        feeder = Feeder(data_map={
            'obs': seg['obs'],
            'td_lam_rets': seg['td_lam_rets']
        },
                        enable_shuffle=True)

        # Update state-value function
        with timed("updating value function"):
            for _ in range(vf_iters):
                for minibatch in feeder.get_feed(batch_size=batch_size):
                    optimize_vf({
                        ob: minibatch['obs'],
                        ret: minibatch['td_lam_rets']
                    })

        # Log policy update statistics
        logger.info("logging pol training losses (log)")
        pol_losses_np_mean = np.mean(pol_losses_buffer, axis=0)
        pol_losses_mpi_mean = mpi_mean_reduce(pol_losses_buffer, comm, axis=0)
        zipped_pol_losses = zipsame(list(losses.keys()), pol_losses_np_mean,
                                    pol_losses_mpi_mean)
        logger.info(
            columnize(names=['name', 'local', 'global'],
                      tuples=zipped_pol_losses,
                      widths=[20, 16, 16]))

        # Log statistics

        logger.info("logging misc training stats (log + csv)")
        # Gather statistics across workers
        local_lens_rets = (seg['ep_lens'], seg['ep_env_rets'])
        gathered_lens_rets = comm.allgather(local_lens_rets)
        lens, env_rets = map(flatten_lists, zip(*gathered_lens_rets))
        # Extend the deques of recorded statistics
        len_buffer.extend(lens)
        env_ret_buffer.extend(env_rets)
        ep_len_mpi_mean = np.mean(len_buffer)
        ep_env_ret_mpi_mean = np.mean(env_ret_buffer)
        logger.record_tabular('ep_len_mpi_mean', ep_len_mpi_mean)
        logger.record_tabular('ep_env_ret_mpi_mean', ep_env_ret_mpi_mean)
        eps_this_iter = len(lens)
        timesteps_this_iter = sum(lens)
        eps_so_far += eps_this_iter
        timesteps_so_far += timesteps_this_iter
        eps_this_iter_mpi_mean = mpi_mean_like(eps_this_iter, comm)
        timesteps_this_iter_mpi_mean = mpi_mean_like(timesteps_this_iter, comm)
        eps_so_far_mpi_mean = mpi_mean_like(eps_so_far, comm)
        timesteps_so_far_mpi_mean = mpi_mean_like(timesteps_so_far, comm)
        logger.record_tabular('eps_this_iter_mpi_mean', eps_this_iter_mpi_mean)
        logger.record_tabular('timesteps_this_iter_mpi_mean',
                              timesteps_this_iter_mpi_mean)
        logger.record_tabular('eps_so_far_mpi_mean', eps_so_far_mpi_mean)
        logger.record_tabular('timesteps_so_far_mpi_mean',
                              timesteps_so_far_mpi_mean)
        logger.record_tabular('elapsed time',
                              prettify_time(time.time() -
                                            tstart))  # no mpi mean
        logger.record_tabular(
            'ev_td_lam_before',
            explained_variance(seg['vs'], seg['td_lam_rets']))
        iters_so_far += 1

        if rank == 0:
            logger.dump_tabular()

        if rank == 0:
            # Add summaries
            summary = tf.summary.Summary()
            tab = 'trpo'
            # Episode stats
            summary.value.add(tag="{}/{}".format(tab, 'mean_ep_len'),
                              simple_value=ep_len_mpi_mean)
            summary.value.add(tag="{}/{}".format(tab, 'mean_ep_env_ret'),
                              simple_value=ep_env_ret_mpi_mean)
            # Losses
            for name, loss in zipsame(list(losses.keys()),
                                      pol_losses_mpi_mean):
                summary.value.add(tag="{}/{}".format(tab, name),
                                  simple_value=loss)

            summary_writer.add_summary(summary, iters_so_far)