Beispiel #1
0
    def __init__(self, env, meta_pi, meta_oldpi, proximity_predictors,
                 num_primitives, trans_pis, trans_oldpis, config):
        self._env = env
        self._config = config
        self.meta_pi = meta_pi
        self.meta_oldpi = meta_oldpi
        self.proximity_predictors = proximity_predictors
        self._use_proximity_predictor = config.use_proximity_predictor
        self.trans_pis = trans_pis
        self.trans_oldpis = trans_oldpis
        self._num_primitives = num_primitives
        self._use_trans = config.use_trans

        self._cur_lrmult = 0
        self._entcoeff = config.entcoeff
        self._meta_entcoeff = config.meta_entcoeff
        self._trans_entcoeff = config.trans_entcoeff
        self._optim_epochs = config.optim_epochs
        self._optim_proximity_epochs = config.proximity_optim_epochs
        self._optim_stepsize = config.optim_stepsize
        self._optim_proximity_stepsize = config.proximity_learning_rate
        self._optim_batchsize = config.optim_batchsize

        # global step
        self.global_step = tf.Variable(0, name='global_step', dtype=tf.int64, trainable=False)
        self.update_global_step = tf.assign(self.global_step, self.global_step + 1)

        # tensorboard summary
        self._is_chef = (MPI.COMM_WORLD.Get_rank() == 0)
        if self._is_chef:
            self.summary_name = ["reward", "length"]
            self.summary_name += env.unwrapped.reward_type
            self.summary_histogram_name = ['reward_dist', 'primitive_dist']
            if self._use_trans:
                for pi in self.trans_pis:
                    self.summary_name += ["trans_{}/average_length".format(pi.env_name)]
                    self.summary_name += ["trans_{}/rew".format(pi.env_name)]
                    if self._use_proximity_predictor:
                        self.summary_name += ["trans_{}/proximity_rew".format(pi.env_name)]
                self.summary_histogram_name += ["trans_len_histogram"]
                if self._use_proximity_predictor:
                    for proximity in self.proximity_predictors:
                        self.summary_histogram_name += [
                            'proximity_predictor_{}/hist_success_final'.format(proximity.env_name)]
                        self.summary_histogram_name += [
                            'proximity_predictor_{}/hist_success_intermediate'.format(proximity.env_name)]
                        self.summary_histogram_name += [
                            'proximity_predictor_{}/hist_fail_final'.format(proximity.env_name)]
                        self.summary_histogram_name += [
                            'proximity_predictor_{}/hist_fail_intermediate'.format(proximity.env_name)]

        # build loss/optimizers
        self._build()

        if self._is_chef and self._config.is_train:
            self.ep_stats = stats(self.summary_name, self.summary_histogram_name)
            self.writer = U.file_writer(config.log_dir)
Beispiel #2
0
    def __init__(self, env, policy, old_policy, config):
        self._env = env
        self._config = config
        self.policy = policy
        self.old_policy = old_policy

        self._entcoeff = config.entcoeff
        self._optim_epochs = config.optim_epochs
        self._optim_stepsize = config.optim_stepsize
        self._optim_batchsize = config.optim_batchsize

        # global step
        self.global_step = tf.Variable(0,
                                       name='global_step',
                                       dtype=tf.int64,
                                       trainable=False)
        self.update_global_step = tf.assign(self.global_step,
                                            self.global_step + 1)

        # tensorboard summary
        self._is_chef = (MPI.COMM_WORLD.Get_rank() == 0)
        self._num_workers = MPI.COMM_WORLD.Get_size()
        if self._is_chef:
            self.summary_name = ["reward", "length"]
            self.summary_name += env.unwrapped.reward_type

        # build loss/optimizers
        if self._config.rl_method == 'trpo':
            self._build_trpo()
        elif self._config.rl_method == 'ppo':
            self._build_ppo()
        else:
            raise NotImplementedError

        if self._is_chef and self._config.is_train:
            self.ep_stats = stats(self.summary_name)
            self.writer = U.file_writer(config.log_dir)
