コード例 #1
0
    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100,
              tb_log_name="TRPO",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
                as writer:
            self._setup_learn(seed)

            with self.sess.as_default():
                seg_gen = traj_segment_generator(
                    self.policy_pi,
                    self.env,
                    self.timesteps_per_batch,
                    reward_giver=self.reward_giver,
                    gail=self.using_gail)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                len_buffer = deque(
                    maxlen=40)  # rolling buffer for episode lengths
                reward_buffer = deque(
                    maxlen=40)  # rolling buffer for episode rewards
                self.episode_reward = np.zeros((self.n_envs, ))

                true_reward_buffer = None
                if self.using_gail:
                    true_reward_buffer = deque(maxlen=40)

                    # Initialize dataloader
                    batchsize = self.timesteps_per_batch // self.d_step
                    self.expert_dataset.init_dataloader(batchsize)

                    #  Stats not used for now
                    # TODO: replace with normal tb logging
                    #  g_loss_stats = Stats(loss_names)
                    #  d_loss_stats = Stats(reward_giver.loss_name)
                    #  ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"])

                while True:
                    if callback is not None:
                        # Only stop training if return value is False, not when it is None. This is for backwards
                        # compatibility with callbacks that have no return statement.
                        if callback(locals(), globals()) is False:
                            break
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    def fisher_vector_product(vec):
                        return self.allmean(
                            self.compute_fvp(
                                vec, *fvpargs,
                                sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for k in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()
                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        observation, action, atarg, tdlamret = seg["ob"], seg[
                            "ac"], seg["adv"], seg["tdlamret"]
                        vpredbefore = seg[
                            "vpred"]  # predicted value function before update
                        atarg = (atarg - atarg.mean()) / atarg.std(
                        )  # standardized advantage function estimate

                        # true_rew is the reward without discount
                        if writer is not None:
                            self.episode_reward = total_episode_reward_logger(
                                self.episode_reward, seg["true_rew"].reshape(
                                    (self.n_envs, -1)), seg["dones"].reshape(
                                        (self.n_envs, -1)), writer,
                                self.num_timesteps)

                        args = seg["ob"], seg["ob"], seg["ac"], atarg
                        fvpargs = [arr[::5] for arr in args]

                        self.assign_old_eq_new(sess=self.sess)

                        with self.timed("computegrad"):
                            steps = self.num_timesteps + (k + 1) * (
                                seg["total_timestep"] / self.g_step)
                            run_options = tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE)
                            run_metadata = tf.RunMetadata(
                            ) if self.full_tensorboard_log else None
                            # run loss backprop with summary, and save the metadata (memory, compute time, ...)
                            if writer is not None:
                                summary, grad, *lossbefore = self.compute_lossandgrad(
                                    *args,
                                    tdlamret,
                                    sess=self.sess,
                                    options=run_options,
                                    run_metadata=run_metadata)
                                if self.full_tensorboard_log:
                                    writer.add_run_metadata(
                                        run_metadata, 'step%d' % steps)
                                writer.add_summary(summary, steps)
                            else:
                                _, grad, *lossbefore = self.compute_lossandgrad(
                                    *args,
                                    tdlamret,
                                    sess=self.sess,
                                    options=run_options,
                                    run_metadata=run_metadata)

                        lossbefore = self.allmean(np.array(lossbefore))
                        grad = self.allmean(grad)
                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            with self.timed("conjugate_gradient"):
                                stepdir = conjugate_gradient(
                                    fisher_vector_product,
                                    grad,
                                    cg_iters=self.cg_iters,
                                    verbose=self.rank == 0
                                    and self.verbose >= 1)
                            assert np.isfinite(stepdir).all()
                            shs = .5 * stepdir.dot(
                                fisher_vector_product(stepdir))
                            # abs(shs) to avoid taking square root of negative values
                            lagrange_multiplier = np.sqrt(
                                abs(shs) / self.max_kl)
                            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                            fullstep = stepdir / lagrange_multiplier
                            expectedimprove = grad.dot(fullstep)
                            surrbefore = lossbefore[0]
                            stepsize = 1.0
                            thbefore = self.get_flat()
                            thnew = None
                            for _ in range(10):
                                thnew = thbefore + fullstep * stepsize
                                self.set_from_flat(thnew)
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(
                                        self.compute_losses(*args,
                                                            sess=self.sess)))
                                improve = surr - surrbefore
                                logger.log("Expected: %.3f Actual: %.3f" %
                                           (expectedimprove, improve))
                                if not np.isfinite(mean_losses).all():
                                    logger.log(
                                        "Got non-finite value of losses -- bad!"
                                    )
                                elif kl_loss > self.max_kl * 1.5:
                                    logger.log(
                                        "violated KL constraint. shrinking step."
                                    )
                                elif improve < 0:
                                    logger.log(
                                        "surrogate didn't improve. shrinking step."
                                    )
                                else:
                                    logger.log("Stepsize OK!")
                                    break
                                stepsize *= .5
                            else:
                                logger.log("couldn't compute a good step")
                                self.set_from_flat(thbefore)
                            if self.nworkers > 1 and iters_so_far % 20 == 0:
                                # list of tuples
                                paramsums = MPI.COMM_WORLD.allgather(
                                    (thnew.sum(), self.vfadam.getflat().sum()))
                                assert all(
                                    np.allclose(ps, paramsums[0])
                                    for ps in paramsums[1:])

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                # NOTE: for recurrent policies, use shuffle=False?
                                for (mbob, mbret) in dataset.iterbatches(
                                    (seg["ob"], seg["tdlamret"]),
                                        include_final_partial_batch=False,
                                        batch_size=128,
                                        shuffle=True):
                                    grad = self.allmean(
                                        self.compute_vflossandgrad(
                                            mbob, mbob, mbret, sess=self.sess))
                                    self.vfadam.update(grad, self.vf_stepsize)

                    for (loss_name, loss_val) in zip(self.loss_names,
                                                     mean_losses):
                        logger.record_tabular(loss_name, loss_val)

                    logger.record_tabular(
                        "explained_variance_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    if self.using_gail:
                        # ------------------ Update D ------------------
                        logger.log("Optimizing Discriminator...")
                        logger.log(fmt_row(13, self.reward_giver.loss_name))
                        assert len(observation) == self.timesteps_per_batch
                        batch_size = self.timesteps_per_batch // self.d_step

                        # NOTE: uses only the last g step for observation
                        d_losses = [
                        ]  # list of tuples, each of which gives the loss for a minibatch
                        # NOTE: for recurrent policies, use shuffle=False?
                        for ob_batch, ac_batch in dataset.iterbatches(
                            (observation, action),
                                include_final_partial_batch=False,
                                batch_size=batch_size,
                                shuffle=True):
                            ob_expert, ac_expert = self.expert_dataset.get_next_batch(
                            )
                            # update running mean/std for reward_giver
                            if self.reward_giver.normalize:
                                self.reward_giver.obs_rms.update(
                                    np.concatenate((ob_batch, ob_expert), 0))

                            # Reshape actions if needed when using discrete actions
                            if isinstance(self.action_space,
                                          gym.spaces.Discrete):
                                if len(ac_batch.shape) == 2:
                                    ac_batch = ac_batch[:, 0]
                                if len(ac_expert.shape) == 2:
                                    ac_expert = ac_expert[:, 0]
                            *newlosses, grad = self.reward_giver.lossandgrad(
                                ob_batch, ac_batch, ob_expert, ac_expert)
                            self.d_adam.update(self.allmean(grad),
                                               self.d_stepsize)
                            d_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"],
                                    seg["ep_true_rets"])  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(
                            lr_local)  # list of tuples
                        lens, rews, true_rets = map(flatten_lists,
                                                    zip(*list_lr_pairs))
                        true_reward_buffer.extend(true_rets)
                    else:
                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"]
                                    )  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(
                            lr_local)  # list of tuples
                        lens, rews = map(flatten_lists, zip(*list_lr_pairs))
                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)

                    if len(len_buffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))
                        logger.record_tabular("EpRewMean",
                                              np.mean(reward_buffer))
                    if self.using_gail:
                        logger.record_tabular("EpTrueRewMean",
                                              np.mean(true_reward_buffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(
                        seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        return self
コード例 #2
0
    def learn(self,
              total_timesteps,
              callback=None,
              log_interval=100,
              tb_log_name="PPO1",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
                as writer:
            self._setup_learn()

            assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the PPO1 model must be " \
                                                               "an instance of common.policies.ActorCriticPolicy."

            with self.sess.as_default():
                self.adam.sync()
                callback.on_training_start(locals(), globals())

                # Prepare for rollouts
                seg_gen = traj_segment_generator(self.policy_pi,
                                                 self.env,
                                                 self.timesteps_per_actorbatch,
                                                 callback=callback)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()

                # rolling buffer for episode lengths
                len_buffer = deque(maxlen=100)
                # rolling buffer for episode rewards
                reward_buffer = deque(maxlen=100)

                while True:
                    if timesteps_so_far >= total_timesteps:
                        break

                    if self.schedule == 'constant':
                        cur_lrmult = 1.0
                    elif self.schedule == 'linear':
                        cur_lrmult = max(
                            1.0 - float(timesteps_so_far) / total_timesteps, 0)
                    else:
                        raise NotImplementedError

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    seg = seg_gen.__next__()

                    # Stop training early (triggered by the callback)
                    if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                        break

                    add_vtarg_and_adv(seg, self.gamma, self.lam)

                    # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                    observations, actions = seg["observations"], seg["actions"]
                    atarg, tdlamret = seg["adv"], seg["tdlamret"]

                    # true_rew is the reward without discount
                    if writer is not None:
                        total_episode_reward_logger(
                            self.episode_reward, seg["true_rewards"].reshape(
                                (self.n_envs, -1)), seg["dones"].reshape(
                                    (self.n_envs, -1)), writer,
                            self.num_timesteps)

                    # predicted value function before udpate
                    vpredbefore = seg["vpred"]

                    # standardized advantage function estimate
                    atarg = (atarg - atarg.mean()) / atarg.std()
                    dataset = Dataset(dict(ob=observations,
                                           ac=actions,
                                           atarg=atarg,
                                           vtarg=tdlamret),
                                      shuffle=not self.policy.recurrent)
                    optim_batchsize = self.optim_batchsize or observations.shape[
                        0]

                    # set old parameter values to new parameter values
                    self.assign_old_eq_new(sess=self.sess)
                    logger.log("Optimizing...")
                    logger.log(fmt_row(13, self.loss_names))

                    # Here we do a bunch of optimization epochs over the data
                    for k in range(self.optim_epochs):
                        # list of tuples, each of which gives the loss for a minibatch
                        losses = []
                        for i, batch in enumerate(
                                dataset.iterate_once(optim_batchsize)):
                            steps = (
                                self.num_timesteps + k * optim_batchsize +
                                int(i *
                                    (optim_batchsize / len(dataset.data_map))))
                            if writer is not None:
                                # run loss backprop with summary, but once every 10 runs save the metadata
                                # (memory, compute time, ...)
                                if self.full_tensorboard_log and (1 +
                                                                  k) % 10 == 0:
                                    run_options = tf.compat.v1.RunOptions(
                                        trace_level=tf.compat.v1.RunOptions.
                                        FULL_TRACE)
                                    run_metadata = tf.compat.v1.RunMetadata()
                                    summary, grad, *newlosses = self.lossandgrad(
                                        batch["ob"],
                                        batch["ob"],
                                        batch["ac"],
                                        batch["atarg"],
                                        batch["vtarg"],
                                        cur_lrmult,
                                        sess=self.sess,
                                        options=run_options,
                                        run_metadata=run_metadata)
                                    writer.add_run_metadata(
                                        run_metadata, 'step%d' % steps)
                                else:
                                    summary, grad, *newlosses = self.lossandgrad(
                                        batch["ob"],
                                        batch["ob"],
                                        batch["ac"],
                                        batch["atarg"],
                                        batch["vtarg"],
                                        cur_lrmult,
                                        sess=self.sess)
                                writer.add_summary(summary, steps)
                            else:
                                _, grad, *newlosses = self.lossandgrad(
                                    batch["ob"],
                                    batch["ob"],
                                    batch["ac"],
                                    batch["atarg"],
                                    batch["vtarg"],
                                    cur_lrmult,
                                    sess=self.sess)

                            self.adam.update(grad,
                                             self.optim_stepsize * cur_lrmult)
                            losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(losses, axis=0)))

                    logger.log("Evaluating losses...")
                    losses = []
                    for batch in dataset.iterate_once(optim_batchsize):
                        newlosses = self.compute_losses(batch["ob"],
                                                        batch["ob"],
                                                        batch["ac"],
                                                        batch["atarg"],
                                                        batch["vtarg"],
                                                        cur_lrmult,
                                                        sess=self.sess)
                        losses.append(newlosses)
                    mean_losses, _, _ = mpi_moments(losses, axis=0)
                    logger.log(fmt_row(13, mean_losses))
                    for (loss_val, name) in zipsame(mean_losses,
                                                    self.loss_names):
                        logger.record_tabular("loss_" + name, loss_val)
                    logger.record_tabular(
                        "ev_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    # local values
                    lrlocal = (seg["ep_lens"], seg["ep_rets"])

                    # list of tuples
                    listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)
                    lens, rews = map(flatten_lists, zip(*listoflrpairs))
                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)
                    if len(len_buffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))
                        logger.record_tabular("EpRewMean",
                                              np.mean(reward_buffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(
                        seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1
                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)
                    if self.verbose >= 1 and MPI.COMM_WORLD.Get_rank() == 0:
                        logger.dump_tabular()
        callback.on_training_end()
        return self
コード例 #3
0
ファイル: pposgd_simple.py プロジェクト: lingjunz/pong_src
    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100,
              tb_log_name="PPO1",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
                as writer:
            self._setup_learn(seed)

            assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the PPO1 model must be " \
                                                               "an instance of common.policies.ActorCriticPolicy({}).".format(self.policy)

            with self.sess.as_default():
                self.adam.sync()
                trajectory_dic = None
                # Prepare for rollouts
                seg_gen = traj_segment_generator(self.policy_pi, self.env,
                                                 self.timesteps_per_actorbatch)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()

                # rolling buffer for episode lengths
                lenbuffer = deque(maxlen=100)
                # rolling buffer for episode rewards
                rewbuffer = deque(maxlen=100)

                self.episode_reward = np.zeros((self.n_envs, ))
                if self.save_trajectory:
                    hidden_list = []
                    obs_list = []
                    act_list = []
                    rwds_list = []
                    dones_list = []

                while True:
                    if callback is not None:
                        # Only stop training if return value is False, not when it is None. This is for backwards
                        # compatibility with callbacks that have no return statement.
                        if callback(locals(), globals()) is False:
                            break
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    if self.schedule == 'constant':
                        cur_lrmult = 1.0
                    elif self.schedule == 'linear':
                        cur_lrmult = max(
                            1.0 - float(timesteps_so_far) / total_timesteps, 0)
                    else:
                        raise NotImplementedError

                    # logger.log("********** Iteration %i ************" % iters_so_far)
                    logger.log("********** Iteration %i %i************" %
                               (iters_so_far, self.n_envs))

                    seg = seg_gen.__next__()

                    add_vtarg_and_adv(seg, self.gamma, self.lam)

                    # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                    obs_ph, hiddens_ph, action_ph, atarg, tdlamret = seg[
                        "ob"], seg["hiddens"], seg["ac"], seg["adv"], seg[
                            "tdlamret"]
                    # print(">>>hiddens_ph:",len(hiddens_ph))
                    if self.save_trajectory:
                        rwds_ph, dones_ph = seg["rew"], seg["dones"]
                        obs_list.append(obs_ph.copy())
                        hidden_list.append(hiddens_ph.copy())
                        act_list.append(action_ph.copy())
                        rwds_list.append(rwds_ph.copy())
                        dones_list.append(dones_ph.copy())

                    # true_rew is the reward without discount
                    if writer is not None:
                        self.episode_reward = total_episode_reward_logger(
                            self.episode_reward, seg["true_rew"].reshape(
                                (self.n_envs, -1)), seg["dones"].reshape(
                                    (self.n_envs, -1)), writer,
                            self.num_timesteps)

                    # predicted value function before udpate
                    vpredbefore = seg["vpred"]

                    # standardized advantage function estimate
                    atarg = (atarg - atarg.mean()) / atarg.std()
                    dataset = Dataset(dict(ob=obs_ph,
                                           ac=action_ph,
                                           atarg=atarg,
                                           vtarg=tdlamret),
                                      shuffle=not self.policy.recurrent)
                    optim_batchsize = self.optim_batchsize or obs_ph.shape[0]

                    # set old parameter values to new parameter values
                    self.assign_old_eq_new(sess=self.sess)
                    logger.log("Optimizing...")
                    logger.log(fmt_row(13, self.loss_names))

                    # Here we do a bunch of optimization epochs over the data
                    for k in range(self.optim_epochs):
                        # list of tuples, each of which gives the loss for a minibatch
                        losses = []
                        for i, batch in enumerate(
                                dataset.iterate_once(optim_batchsize)):
                            steps = (
                                self.num_timesteps + k * optim_batchsize +
                                int(i *
                                    (optim_batchsize / len(dataset.data_map))))
                            if writer is not None:
                                # run loss backprop with summary, but once every 10 runs save the metadata
                                # (memory, compute time, ...)
                                if self.full_tensorboard_log and (1 +
                                                                  k) % 10 == 0:
                                    run_options = tf.RunOptions(
                                        trace_level=tf.RunOptions.FULL_TRACE)
                                    run_metadata = tf.RunMetadata()

                                    summary, grad, *newlosses = self.lossandgrad(
                                        batch["ob"],
                                        batch["ob"],
                                        batch["ac"],
                                        batch["atarg"],
                                        batch["vtarg"],
                                        cur_lrmult,
                                        sess=self.sess,
                                        options=run_options,
                                        run_metadata=run_metadata)
                                    writer.add_run_metadata(
                                        run_metadata, 'step%d' % steps)
                                else:
                                    summary, grad, *newlosses = self.lossandgrad(
                                        batch["ob"],
                                        batch["ob"],
                                        batch["ac"],
                                        batch["atarg"],
                                        batch["vtarg"],
                                        cur_lrmult,
                                        sess=self.sess)
                                writer.add_summary(summary, steps)
                            else:
                                _, grad, *newlosses = self.lossandgrad(
                                    batch["ob"],
                                    batch["ob"],
                                    batch["ac"],
                                    batch["atarg"],
                                    batch["vtarg"],
                                    cur_lrmult,
                                    sess=self.sess)

                            self.adam.update(grad,
                                             self.optim_stepsize * cur_lrmult)
                            losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(losses, axis=0)))

                    logger.log("Evaluating losses...")
                    losses = []
                    for batch in dataset.iterate_once(optim_batchsize):
                        newlosses = self.compute_losses(batch["ob"],
                                                        batch["ob"],
                                                        batch["ac"],
                                                        batch["atarg"],
                                                        batch["vtarg"],
                                                        cur_lrmult,
                                                        sess=self.sess)
                        losses.append(newlosses)
                    mean_losses, _, _ = mpi_moments(losses, axis=0)
                    logger.log(fmt_row(13, mean_losses))
                    for (loss_val, name) in zipsame(mean_losses,
                                                    self.loss_names):
                        logger.record_tabular("loss_" + name, loss_val)
                    logger.record_tabular(
                        "ev_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    # local values
                    lrlocal = (seg["ep_lens"], seg["ep_rets"])

                    # list of tuples
                    listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)
                    lens, rews = map(flatten_lists, zip(*listoflrpairs))
                    lenbuffer.extend(lens)
                    rewbuffer.extend(rews)
                    if len(lenbuffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
                        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(
                        seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1
                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)
                    if self.verbose >= 1 and MPI.COMM_WORLD.Get_rank() == 0:
                        logger.dump_tabular()

                if self.save_trajectory:

                    length = np.vstack(obs_list).shape[0]
                    print("Save trajectory...(length:{})".format(length))
                    trajectory_dic = {
                        "all_obvs": np.vstack(obs_list).reshape(length, -1),
                        "all_hiddens":
                        np.vstack(hidden_list).reshape(length, -1),
                        "all_acts": np.vstack(act_list).reshape(length, -1),
                        "all_rwds": np.vstack(rwds_list).reshape(length, -1),
                        "all_dones": np.vstack(dones_list).reshape(length, -1)
                    }

                    # with open('../saved/{}-trajectory.pkl'.format(str(self.__class__).split("'")[-2].split(".")[-1]), 'wb+') as f:
                    #         pkl.dump(trajectory_dic, f, protocol=2)

        return self, trajectory_dic
コード例 #4
0
ファイル: trpo_mpi.py プロジェクト: safrooze/stable-baselines
    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100):
        with SetVerbosity(self.verbose):
            self._setup_learn(seed)

            with self.sess.as_default():
                seg_gen = traj_segment_generator(
                    self.policy_pi,
                    self.env,
                    self.timesteps_per_batch,
                    reward_giver=self.reward_giver,
                    gail=self.using_gail)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                lenbuffer = deque(
                    maxlen=40)  # rolling buffer for episode lengths
                rewbuffer = deque(
                    maxlen=40)  # rolling buffer for episode rewards

                true_rewbuffer = None
                if self.using_gail:
                    true_rewbuffer = deque(maxlen=40)
                    #  Stats not used for now
                    #  g_loss_stats = Stats(loss_names)
                    #  d_loss_stats = Stats(reward_giver.loss_name)
                    #  ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"])

                    # if provide pretrained weight
                    if self.pretrained_weight is not None:
                        tf_util.load_state(
                            self.pretrained_weight,
                            var_list=tf_util.get_globals_vars("pi"),
                            sess=self.sess)

                while True:
                    if callback:
                        callback(locals(), globals())
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    def fisher_vector_product(vec):
                        return self.allmean(
                            self.compute_fvp(
                                vec, *fvpargs,
                                sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for _ in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()
                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        observation, action, atarg, tdlamret = seg["ob"], seg[
                            "ac"], seg["adv"], seg["tdlamret"]
                        vpredbefore = seg[
                            "vpred"]  # predicted value function before udpate
                        atarg = (atarg - atarg.mean()) / atarg.std(
                        )  # standardized advantage function estimate

                        args = seg["ob"], seg["ob"], seg["ac"], atarg
                        fvpargs = [arr[::5] for arr in args]

                        self.assign_old_eq_new(sess=self.sess)

                        with self.timed("computegrad"):
                            *lossbefore, grad = self.compute_lossandgrad(
                                *args, sess=self.sess)
                        lossbefore = self.allmean(np.array(lossbefore))
                        grad = self.allmean(grad)
                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            with self.timed("cg"):
                                stepdir = conjugate_gradient(
                                    fisher_vector_product,
                                    grad,
                                    cg_iters=self.cg_iters,
                                    verbose=self.rank == 0
                                    and self.verbose >= 1)
                            assert np.isfinite(stepdir).all()
                            shs = .5 * stepdir.dot(
                                fisher_vector_product(stepdir))
                            # abs(shs) to avoid taking square root of negative values
                            lagrange_multiplier = np.sqrt(
                                abs(shs) / self.max_kl)
                            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                            fullstep = stepdir / lagrange_multiplier
                            expectedimprove = grad.dot(fullstep)
                            surrbefore = lossbefore[0]
                            stepsize = 1.0
                            thbefore = self.get_flat()
                            thnew = None
                            for _ in range(10):
                                thnew = thbefore + fullstep * stepsize
                                self.set_from_flat(thnew)
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(
                                        self.compute_losses(*args,
                                                            sess=self.sess)))
                                improve = surr - surrbefore
                                logger.log("Expected: %.3f Actual: %.3f" %
                                           (expectedimprove, improve))
                                if not np.isfinite(mean_losses).all():
                                    logger.log(
                                        "Got non-finite value of losses -- bad!"
                                    )
                                elif kl_loss > self.max_kl * 1.5:
                                    logger.log(
                                        "violated KL constraint. shrinking step."
                                    )
                                elif improve < 0:
                                    logger.log(
                                        "surrogate didn't improve. shrinking step."
                                    )
                                else:
                                    logger.log("Stepsize OK!")
                                    break
                                stepsize *= .5
                            else:
                                logger.log("couldn't compute a good step")
                                self.set_from_flat(thbefore)
                            if self.nworkers > 1 and iters_so_far % 20 == 0:
                                # list of tuples
                                paramsums = MPI.COMM_WORLD.allgather(
                                    (thnew.sum(), self.vfadam.getflat().sum()))
                                assert all(
                                    np.allclose(ps, paramsums[0])
                                    for ps in paramsums[1:])

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                for (mbob, mbret) in dataset.iterbatches(
                                    (seg["ob"], seg["tdlamret"]),
                                        include_final_partial_batch=False,
                                        batch_size=128):
                                    grad = self.allmean(
                                        self.compute_vflossandgrad(
                                            mbob, mbob, mbret, sess=self.sess))
                                    self.vfadam.update(grad, self.vf_stepsize)

                    for (loss_name, loss_val) in zip(self.loss_names,
                                                     mean_losses):
                        logger.record_tabular(loss_name, loss_val)

                    logger.record_tabular(
                        "ev_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    if self.using_gail:
                        # ------------------ Update D ------------------
                        logger.log("Optimizing Discriminator...")
                        logger.log(fmt_row(13, self.reward_giver.loss_name))
                        ob_expert, ac_expert = self.expert_dataset.get_next_batch(
                            len(observation))
                        batch_size = len(observation) // self.d_step
                        d_losses = [
                        ]  # list of tuples, each of which gives the loss for a minibatch
                        for ob_batch, ac_batch in dataset.iterbatches(
                            (observation, action),
                                include_final_partial_batch=False,
                                batch_size=batch_size):
                            ob_expert, ac_expert = self.expert_dataset.get_next_batch(
                                len(ob_batch))
                            # update running mean/std for reward_giver
                            if hasattr(self.reward_giver, "obs_rms"):
                                self.reward_giver.obs_rms.update(
                                    np.concatenate((ob_batch, ob_expert), 0))
                            *newlosses, grad = self.reward_giver.lossandgrad(
                                ob_batch, ac_batch, ob_expert, ac_expert)
                            self.d_adam.update(self.allmean(grad),
                                               self.d_stepsize)
                            d_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                        lrlocal = (seg["ep_lens"], seg["ep_rets"],
                                   seg["ep_true_rets"])  # local values
                        listoflrpairs = MPI.COMM_WORLD.allgather(
                            lrlocal)  # list of tuples
                        lens, rews, true_rets = map(flatten_lists,
                                                    zip(*listoflrpairs))
                        true_rewbuffer.extend(true_rets)
                    else:
                        lrlocal = (seg["ep_lens"], seg["ep_rets"]
                                   )  # local values
                        listoflrpairs = MPI.COMM_WORLD.allgather(
                            lrlocal)  # list of tuples
                        lens, rews = map(flatten_lists, zip(*listoflrpairs))
                    lenbuffer.extend(lens)
                    rewbuffer.extend(rews)

                    logger.record_tabular("EpLenMean", np.mean(lenbuffer))
                    logger.record_tabular("EpRewMean", np.mean(rewbuffer))
                    if self.using_gail:
                        logger.record_tabular("EpTrueRewMean",
                                              np.mean(true_rewbuffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    timesteps_so_far += seg["total_timestep"]
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", timesteps_so_far)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        return self
コード例 #5
0
    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100):
        with SetVerbosity(self.verbose):
            self._setup_learn(seed)

            with self.sess.as_default():
                self.adam.sync()

                # Prepare for rollouts
                seg_gen = traj_segment_generator(self.policy_pi, self.env,
                                                 self.timesteps_per_actorbatch)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()

                # rolling buffer for episode lengths
                lenbuffer = deque(maxlen=100)
                # rolling buffer for episode rewards
                rewbuffer = deque(maxlen=100)

                while True:
                    if callback:
                        callback(locals(), globals())
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    if self.schedule == 'constant':
                        cur_lrmult = 1.0
                    elif self.schedule == 'linear':
                        cur_lrmult = max(
                            1.0 - float(timesteps_so_far) / total_timesteps, 0)
                    else:
                        raise NotImplementedError

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    seg = seg_gen.__next__()
                    add_vtarg_and_adv(seg, self.gamma, self.lam)

                    # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                    obs_ph, action_ph, atarg, tdlamret = seg["ob"], seg[
                        "ac"], seg["adv"], seg["tdlamret"]

                    # predicted value function before udpate
                    vpredbefore = seg["vpred"]

                    # standardized advantage function estimate
                    atarg = (atarg - atarg.mean()) / atarg.std()
                    dataset = Dataset(
                        dict(ob=obs_ph,
                             ac=action_ph,
                             atarg=atarg,
                             vtarg=tdlamret),
                        shuffle=not issubclass(self.policy, LstmPolicy))
                    optim_batchsize = self.optim_batchsize or obs_ph.shape[0]

                    # set old parameter values to new parameter values
                    self.assign_old_eq_new(sess=self.sess)
                    logger.log("Optimizing...")
                    logger.log(fmt_row(13, self.loss_names))

                    # Here we do a bunch of optimization epochs over the data
                    for _ in range(self.optim_epochs):
                        # list of tuples, each of which gives the loss for a minibatch
                        losses = []
                        for batch in dataset.iterate_once(optim_batchsize):
                            *newlosses, grad = self.lossandgrad(batch["ob"],
                                                                batch["ob"],
                                                                batch["ac"],
                                                                batch["atarg"],
                                                                batch["vtarg"],
                                                                cur_lrmult,
                                                                sess=self.sess)
                            self.adam.update(grad,
                                             self.optim_stepsize * cur_lrmult)
                            losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(losses, axis=0)))

                    logger.log("Evaluating losses...")
                    losses = []
                    for batch in dataset.iterate_once(optim_batchsize):
                        newlosses = self.compute_losses(batch["ob"],
                                                        batch["ob"],
                                                        batch["ac"],
                                                        batch["atarg"],
                                                        batch["vtarg"],
                                                        cur_lrmult,
                                                        sess=self.sess)
                        losses.append(newlosses)
                    mean_losses, _, _ = mpi_moments(losses, axis=0)
                    logger.log(fmt_row(13, mean_losses))
                    for (loss_val, name) in zipsame(mean_losses,
                                                    self.loss_names):
                        logger.record_tabular("loss_" + name, loss_val)
                    logger.record_tabular(
                        "ev_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    # local values
                    lrlocal = (seg["ep_lens"], seg["ep_rets"])

                    # list of tuples
                    listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)
                    lens, rews = map(flatten_lists, zip(*listoflrpairs))
                    lenbuffer.extend(lens)
                    rewbuffer.extend(rews)
                    logger.record_tabular("EpLenMean", np.mean(lenbuffer))
                    logger.record_tabular("EpRewMean", np.mean(rewbuffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    timesteps_so_far += seg["total_timestep"]
                    iters_so_far += 1
                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", timesteps_so_far)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)
                    if self.verbose >= 1 and MPI.COMM_WORLD.Get_rank() == 0:
                        logger.dump_tabular()

        return self
コード例 #6
0
    def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="MDPO",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)
        print("got seed {}, sgd_steps {}".format(seed, self.sgd_steps))

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
                as writer:

            with self.sess.as_default():
                callback.on_training_start(locals(), globals())

                seg_gen = traj_segment_generator(self.old_policy, self.env, self.timesteps_per_batch,
                                                     reward_giver=self.reward_giver,
                                                     gail=self.using_gail, mdal=self.using_mdal, neural=self.neural,
                                                     action_space=self.action_space, gamma=self.gamma, callback=callback)


                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                len_buffer = deque(maxlen=40)  # rolling buffer for episode lengths
                reward_buffer = deque(maxlen=40)  # rolling buffer for episode rewards

                self.episode_reward = np.zeros((self.n_envs,))
                self.outer_learning_rate = get_schedule_fn(3e-4)
                self.cliprange_vf = get_schedule_fn(0.2)

                true_reward_buffer = None
                if self.using_gail or self.using_mdal:
                    true_reward_buffer = deque(maxlen=40)

                    # Initialize dataloader
                    batchsize = self.timesteps_per_batch // self.d_step
                    self.expert_dataset.init_dataloader(batchsize)

                    #  Stats not used for now
                    # TODO: replace with normal tb logging
                    #  g_loss_stats = Stats(loss_names)
                    #  d_loss_stats = Stats(reward_giver.loss_name)
                    #  ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"])

                while True:
                    # if callback is not None:
                    #     # Only stop training if return value is False, not when it is None. This is for backwards
                    #     # compatibility with callbacks that have no return statement.
                    #     if callback(locals(), globals()) is False:
                    #         break
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" % iters_so_far)

                    #def fisher_vector_product(vec):
                    #    return self.allmean(self.compute_fvp(vec, *fvpargs, sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    # logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for k in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()
                        if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                            break

                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        if self.using_mdal:
                            policy_successor_features = add_successor_features(seg, self.gamma,
                                                                           is_action_features=self.is_action_features)
                        else:
                            policy_successor_features = add_successor_features(seg, self.gamma)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        observation, action = seg["observations"], seg["actions"]
                        atarg, tdlamret = seg["adv"], seg["tdlamret"]


                        vpredbefore = seg["vpred"]  # predicted value function before update
                        atarg = (atarg - atarg.mean()) / atarg.std()  # standardized advantage function estimate

                        # true_rew is the reward without discount
                        if writer is not None:
                            self.episode_reward = total_episode_reward_logger(self.episode_reward,
                                                                              seg["true_rewards"].reshape(
                                                                                  (self.n_envs, -1)),
                                                                              seg["dones"].reshape((self.n_envs, -1)),
                                                                              writer, self.num_timesteps)

                        n_updates = int(total_timesteps / self.timesteps_per_batch)
                        lr_now = np.float32(1.0 - (iters_so_far - 1.0) / n_updates)
                        outer_lr_now = self.outer_learning_rate(1.0 - (iters_so_far - 1.0) / n_updates)
                        clip_now = self.cliprange_vf(1.0 - (iters_so_far - 1.0) / n_updates)
                        args = seg["observations"], seg["observations"], seg["actions"], atarg
                        # Subsampling: see p40-42 of John Schulman thesis
                        # http://joschu.net/docs/thesis.pdf
                        #fvpargs = [arr[::5] for arr in args]

                        with self.timed("computegrad"):
                            steps = self.num_timesteps + (k + 1) * (seg["total_timestep"] / self.g_step)
                            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                            run_metadata = tf.RunMetadata() if self.full_tensorboard_log else None
                            # run loss backprop with summary, and save the metadata (memory, compute time, ...)
                            if writer is not None:
                                summary, grad, *lossbefore = self.compute_lossandgrad(*args, tdlamret,
                                                                                      lr_now, seg["vpred"],
                                                                                      seg["observations"],
                                                                                      sess=self.sess,
                                                                                      options=run_options,
                                                                                      run_metadata=run_metadata)
                                if self.full_tensorboard_log:
                                    writer.add_run_metadata(run_metadata, 'step%d' % steps)
                                writer.add_summary(summary, steps)
                            else:
                                _, grad, *lossbefore = self.compute_lossandgrad(*args, tdlamret,
                                                                                lr_now, seg["vpred"],
                                                                                seg["observations"],
                                                                                sess=self.sess,
                                                                                options=run_options,
                                                                                run_metadata=run_metadata)
                                td_map = {self.policy_pi.obs_ph: seg["observations"],
                                            self.old_policy.obs_ph: seg["observations"],
                                            self.closed_policy.obs_ph: seg["observations"],
                                            self.action: seg["actions"], self.atarg: atarg, self.ret: tdlamret,
                                            self.learning_rate_ph: lr_now, self.outer_learning_rate_ph: outer_lr_now,
                                            self.vtarg: seg["vpred"]}
                                for _ in range(int(self.sgd_steps)):
                                    _ = self.sess.run(self._train, td_map)
                                    #if self.method == "closed-KL":
                                    #    _ = self.sess.run(self._train_policy, td_map)

                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            for _ in range(1):
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(self.compute_losses(*args, lr_now, seg["vpred"], sess=self.sess)))

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                # NOTE: for recurrent policies, use shuffle=False?
                                for (mbob, mbret, mbval) in dataset.iterbatches((seg["observations"], seg["tdlamret"], seg["vpred"]),
                                                                         include_final_partial_batch=False,
                                                                         batch_size=128,
                                                                         shuffle=True):
                                    grad = self.allmean(self.compute_vflossandgrad(mbob, mbob, mbret, mbval, clip_now, sess=self.sess))
                                    self.vfadam.update(grad, outer_lr_now) #self.vf_stepsize)

                        if iters_so_far % 1 == 0:
                            # print("updating theta now")
                            self.assign_old_eq_new(sess=self.sess)

                    for (loss_name, loss_val) in zip(self.loss_names, mean_losses):
                        logger.record_tabular(loss_name, loss_val)

                    logger.record_tabular("explained_variance_tdlam_before",
                                          explained_variance(vpredbefore, tdlamret))

                    if self.using_gail:
                        # ------------------ Update D ------------------
                        logger.log("Optimizing Discriminator...")
                        logger.log(fmt_row(13, self.reward_giver.loss_name))
                        assert len(observation) == self.timesteps_per_batch
                        batch_size = self.timesteps_per_batch // self.d_step

                        # NOTE: uses only the last g step for observation
                        d_losses = []  # list of tuples, each of which gives the loss for a minibatch
                        # NOTE: for recurrent policies, use shuffle=False?
                        for ob_batch, ac_batch in dataset.iterbatches((observation, action),
                                                                      include_final_partial_batch=False,
                                                                      batch_size=batch_size,
                                                                      shuffle=True):
                            ob_expert, ac_expert = self.expert_dataset.get_next_batch()
                            # update running mean/std for reward_giver
                            if self.reward_giver.normalize:
                                self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))

                            # Reshape actions if needed when using discrete actions
                            if isinstance(self.action_space, gym.spaces.Discrete):
                                if len(ac_batch.shape) == 2:
                                    ac_batch = ac_batch[:, 0]
                                if len(ac_expert.shape) == 2:
                                    ac_expert = ac_expert[:, 0]
                            *newlosses, grad = self.reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert)
                            self.d_adam.update(self.allmean(grad), self.d_stepsize)
                            d_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                    elif self.using_mdal:
                        batch_sampling = True

                        if self.neural:

                            if batch_sampling:
                                batch_size = self.timesteps_per_batch // self.d_step

                                # NOTE: uses only the last g step for observation
                                d_losses = []  # list of tuples, each of which gives the loss for a minibatch
                                # NOTE: for recurrent policies, use shuffle=False?
                                for ob_batch, ac_batch in dataset.iterbatches((observation, action),
                                                                              include_final_partial_batch=False,
                                                                              batch_size=batch_size,
                                                                              shuffle=True):
                                # ob_batch, ac_batch, gamma_batch = np.array(batch_buffer['obs']), np.array(
                                #     batch_buffer['acs']), np.array(batch_buffer['gammas'])
                                    gamma_batch = np.ones((ob_batch.shape[0]))
                                    ob_expert, ac_expert = self.expert_dataset.get_next_batch()
                                    gamma_expert = np.ones((ob_expert.shape[0]))
                                    # ob_expert, ac_expert, gamma_expert = np.concatenate(self.expert_dataset.ep_obs),\
                                    #                                      np.concatenate(self.expert_dataset.ep_acs),\
                                    #                                      np.concatenate(self.expert_dataset.ep_gammas)

                                    # update running mean/std for reward_giver
                                    if self.reward_giver.normalize:
                                        self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))

                                    # Reshape actions if needed when using discrete actions
                                    if isinstance(self.action_space, gym.spaces.Discrete):
                                        if len(ac_batch.shape) == 2:
                                            ac_batch = ac_batch[:, 0]
                                        if len(ac_expert.shape) == 2:
                                            ac_expert = ac_expert[:, 0]

                                    ob_reg_expert, ac_reg_expert = np.array(ob_expert), np.array(ac_expert)

                                    # while True:
                                    #     if ob_reg_expert.shape[0] == ob_batch.shape[0] and ac_reg_expert.shape[0] == \
                                    #             ac_batch.shape[0]:
                                    #         break
                                    #     ob_reg_expert, ac_reg_expert = self.expert_dataset.get_next_batch()
                                    #     ob_reg_expert, ac_reg_expert = np.array(ob_reg_expert), np.array(ac_reg_expert)


                                    alpha = np.random.uniform(0.0, 1.0, size=(ob_reg_expert.shape[0], 1))
                                    ob_mix_batch = alpha * ob_batch[:ob_reg_expert.shape[0]] + (1 - alpha) * ob_reg_expert
                                    ac_mix_batch = alpha * ac_batch[:ac_reg_expert.shape[0]] + (1 - alpha) * ac_reg_expert
                                    with self.sess.as_default():
                                        # self.reward_giver.train(ob_batch, ac_batch, np.expand_dims(gamma_batch, axis=1),
                                        #                         ob_expert, ac_expert, np.expand_dims(gamma_expert, axis=1))
                                        *newlosses, grad = self.reward_giver.lossandgrad(
                                                                ob_batch, ac_batch, np.expand_dims(gamma_batch, axis=1),
                                                                ob_expert, ac_expert, np.expand_dims(gamma_expert, axis=1),
                                                                ob_mix_batch, ac_mix_batch)
                                        self.d_adam.update(self.allmean(grad), self.d_stepsize)
                            else:
                                # assert len(observation) == self.timesteps_per_batch
                                # Comment out if you want only the latest rewards:
                                obs_batch, acs_batch, gammas_batch = seg['obs_batch'], seg['acs_batch'], seg['gammas_batch']
                                batch_successor_features = seg['successor_features_batch']


                                if self.reward_giver.normalize:
                                    ob_reg_batch, ac_reg_batch = observation, action
                                    ob_expert, _ = self.expert_dataset.get_next_batch()
                                    self.reward_giver.obs_rms.update(np.concatenate((ob_reg_batch, ob_expert), 0))
                                #     self.reward_giver.obs_rms.update(
                                #         np.array(batch_successor_features)[:, :self.observation_space.shape[0]])

                                for idx, (ob_batch, ac_batch, gamma_batch) in enumerate(
                                        zip(obs_batch, acs_batch, gammas_batch)):
                                    rand_traj = np.random.randint(self.expert_dataset.num_traj)
                                    ob_expert, ac_expert, gamma_expert = self.expert_dataset.ep_obs[rand_traj], \
                                                                         self.expert_dataset.ep_acs[rand_traj], \
                                                                         self.expert_dataset.ep_gammas[rand_traj]

                                    ob_batch, ac_batch, gamma_batch = np.array(ob_batch), np.array(ac_batch), np.array(
                                        gamma_batch)

                                    while True:
                                        ob_reg_expert, ac_reg_expert = self.expert_dataset.get_next_batch()
                                        ob_reg_expert, ac_reg_expert = np.array(ob_reg_expert), np.array(ac_reg_expert)

                                        if ob_reg_expert.shape[0] == ob_reg_batch.shape[0] and ac_reg_expert.shape[0] == \
                                                ac_reg_batch.shape[0]:
                                            break
                                    alpha = np.random.uniform(0.0, 1.0, size=(ob_reg_batch.shape[0], 1))
                                    ob_mix_batch = alpha * ob_reg_batch + (1 - alpha) * ob_reg_expert
                                    ac_mix_batch = alpha * ac_reg_batch + (1 - alpha) * ac_reg_expert

                                    with self.sess.as_default():
                                        *newlosses, grad = self.reward_giver.lossandgrad(
                                                                ob_batch, ac_batch, np.expand_dims(gamma_batch, axis=1),
                                                                ob_expert, ac_expert, np.expand_dims(gamma_expert, axis=1),
                                                                ob_mix_batch, ac_mix_batch)
                                        self.d_adam.update(self.allmean(grad), self.d_stepsize)
                                        # self.reward_giver.train(ob_batch, ac_batch, np.expand_dims(gamma_batch, axis=1),
                                        #                         ob_expert, ac_expert,
                                        #                         np.expand_dims(gamma_expert, axis=1),
                                        #                         ob_mix_batch, ac_mix_batch)

                    if self.using_gail or self.using_mdal:
                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"])  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(lr_local)  # list of tuples
                        lens, rews, true_rets = map(flatten_lists, zip(*list_lr_pairs))
                        true_reward_buffer.extend(true_rets)
                    else:
                        lr_local = (seg["ep_lens"], seg["ep_rets"])  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(lr_local)  # list of tuples
                        lens, rews = map(flatten_lists, zip(*list_lr_pairs))
                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)

                    if len(len_buffer) > 0:
                        if self.using_gail or self.using_mdal:
                            logger.record_tabular("EpTrueRewMean", np.mean(true_reward_buffer))

                        logger.record_tabular("EpRewMean", np.mean(reward_buffer))
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))


                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)
                    logger.record_tabular("Tsallis-q", self.tsallis_q)
                    logger.record_tabular("steps", self.num_timesteps)
                    logger.record_tabular("seed", self.seed)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()
        callback.on_training_end()

        return self