Beispiel #3
0
def run(args):
    sess = U.single_threaded_session()
    sess.__enter__()

    is_chef = (MPI.COMM_WORLD.Get_rank() == 0)
    num_workers = MPI.COMM_WORLD.Get_size()

    if args.method == 'trpo':
        args.num_rollouts *= args.num_contexts
        args.num_contexts = 1

    # setting envs and networks
    env = gym.make(args.env)
    if args.obs_norm == 'predefined':
        env.unwrapped.set_norm(True)

    global_network = MlpPolicy(0, 'global', env, args)
    global_runner = Runner(env, global_network, args)
    global_trainer = GlobalTrainer('global', env, global_runner,
                                   global_network, args)

    networks = []
    old_networks = []
    trainers = []
    for i in range(args.num_contexts):
        network = MlpPolicy(i, 'local_%d' % i, env, args)
        old_network = MlpPolicy(i, 'old_local_%d' % i, env, args)
        networks.append(network)
        old_networks.append(old_network)

    for i in range(args.num_contexts):
        runner = Runner(env, networks[i], args)
        trainer = LocalTrainer(i, env, runner, networks[i], old_networks[i],
                               networks, global_network, args)
        trainers.append(trainer)

    # summaries
    if is_chef:
        if args.debug:
            print_variables()

        exp_name = '{}_{}'.format(args.env, args.method)
        if args.prefix:
            exp_name = '{}_{}'.format(exp_name, args.prefix)
        args.log_dir = os.path.join(args.log_dir, exp_name)
        logger.info("Events directory: %s", args.log_dir)
        os.makedirs(args.log_dir, exist_ok=True)
        write_info(args)

        if args.is_train:
            summary_writer = tf.summary.FileWriter(args.log_dir)
            summary_name = global_trainer.summary_name.copy()
            for trainer in trainers:
                summary_name.extend(trainer.summary_name)
            ep_stats = stats(summary_name)

    # initialize model
    if args.load_model_path:
        logger.info('Load models from checkpoint...')

        def load_model(load_model_path, var_list=None):
            if os.path.isdir(load_model_path):
                ckpt_path = tf.train.latest_checkpoint(load_model_path)
            else:
                ckpt_path = load_model_path
            logger.info("Load checkpoint: %s", ckpt_path)
            U.load_state(ckpt_path, var_list)

        load_model(args.load_model_path)

    # evaluation
    if not args.is_train:
        assert num_workers == 1
        global_trainer.evaluate(ckpt_num=None, record=args.record)
        for trainer in trainers:
            trainer.evaluate(ckpt_num=None,
                             record=args.record,
                             context=trainer.id)
        return

    # training
    global_step = sess.run(global_trainer.global_step)
    logger.info("Starting training at step=%d", global_step)

    pbar = tqdm.trange(global_step, args.T, total=args.T, initial=global_step)
    for epoch in pbar:
        for trainer in trainers:
            trainer.init_network()
        step = epoch * args.R

        for _ in range(args.R):
            # get rollouts
            rollouts = []
            for trainer in trainers:
                trainer.generate_rollout(
                    sess=sess,
                    context=trainer.id if args.method == 'dnc' else None)
                rollouts.append(trainer.rollout)

            # update local policies
            info = {}
            for trainer in trainers:
                _info = trainer.update(sess, rollouts, step)
                info.update(_info)
            if is_chef:
                ep_stats.add_all_summary_dict(summary_writer, info, step)

            # update ob running average
            if args.obs_norm == 'learn':
                trainers[0].update_ob_rms(rollouts)

            step += 1

        # update global policy using the last rollouts
        global_info = info
        if args.method == 'dnc':
            ob = np.concatenate([rollout['ob'] for rollout in rollouts])
            ac = np.concatenate([rollout['ac'] for rollout in rollouts])
            ret = np.concatenate([rollout['tdlamret'] for rollout in rollouts])
            info = global_trainer.update(step, ob, ac, ret)
            global_info.update(info)
        else:
            trainers[0].copy_network()

        if is_chef:
            # evaluate local policies
            for trainer in trainers:
                trainer.evaluate(
                    step,
                    record=args.training_video_record,
                    context=trainer.id if args.method == 'dnc' else None)

            # evaluate global policy
            info = global_trainer.summary(step)
            global_info.update(info)
            ep_stats.add_all_summary_dict(summary_writer, global_info, step)
            pbar.set_description(
                '[step {}] reward {} length {} success {}'.format(
                    step, global_info['global/reward'],
                    global_info['global/length'],
                    global_info['global/success']))

    env.close()
def learn(env, policy_func, checkpoint_dir, log_dir, *,
        timesteps_per_actorbatch, # timesteps per actor per update
        clip_param, entcoeff, # clipping parameter epsilon, entropy coeff
        optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers
        gamma, lam, # advantage estimation
        max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0,  # time constraint
        callback=None, # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant' # annealing for stepsize parameters (epsilon and adam)
        ):
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space) # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space) # Network for old policy
    atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return

    lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold
    surr1 = ratio * atarg # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
    pol_surr = - U.mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
        for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    writer = U.file_writer(log_dir)
    U.initialize()
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi, env, timesteps_per_actorbatch, stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards
    ep_stats = stats(["Episode_rewards", "Episode_length"])

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************"%iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy

        assign_old_eq_new() # set old parameter values to new parameter values
        logger.log("Optimizing...")
        logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [] # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(losses, axis=0)))

        logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
            losses.append(newlosses)
        meanlosses,_,_ = mpi_moments(losses, axis=0)
        logger.log(fmt_row(13, meanlosses))
        for (lossval, name) in zipsame(meanlosses, loss_names):
            logger.record_tabular("loss_"+name, lossval)
        logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank()==0:
            logger.dump_tabular()
            ep_stats.add_all_summary(writer, [np.mean(rewbuffer),
                                              np.mean(lenbuffer)], iters_so_far)
        if iters_so_far % 50 == 0:
            U.save_state('{}/ppo1-{}'.format(checkpoint_dir, iters_so_far))
        iters_so_far += 1
def learn(
        env,
        policy_func,
        checkpoint_dir,
        log_dir,
        *,
        render,
        timesteps_per_batch,  # what to train on
        max_kl,
        cg_iters,
        gamma,
        lam,  # advantage estimation
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None):

    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("{}/pi".format(env.spec.id), ob_space, ac_space)
    oldpi = policy_func("{}/oldpi".format(env.spec.id), ob_space, ac_space)

    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    entbonus = entcoeff * meanent

    vferr = U.mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = U.mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [v for v in all_var_list if v.name.split("/")[2] == "pol"]
    vf_var_list = [v for v in all_var_list if v.name.split("/")[2] == "vf"]
    vfadam = MpiAdam(vf_var_list)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n(
        [U.sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents)])  #pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    if rank == 0:
        writer = U.file_writer(log_dir)
    U.initialize()
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True,
                                     render=render)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards
    reward_details_buffer = {}
    for name in env.unwrapped.reward_type:
        reward_details_buffer.update({name: deque(maxlen=40)})
    if rank == 0:
        ep_stats = stats(["Episode_rewards", "Episode_length"] +
                         env.unwrapped.reward_type)
    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        if iters_so_far % 100 == 0:
            U.save_state('{}/trpo-{}'.format(checkpoint_dir, iters_so_far))
        logger.log("********** Iteration %i ************" % iters_so_far)

        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*args)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
            assert np.isfinite(stepdir).all()
            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(
                    np.array(compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f" %
                           (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        for name in env.unwrapped.reward_type:
            lrlocal += ([seg[name]], )
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        log_data = map(flatten_lists, zip(*listoflrpairs))
        log_data = [i for i in log_data]
        lens, rews = log_data[0], log_data[1]

        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        for i, name in enumerate(env.unwrapped.reward_type):
            reward_details_buffer[name].extend(log_data[i + 2])

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if rank == 0:
            logger.dump_tabular()
            to_write = [np.mean(rewbuffer), np.mean(lenbuffer)]
            for name in env.unwrapped.reward_type:
                to_write += [
                    np.mean(reward_details_buffer[name]),
                ]
            ep_stats.add_all_summary(writer, to_write, iters_so_far)
        iters_so_far += 1