Esempio n. 1
0
class TRPO(ActorCriticRLModel):
    """
    Trust Region Policy Optimization (https://arxiv.org/abs/1502.05477)

    :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...)
    :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
    :param gamma: (float) the discount value
    :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon)
    :param max_kl: (float) the kullback leiber loss threshold
    :param cg_iters: (int) the number of iterations for the conjugate gradient calculation
    :param lam: (float) GAE factor
    :param entcoeff: (float) the weight for the entropy loss
    :param cg_damping: (float) the compute gradient dampening factor
    :param vf_stepsize: (float) the value function stepsize
    :param vf_iters: (int) the value function's number iterations for learning
    :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
    :param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
    :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
    :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
    :param full_tensorboard_log: (bool) enable additional logging when using tensorboard
        WARNING: this logging can take a lot of space quickly
    """
    def __init__(self,
                 policy,
                 env,
                 gamma=0.99,
                 timesteps_per_batch=1024,
                 max_kl=0.01,
                 cg_iters=10,
                 lam=0.98,
                 entcoeff=0.0,
                 cg_damping=1e-2,
                 vf_stepsize=3e-4,
                 vf_iters=3,
                 verbose=0,
                 tensorboard_log=None,
                 _init_setup_model=True,
                 policy_kwargs=None,
                 full_tensorboard_log=False):
        super(TRPO, self).__init__(policy=policy,
                                   env=env,
                                   verbose=verbose,
                                   requires_vec_env=False,
                                   _init_setup_model=_init_setup_model,
                                   policy_kwargs=policy_kwargs)

        self.using_gail = False
        self.timesteps_per_batch = timesteps_per_batch
        self.cg_iters = cg_iters
        self.cg_damping = cg_damping
        self.gamma = gamma
        self.lam = lam
        self.max_kl = max_kl
        self.vf_iters = vf_iters
        self.vf_stepsize = vf_stepsize
        self.entcoeff = entcoeff
        self.tensorboard_log = tensorboard_log
        self.full_tensorboard_log = full_tensorboard_log

        # GAIL Params
        self.hidden_size_adversary = 100
        self.adversary_entcoeff = 1e-3
        self.expert_dataset = None
        self.g_step = 1
        self.d_step = 1
        self.d_stepsize = 3e-4

        self.graph = None
        self.sess = None
        self.policy_pi = None
        self.loss_names = None
        self.assign_old_eq_new = None
        self.compute_losses = None
        self.compute_lossandgrad = None
        self.compute_fvp = None
        self.compute_vflossandgrad = None
        self.d_adam = None
        self.vfadam = None
        self.get_flat = None
        self.set_from_flat = None
        self.timed = None
        self.allmean = None
        self.nworkers = None
        self.rank = None
        self.reward_giver = None
        self.step = None
        self.proba_step = None
        self.initial_state = None
        self.params = None
        self.summary = None
        self.episode_reward = None

        if _init_setup_model:
            self.setup_model()

    def _get_pretrain_placeholders(self):
        policy = self.policy_pi
        action_ph = policy.pdtype.sample_placeholder([None])
        if isinstance(self.action_space, gym.spaces.Discrete):
            return policy.obs_ph, action_ph, policy.policy
        return policy.obs_ph, action_ph, policy.deterministic_action

    def setup_model(self):
        # prevent import loops
        from stable_baselines.gail.adversary import TransitionClassifier

        with SetVerbosity(self.verbose):

            assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the TRPO model must be " \
                                                               "an instance of common.policies.ActorCriticPolicy."

            self.nworkers = MPI.COMM_WORLD.Get_size()
            self.rank = MPI.COMM_WORLD.Get_rank()
            np.set_printoptions(precision=3)

            self.graph = tf.Graph()
            with self.graph.as_default():
                self.sess = tf_util.single_threaded_session(graph=self.graph)

                if self.using_gail:
                    self.reward_giver = TransitionClassifier(
                        self.observation_space,
                        self.action_space,
                        self.hidden_size_adversary,
                        entcoeff=self.adversary_entcoeff)

                # Construct network for new policy
                self.policy_pi = self.policy(self.sess,
                                             self.observation_space,
                                             self.action_space,
                                             self.n_envs,
                                             1,
                                             None,
                                             reuse=False,
                                             **self.policy_kwargs)

                # Network for old policy
                with tf.variable_scope("oldpi", reuse=False):
                    old_policy = self.policy(self.sess,
                                             self.observation_space,
                                             self.action_space,
                                             self.n_envs,
                                             1,
                                             None,
                                             reuse=False,
                                             **self.policy_kwargs)

                with tf.variable_scope("loss", reuse=False):
                    atarg = tf.placeholder(dtype=tf.float32, shape=[
                        None
                    ])  # Target advantage function (if applicable)
                    ret = tf.placeholder(dtype=tf.float32,
                                         shape=[None])  # Empirical return

                    observation = self.policy_pi.obs_ph
                    action = self.policy_pi.pdtype.sample_placeholder([None])

                    kloldnew = old_policy.proba_distribution.kl(
                        self.policy_pi.proba_distribution)
                    ent = self.policy_pi.proba_distribution.entropy()
                    meankl = tf.reduce_mean(kloldnew)
                    meanent = tf.reduce_mean(ent)
                    entbonus = self.entcoeff * meanent

                    vferr = tf.reduce_mean(
                        tf.square(self.policy_pi.value_fn[:, 0] - ret))

                    # advantage * pnew / pold
                    ratio = tf.exp(
                        self.policy_pi.proba_distribution.logp(action) -
                        old_policy.proba_distribution.logp(action))
                    surrgain = tf.reduce_mean(ratio * atarg)

                    optimgain = surrgain + entbonus
                    losses = [optimgain, meankl, entbonus, surrgain, meanent]
                    self.loss_names = [
                        "optimgain", "meankl", "entloss", "surrgain", "entropy"
                    ]

                    dist = meankl

                    all_var_list = tf_util.get_trainable_vars("model")
                    var_list = [
                        v for v in all_var_list
                        if "/vf" not in v.name and "/q/" not in v.name
                    ]
                    vf_var_list = [
                        v for v in all_var_list
                        if "/pi" not in v.name and "/logstd" not in v.name
                    ]

                    self.get_flat = tf_util.GetFlat(var_list, sess=self.sess)
                    self.set_from_flat = tf_util.SetFromFlat(var_list,
                                                             sess=self.sess)

                    klgrads = tf.gradients(dist, var_list)
                    flat_tangent = tf.placeholder(dtype=tf.float32,
                                                  shape=[None],
                                                  name="flat_tan")
                    shapes = [var.get_shape().as_list() for var in var_list]
                    start = 0
                    tangents = []
                    for shape in shapes:
                        var_size = tf_util.intprod(shape)
                        tangents.append(
                            tf.reshape(flat_tangent[start:start + var_size],
                                       shape))
                        start += var_size
                    gvp = tf.add_n([
                        tf.reduce_sum(grad * tangent)
                        for (grad, tangent) in zipsame(klgrads, tangents)
                    ])  # pylint: disable=E1111
                    fvp = tf_util.flatgrad(gvp, var_list)

                    tf.summary.scalar('entropy_loss', meanent)
                    tf.summary.scalar('policy_gradient_loss', optimgain)
                    tf.summary.scalar('value_function_loss', surrgain)
                    tf.summary.scalar('approximate_kullback-leiber', meankl)
                    tf.summary.scalar(
                        'loss',
                        optimgain + meankl + entbonus + surrgain + meanent)

                    self.assign_old_eq_new = \
                        tf_util.function([], [], updates=[tf.assign(oldv, newv) for (oldv, newv) in
                                                          zipsame(tf_util.get_globals_vars("oldpi"),
                                                                  tf_util.get_globals_vars("model"))])
                    self.compute_losses = tf_util.function(
                        [observation, old_policy.obs_ph, action, atarg],
                        losses)
                    self.compute_fvp = tf_util.function([
                        flat_tangent, observation, old_policy.obs_ph, action,
                        atarg
                    ], fvp)
                    self.compute_vflossandgrad = tf_util.function(
                        [observation, old_policy.obs_ph, ret],
                        tf_util.flatgrad(vferr, vf_var_list))

                    @contextmanager
                    def timed(msg):
                        if self.rank == 0 and self.verbose >= 1:
                            print(colorize(msg, color='magenta'))
                            start_time = time.time()
                            yield
                            print(
                                colorize("done in {:.3f} seconds".format(
                                    (time.time() - start_time)),
                                         color='magenta'))
                        else:
                            yield

                    def allmean(arr):
                        assert isinstance(arr, np.ndarray)
                        out = np.empty_like(arr)
                        MPI.COMM_WORLD.Allreduce(arr, out, op=MPI.SUM)
                        out /= self.nworkers
                        return out

                    tf_util.initialize(sess=self.sess)

                    th_init = self.get_flat()
                    MPI.COMM_WORLD.Bcast(th_init, root=0)
                    self.set_from_flat(th_init)

                with tf.variable_scope("Adam_mpi", reuse=False):
                    self.vfadam = MpiAdam(vf_var_list, sess=self.sess)
                    if self.using_gail:
                        self.d_adam = MpiAdam(
                            self.reward_giver.get_trainable_variables(),
                            sess=self.sess)
                        self.d_adam.sync()
                    self.vfadam.sync()

                with tf.variable_scope("input_info", reuse=False):
                    tf.summary.scalar('discounted_rewards',
                                      tf.reduce_mean(ret))
                    tf.summary.scalar('learning_rate',
                                      tf.reduce_mean(self.vf_stepsize))
                    tf.summary.scalar('advantage', tf.reduce_mean(atarg))
                    tf.summary.scalar('kl_clip_range',
                                      tf.reduce_mean(self.max_kl))

                    if self.full_tensorboard_log:
                        tf.summary.histogram('discounted_rewards', ret)
                        tf.summary.histogram('learning_rate', self.vf_stepsize)
                        tf.summary.histogram('advantage', atarg)
                        tf.summary.histogram('kl_clip_range', self.max_kl)
                        if tf_util.is_image(self.observation_space):
                            tf.summary.image('observation', observation)
                        else:
                            tf.summary.histogram('observation', observation)

                self.timed = timed
                self.allmean = allmean

                self.step = self.policy_pi.step
                self.proba_step = self.policy_pi.proba_step
                self.initial_state = self.policy_pi.initial_state

                self.params = find_trainable_variables("model")
                if self.using_gail:
                    self.params.extend(
                        self.reward_giver.get_trainable_variables())

                self.summary = tf.summary.merge_all()

                self.compute_lossandgrad = \
                    tf_util.function([observation, old_policy.obs_ph, action, atarg, ret],
                                     [self.summary, tf_util.flatgrad(optimgain, var_list)] + losses)

    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100,
              tb_log_name="TRPO",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
                as writer:
            self._setup_learn(seed)

            with self.sess.as_default():
                seg_gen = traj_segment_generator(
                    self.policy_pi,
                    self.env,
                    self.timesteps_per_batch,
                    reward_giver=self.reward_giver,
                    gail=self.using_gail)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                len_buffer = deque(
                    maxlen=40)  # rolling buffer for episode lengths
                reward_buffer = deque(
                    maxlen=40)  # rolling buffer for episode rewards
                self.episode_reward = np.zeros((self.n_envs, ))

                true_reward_buffer = None
                if self.using_gail:
                    true_reward_buffer = deque(maxlen=40)

                    # Initialize dataloader
                    batchsize = self.timesteps_per_batch // self.d_step
                    self.expert_dataset.init_dataloader(batchsize)

                    #  Stats not used for now
                    # TODO: replace with normal tb logging
                    #  g_loss_stats = Stats(loss_names)
                    #  d_loss_stats = Stats(reward_giver.loss_name)
                    #  ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"])

                while True:
                    if callback is not None:
                        # Only stop training if return value is False, not when it is None. This is for backwards
                        # compatibility with callbacks that have no return statement.
                        if callback(locals(), globals()) is False:
                            break
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    def fisher_vector_product(vec):
                        return self.allmean(
                            self.compute_fvp(
                                vec, *fvpargs,
                                sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for k in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()
                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        observation, action, atarg, tdlamret = seg["ob"], seg[
                            "ac"], seg["adv"], seg["tdlamret"]
                        vpredbefore = seg[
                            "vpred"]  # predicted value function before update
                        atarg = (atarg - atarg.mean()) / atarg.std(
                        )  # standardized advantage function estimate

                        # true_rew is the reward without discount
                        if writer is not None:
                            self.episode_reward = total_episode_reward_logger(
                                self.episode_reward, seg["true_rew"].reshape(
                                    (self.n_envs, -1)), seg["dones"].reshape(
                                        (self.n_envs, -1)), writer,
                                self.num_timesteps)

                        args = seg["ob"], seg["ob"], seg["ac"], atarg
                        fvpargs = [arr[::5] for arr in args]

                        self.assign_old_eq_new(sess=self.sess)

                        with self.timed("computegrad"):
                            steps = self.num_timesteps + (k + 1) * (
                                seg["total_timestep"] / self.g_step)
                            run_options = tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE)
                            run_metadata = tf.RunMetadata(
                            ) if self.full_tensorboard_log else None
                            # run loss backprop with summary, and save the metadata (memory, compute time, ...)
                            if writer is not None:
                                summary, grad, *lossbefore = self.compute_lossandgrad(
                                    *args,
                                    tdlamret,
                                    sess=self.sess,
                                    options=run_options,
                                    run_metadata=run_metadata)
                                if self.full_tensorboard_log:
                                    writer.add_run_metadata(
                                        run_metadata, 'step%d' % steps)
                                writer.add_summary(summary, steps)
                            else:
                                _, grad, *lossbefore = self.compute_lossandgrad(
                                    *args,
                                    tdlamret,
                                    sess=self.sess,
                                    options=run_options,
                                    run_metadata=run_metadata)

                        lossbefore = self.allmean(np.array(lossbefore))
                        grad = self.allmean(grad)
                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            with self.timed("conjugate_gradient"):
                                stepdir = conjugate_gradient(
                                    fisher_vector_product,
                                    grad,
                                    cg_iters=self.cg_iters,
                                    verbose=self.rank == 0
                                    and self.verbose >= 1)
                            assert np.isfinite(stepdir).all()
                            shs = .5 * stepdir.dot(
                                fisher_vector_product(stepdir))
                            # abs(shs) to avoid taking square root of negative values
                            lagrange_multiplier = np.sqrt(
                                abs(shs) / self.max_kl)
                            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                            fullstep = stepdir / lagrange_multiplier
                            expectedimprove = grad.dot(fullstep)
                            surrbefore = lossbefore[0]
                            stepsize = 1.0
                            thbefore = self.get_flat()
                            thnew = None
                            for _ in range(10):
                                thnew = thbefore + fullstep * stepsize
                                self.set_from_flat(thnew)
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(
                                        self.compute_losses(*args,
                                                            sess=self.sess)))
                                improve = surr - surrbefore
                                logger.log("Expected: %.3f Actual: %.3f" %
                                           (expectedimprove, improve))
                                if not np.isfinite(mean_losses).all():
                                    logger.log(
                                        "Got non-finite value of losses -- bad!"
                                    )
                                elif kl_loss > self.max_kl * 1.5:
                                    logger.log(
                                        "violated KL constraint. shrinking step."
                                    )
                                elif improve < 0:
                                    logger.log(
                                        "surrogate didn't improve. shrinking step."
                                    )
                                else:
                                    logger.log("Stepsize OK!")
                                    break
                                stepsize *= .5
                            else:
                                logger.log("couldn't compute a good step")
                                self.set_from_flat(thbefore)
                            if self.nworkers > 1 and iters_so_far % 20 == 0:
                                # list of tuples
                                paramsums = MPI.COMM_WORLD.allgather(
                                    (thnew.sum(), self.vfadam.getflat().sum()))
                                assert all(
                                    np.allclose(ps, paramsums[0])
                                    for ps in paramsums[1:])

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                # NOTE: for recurrent policies, use shuffle=False?
                                for (mbob, mbret) in dataset.iterbatches(
                                    (seg["ob"], seg["tdlamret"]),
                                        include_final_partial_batch=False,
                                        batch_size=128,
                                        shuffle=True):
                                    grad = self.allmean(
                                        self.compute_vflossandgrad(
                                            mbob, mbob, mbret, sess=self.sess))
                                    self.vfadam.update(grad, self.vf_stepsize)

                    for (loss_name, loss_val) in zip(self.loss_names,
                                                     mean_losses):
                        logger.record_tabular(loss_name, loss_val)

                    logger.record_tabular(
                        "explained_variance_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    if self.using_gail:
                        # ------------------ Update D ------------------
                        logger.log("Optimizing Discriminator...")
                        logger.log(fmt_row(13, self.reward_giver.loss_name))
                        assert len(observation) == self.timesteps_per_batch
                        batch_size = self.timesteps_per_batch // self.d_step

                        # NOTE: uses only the last g step for observation
                        d_losses = [
                        ]  # list of tuples, each of which gives the loss for a minibatch
                        # NOTE: for recurrent policies, use shuffle=False?
                        for ob_batch, ac_batch in dataset.iterbatches(
                            (observation, action),
                                include_final_partial_batch=False,
                                batch_size=batch_size,
                                shuffle=True):
                            ob_expert, ac_expert = self.expert_dataset.get_next_batch(
                            )
                            # update running mean/std for reward_giver
                            if self.reward_giver.normalize:
                                self.reward_giver.obs_rms.update(
                                    np.concatenate((ob_batch, ob_expert), 0))

                            # Reshape actions if needed when using discrete actions
                            if isinstance(self.action_space,
                                          gym.spaces.Discrete):
                                if len(ac_batch.shape) == 2:
                                    ac_batch = ac_batch[:, 0]
                                if len(ac_expert.shape) == 2:
                                    ac_expert = ac_expert[:, 0]
                            *newlosses, grad = self.reward_giver.lossandgrad(
                                ob_batch, ac_batch, ob_expert, ac_expert)
                            self.d_adam.update(self.allmean(grad),
                                               self.d_stepsize)
                            d_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"],
                                    seg["ep_true_rets"])  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(
                            lr_local)  # list of tuples
                        lens, rews, true_rets = map(flatten_lists,
                                                    zip(*list_lr_pairs))
                        true_reward_buffer.extend(true_rets)
                    else:
                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"]
                                    )  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(
                            lr_local)  # list of tuples
                        lens, rews = map(flatten_lists, zip(*list_lr_pairs))
                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)

                    if len(len_buffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))
                        logger.record_tabular("EpRewMean",
                                              np.mean(reward_buffer))
                    if self.using_gail:
                        logger.record_tabular("EpTrueRewMean",
                                              np.mean(true_reward_buffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(
                        seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        return self

    def save(self, save_path):
        if self.using_gail and self.expert_dataset is not None:
            # Exit processes to pickle the dataset
            self.expert_dataset.prepare_pickling()
        data = {
            "gamma": self.gamma,
            "timesteps_per_batch": self.timesteps_per_batch,
            "max_kl": self.max_kl,
            "cg_iters": self.cg_iters,
            "lam": self.lam,
            "entcoeff": self.entcoeff,
            "cg_damping": self.cg_damping,
            "vf_stepsize": self.vf_stepsize,
            "vf_iters": self.vf_iters,
            "hidden_size_adversary": self.hidden_size_adversary,
            "adversary_entcoeff": self.adversary_entcoeff,
            "expert_dataset": self.expert_dataset,
            "g_step": self.g_step,
            "d_step": self.d_step,
            "d_stepsize": self.d_stepsize,
            "using_gail": self.using_gail,
            "verbose": self.verbose,
            "policy": self.policy,
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "n_envs": self.n_envs,
            "_vectorize_action": self._vectorize_action,
            "policy_kwargs": self.policy_kwargs
        }

        params = self.sess.run(self.params)

        self._save_to_file(save_path, data=data, params=params)
Esempio n. 2
0
def learn(
        env,
        policy_func,
        *,
        timesteps_per_batch,  # timesteps per actor per update
        clip_param,
        entcoeff,  # clipping parameter epsilon, entropy coeff
        optim_epochs,
        optim_stepsize,
        optim_batchsize,  # optimization hypers
        gamma,
        lam,  # advantage estimation
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,
        max_seconds=0,  # time constraint
        callback=None,  # you can do anything in the callback, since it takes locals(), globals()
        adam_epsilon=1e-5,
        schedule='constant',  # annealing for stepsize parameters (epsilon and adam)
        num_options=2,
        app='',
        saves=False,
        wsaves=False,
        epoch=-1,
        seed=1,
        dc=0):

    optim_batchsize_ideal = optim_batchsize
    np.random.seed(seed)
    tf.set_random_seed(seed)

    # env._seed(seed)

    gamename = env.spec.id[:-3].lower()
    gamename += 'seed' + str(seed)
    gamename += app

    dirname = '{}_{}opts_saves/'.format(gamename, num_options)

    if wsaves:
        first = True
        if not os.path.exists(dirname):
            os.makedirs(dirname)
            first = False
        # while os.path.exists(dirname) and first:
        #     dirname += '0'

        files = ['pposgd_simple.py', 'mlp_policy.py', 'run_main.py']
        for i in range(len(files)):
            src = os.path.expanduser('~/baselines/baselines/ppo1/') + files[i]
            dest = os.path.expanduser('~/baselines/baselines/ppo1/') + dirname
            shutil.copy2(src, dest)

    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    # option = tf.placeholder(dtype=tf.int32, shape=[None])

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    # pdb.set_trace()
    ob = U.get_placeholder_cached(name="ob")
    option = U.get_placeholder_cached(name="option")
    term_adv = U.get_placeholder(name='term_adv',
                                 dtype=tf.float32,
                                 shape=[None])

    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = tf.reduce_mean(kloldnew)
    meanent = tf.reduce_mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                             1.0 + clip_param) * atarg  #
    pol_surr = -tf.reduce_mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)

    vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    term_loss = pi.tpred * term_adv

    log_pi = tf.log(tf.clip_by_value(pi.op_pi, 1e-20, 1.0))
    entropy = -tf.reduce_sum(pi.op_pi * log_pi, reduction_indices=1)
    op_loss = -tf.reduce_sum(log_pi[0][option[0]] * atarg + entropy * 0.1)

    total_loss += op_loss

    var_list = pi.get_trainable_variables()
    term_list = var_list[6:8]

    lossandgrad = U.function([ob, ac, atarg, ret, lrmult, option, term_adv],
                             losses + [U.flatgrad(total_loss, var_list)])
    termloss = U.function([ob, option, term_adv],
                          [U.flatgrad(term_loss, var_list)
                           ])  # Since we will use a different step size.
    adam = MpiAdam(var_list, epsilon=adam_epsilon)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    compute_losses = U.function([ob, ac, atarg, ret, lrmult, option], losses)
    U.initialize()
    adam.sync()
    saver = tf.train.Saver(max_to_keep=10000)

    results = []
    if saves:
        results = open(
            gamename + '_' + str(num_options) + 'opts_' + '_results.csv', 'w')

        out = 'epoch,avg_reward'

        for opt in range(num_options):
            out += ',option {} dur'.format(opt)
        for opt in range(num_options):
            out += ',option {} std'.format(opt)
        for opt in range(num_options):
            out += ',option {} term'.format(opt)
        for opt in range(num_options):
            out += ',option {} adv'.format(opt)

        out += '\n'

        results.write(out)

        # results.write('epoch,avg_reward,option 1 dur, option 2 dur, option 1 term, option 2 term\n')
        results.flush()

    if epoch >= 0:
        dirname = '{}_{}opts_saves/'.format(gamename, num_options)
        print("Loading weights from iteration: " + str(epoch))

        filename = dirname + '{}_epoch_{}.ckpt'.format(gamename, epoch)
        saver.restore(U.get_session(), filename)

    episodes_so_far = 0
    timesteps_so_far = 0
    global iters_so_far
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    assert sum(
        [max_iters > 0, max_timesteps > 0, max_episodes > 0,
         max_seconds > 0]) == 1, "Only one time constraint permitted"

    seg_gen = traj_segment_generator(pi,
                                     env,
                                     timesteps_per_batch,
                                     stochastic=True,
                                     num_options=num_options,
                                     saves=saves,
                                     results=results,
                                     rewbuffer=rewbuffer,
                                     dc=dc)

    datas = [0 for _ in range(num_options)]

    while True:
        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        opt_d = []
        for i in range(num_options):
            dur = np.mean(
                seg['opt_dur'][i]) if len(seg['opt_dur'][i]) > 0 else 0.
            opt_d.append(dur)

        std = []
        for i in range(num_options):
            logstd = np.mean(
                seg['logstds'][i]) if len(seg['logstds'][i]) > 0 else 0.
            std.append(np.exp(logstd))

        print("mean opt dur:", opt_d)
        print("mean op pol:", np.mean(np.array(seg['optpol_p']), axis=0))
        print("mean term p:", np.mean(np.array(seg['term_p']), axis=0))
        print("mean value val:", np.mean(np.array(seg['value_val']), axis=0))
        ob, ac, opts, atarg, tdlamret = seg["ob"], seg["ac"], seg["opts"], seg[
            "adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy
        assign_old_eq_new()  # set old parameter values to new parameter values

        if iters_so_far % 5 == 0 and wsaves:
            print("weights are saved...")
            filename = dirname + '{}_epoch_{}.ckpt'.format(
                gamename, iters_so_far)
            save_path = saver.save(U.get_session(), filename)

        min_batch = 160
        t_advs = [[] for _ in range(num_options)]
        for opt in range(num_options):
            indices = np.where(opts == opt)[0]
            print("batch size:", indices.size)
            opt_d[opt] = indices.size
            if not indices.size:
                t_advs[opt].append(0.)
                continue

            # This part is only necessasry when we use options.
            # We proceed to these verifications in order not to discard any collected trajectories.
            if datas[opt] != 0:
                if (indices.size < min_batch and datas[opt].n > min_batch):
                    datas[opt] = Dataset(dict(ob=ob[indices],
                                              ac=ac[indices],
                                              atarg=atarg[indices],
                                              vtarg=tdlamret[indices]),
                                         shuffle=not pi.recurrent)
                    t_advs[opt].append(0.)
                    continue

                elif indices.size + datas[opt].n < min_batch:
                    # pdb.set_trace()
                    oldmap = datas[opt].data_map

                    cat_ob = np.concatenate((oldmap['ob'], ob[indices]))
                    cat_ac = np.concatenate((oldmap['ac'], ac[indices]))
                    cat_atarg = np.concatenate(
                        (oldmap['atarg'], atarg[indices]))
                    cat_vtarg = np.concatenate(
                        (oldmap['vtarg'], tdlamret[indices]))
                    datas[opt] = Dataset(dict(ob=cat_ob,
                                              ac=cat_ac,
                                              atarg=cat_atarg,
                                              vtarg=cat_vtarg),
                                         shuffle=not pi.recurrent)
                    t_advs[opt].append(0.)
                    continue

                elif (indices.size + datas[opt].n > min_batch and datas[opt].n
                      < min_batch) or (indices.size > min_batch
                                       and datas[opt].n < min_batch):

                    oldmap = datas[opt].data_map
                    cat_ob = np.concatenate((oldmap['ob'], ob[indices]))
                    cat_ac = np.concatenate((oldmap['ac'], ac[indices]))
                    cat_atarg = np.concatenate(
                        (oldmap['atarg'], atarg[indices]))
                    cat_vtarg = np.concatenate(
                        (oldmap['vtarg'], tdlamret[indices]))
                    datas[opt] = d = Dataset(dict(ob=cat_ob,
                                                  ac=cat_ac,
                                                  atarg=cat_atarg,
                                                  vtarg=cat_vtarg),
                                             shuffle=not pi.recurrent)

                if (indices.size > min_batch and datas[opt].n > min_batch):
                    datas[opt] = d = Dataset(dict(ob=ob[indices],
                                                  ac=ac[indices],
                                                  atarg=atarg[indices],
                                                  vtarg=tdlamret[indices]),
                                             shuffle=not pi.recurrent)

            elif datas[opt] == 0:
                datas[opt] = d = Dataset(dict(ob=ob[indices],
                                              ac=ac[indices],
                                              atarg=atarg[indices],
                                              vtarg=tdlamret[indices]),
                                         shuffle=not pi.recurrent)

            optim_batchsize = optim_batchsize or ob.shape[0]
            optim_epochs = np.clip(
                np.int(10 * (indices.size /
                             (timesteps_per_batch / num_options))), 10,
                10) if num_options > 1 else optim_epochs
            print("optim epochs:", optim_epochs)
            logger.log("Optimizing...")

            # Here we do a bunch of optimization epochs over the data
            for _ in range(optim_epochs):
                losses = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(optim_batchsize):

                    tadv, nodc_adv = pi.get_term_adv(batch["ob"], [opt])
                    tadv = tadv if num_options > 1 else np.zeros_like(tadv)
                    t_advs[opt].append(nodc_adv)

                    *newlosses, grads = lossandgrad(batch["ob"], batch["ac"],
                                                    batch["atarg"],
                                                    batch["vtarg"], cur_lrmult,
                                                    [opt], tadv)
                    termg = termloss(batch["ob"], [opt], tadv)
                    adam.update(termg[0], 5e-7 * cur_lrmult)
                    adam.update(grads, optim_stepsize * cur_lrmult)
                    losses.append(newlosses)

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.dump_tabular()

        if saves:
            out = "{},{}"
            for _ in range(num_options):
                out += ",{},{},{},{}"
            out += "\n"

            info = [iters_so_far, np.mean(rewbuffer)]
            for i in range(num_options):
                info.append(opt_d[i])
            for i in range(num_options):
                info.append(std[i])
            for i in range(num_options):
                info.append(np.mean(np.array(seg['term_p']), axis=0)[i])
            for i in range(num_options):
                info.append(np.mean(t_advs[i]))

            results.write(out.format(*info))
            results.flush()
class PPO1(ActorCriticRLModel):
    """
    Proximal Policy Optimization algorithm (MPI version).
    Paper: https://arxiv.org/abs/1707.06347

    :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
    :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...)
    :param timesteps_per_actorbatch: (int) timesteps per actor per update
    :param clip_param: (float) clipping parameter epsilon
    :param entcoeff: (float) the entropy loss weight
    :param optim_epochs: (float) the optimizer's number of epochs
    :param optim_stepsize: (float) the optimizer's stepsize
    :param optim_batchsize: (int) the optimizer's the batch size
    :param gamma: (float) discount factor
    :param lam: (float) advantage estimation
    :param adam_epsilon: (float) the epsilon value for the adam optimizer
    :param schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
        'double_linear_con', 'middle_drop' or 'double_middle_drop')
    :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
    :param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
    :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
    :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
    :param full_tensorboard_log: (bool) enable additional logging when using tensorboard
        WARNING: this logging can take a lot of space quickly
    :param seed: (int) Seed for the pseudo-random generators (python, numpy, tensorflow).
        If None (default), use random seed. Note that if you want completely deterministic
        results, you must set `n_cpu_tf_sess` to 1.
    :param n_cpu_tf_sess: (int) The number of threads for TensorFlow operations
        If None, the number of cpu of the current machine will be used.
    """
    def __init__(self,
                 policy,
                 env,
                 gamma=0.99,
                 timesteps_per_actorbatch=256,
                 clip_param=0.2,
                 entcoeff=0.01,
                 optim_epochs=4,
                 optim_stepsize=1e-3,
                 optim_batchsize=64,
                 lam=0.95,
                 adam_epsilon=1e-5,
                 schedule='linear',
                 verbose=0,
                 tensorboard_log=None,
                 _init_setup_model=True,
                 policy_kwargs=None,
                 full_tensorboard_log=False,
                 seed=None,
                 n_cpu_tf_sess=1):

        super().__init__(policy=policy,
                         env=env,
                         verbose=verbose,
                         requires_vec_env=False,
                         _init_setup_model=_init_setup_model,
                         policy_kwargs=policy_kwargs,
                         seed=seed,
                         n_cpu_tf_sess=n_cpu_tf_sess)

        self.gamma = gamma
        self.timesteps_per_actorbatch = timesteps_per_actorbatch
        self.clip_param = clip_param
        self.entcoeff = entcoeff
        self.optim_epochs = optim_epochs
        self.optim_stepsize = optim_stepsize
        self.optim_batchsize = optim_batchsize
        self.lam = lam
        self.adam_epsilon = adam_epsilon
        self.schedule = schedule
        self.tensorboard_log = tensorboard_log
        self.full_tensorboard_log = full_tensorboard_log

        self.graph = None
        self.sess = None
        self.policy_pi = None
        self.loss_names = None
        self.lossandgrad = None
        self.adam = None
        self.assign_old_eq_new = None
        self.compute_losses = None
        self.params = None
        self.step = None
        self.proba_step = None
        self.initial_state = None
        self.summary = None

        if _init_setup_model:
            self.setup_model()

    def _get_pretrain_placeholders(self):
        policy = self.policy_pi
        action_ph = policy.pdtype.sample_placeholder([None])
        if isinstance(self.action_space, gym.spaces.Discrete):
            return policy.obs_ph, action_ph, policy.policy
        return policy.obs_ph, action_ph, policy.deterministic_action

    def setup_model(self):
        with SetVerbosity(self.verbose):

            self.graph = tf.Graph()
            with self.graph.as_default():
                self.set_random_seed(self.seed)
                self.sess = tf_util.make_session(num_cpu=self.n_cpu_tf_sess,
                                                 graph=self.graph)

                # Construct network for new policy
                self.policy_pi = self.policy(self.sess,
                                             self.observation_space,
                                             self.action_space,
                                             self.n_envs,
                                             1,
                                             None,
                                             reuse=False,
                                             **self.policy_kwargs)

                # Network for old policy
                with tf.compat.v1.variable_scope("oldpi", reuse=False):
                    old_pi = self.policy(self.sess,
                                         self.observation_space,
                                         self.action_space,
                                         self.n_envs,
                                         1,
                                         None,
                                         reuse=False,
                                         **self.policy_kwargs)

                with tf.compat.v1.variable_scope("loss", reuse=False):
                    # Target advantage function (if applicable)
                    atarg = tf.compat.v1.placeholder(dtype=tf.float32,
                                                     shape=[None])

                    # Empirical return
                    ret = tf.compat.v1.placeholder(dtype=tf.float32,
                                                   shape=[None])

                    # learning rate multiplier, updated with schedule
                    lrmult = tf.compat.v1.placeholder(name='lrmult',
                                                      dtype=tf.float32,
                                                      shape=[])

                    # Annealed cliping parameter epislon
                    clip_param = self.clip_param * lrmult

                    obs_ph = self.policy_pi.obs_ph
                    action_ph = self.policy_pi.pdtype.sample_placeholder(
                        [None])

                    kloldnew = old_pi.proba_distribution.kl(
                        self.policy_pi.proba_distribution)
                    ent = self.policy_pi.proba_distribution.entropy()
                    meankl = tf.reduce_mean(input_tensor=kloldnew)
                    meanent = tf.reduce_mean(input_tensor=ent)
                    pol_entpen = (-self.entcoeff) * meanent

                    # pnew / pold
                    ratio = tf.exp(
                        self.policy_pi.proba_distribution.logp(action_ph) -
                        old_pi.proba_distribution.logp(action_ph))

                    # surrogate from conservative policy iteration
                    surr1 = ratio * atarg
                    surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                                             1.0 + clip_param) * atarg

                    # PPO's pessimistic surrogate (L^CLIP)
                    pol_surr = -tf.reduce_mean(
                        input_tensor=tf.minimum(surr1, surr2))
                    vf_loss = tf.reduce_mean(
                        input_tensor=tf.square(self.policy_pi.value_flat -
                                               ret))
                    total_loss = pol_surr + pol_entpen + vf_loss
                    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
                    self.loss_names = [
                        "pol_surr", "pol_entpen", "vf_loss", "kl", "ent"
                    ]

                    tf.compat.v1.summary.scalar('entropy_loss', pol_entpen)
                    tf.compat.v1.summary.scalar('policy_gradient_loss',
                                                pol_surr)
                    tf.compat.v1.summary.scalar('value_function_loss', vf_loss)
                    tf.compat.v1.summary.scalar('approximate_kullback-leibler',
                                                meankl)
                    tf.compat.v1.summary.scalar('clip_factor', clip_param)
                    tf.compat.v1.summary.scalar('loss', total_loss)

                    self.params = tf_util.get_trainable_vars("model")

                    self.assign_old_eq_new = tf_util.function(
                        [], [],
                        updates=[
                            tf.compat.v1.assign(oldv, newv)
                            for (oldv, newv) in zipsame(
                                tf_util.get_globals_vars("oldpi"),
                                tf_util.get_globals_vars("model"))
                        ])

                with tf.compat.v1.variable_scope("Adam_mpi", reuse=False):
                    self.adam = MpiAdam(self.params,
                                        epsilon=self.adam_epsilon,
                                        sess=self.sess)

                with tf.compat.v1.variable_scope("input_info", reuse=False):
                    tf.compat.v1.summary.scalar(
                        'discounted_rewards', tf.reduce_mean(input_tensor=ret))
                    tf.compat.v1.summary.scalar(
                        'learning_rate',
                        tf.reduce_mean(input_tensor=self.optim_stepsize))
                    tf.compat.v1.summary.scalar(
                        'advantage', tf.reduce_mean(input_tensor=atarg))
                    tf.compat.v1.summary.scalar(
                        'clip_range',
                        tf.reduce_mean(input_tensor=self.clip_param))

                    if self.full_tensorboard_log:
                        tf.compat.v1.summary.histogram('discounted_rewards',
                                                       ret)
                        tf.compat.v1.summary.histogram('learning_rate',
                                                       self.optim_stepsize)
                        tf.compat.v1.summary.histogram('advantage', atarg)
                        tf.compat.v1.summary.histogram('clip_range',
                                                       self.clip_param)
                        if tf_util.is_image(self.observation_space):
                            tf.compat.v1.summary.image('observation', obs_ph)
                        else:
                            tf.compat.v1.summary.histogram(
                                'observation', obs_ph)

                self.step = self.policy_pi.step
                self.proba_step = self.policy_pi.proba_step
                self.initial_state = self.policy_pi.initial_state

                tf_util.initialize(sess=self.sess)

                self.summary = tf.compat.v1.summary.merge_all()

                self.lossandgrad = tf_util.function(
                    [obs_ph, old_pi.obs_ph, action_ph, atarg, ret, lrmult],
                    [self.summary,
                     tf_util.flatgrad(total_loss, self.params)] + losses)
                self.compute_losses = tf_util.function(
                    [obs_ph, old_pi.obs_ph, action_ph, atarg, ret, lrmult],
                    losses)

    def learn(self,
              total_timesteps,
              callback=None,
              log_interval=100,
              tb_log_name="PPO1",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
                as writer:
            self._setup_learn()

            assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the PPO1 model must be " \
                                                               "an instance of common.policies.ActorCriticPolicy."

            with self.sess.as_default():
                self.adam.sync()
                callback.on_training_start(locals(), globals())

                # Prepare for rollouts
                seg_gen = traj_segment_generator(self.policy_pi,
                                                 self.env,
                                                 self.timesteps_per_actorbatch,
                                                 callback=callback)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()

                # rolling buffer for episode lengths
                len_buffer = deque(maxlen=100)
                # rolling buffer for episode rewards
                reward_buffer = deque(maxlen=100)

                while True:
                    if timesteps_so_far >= total_timesteps:
                        break

                    if self.schedule == 'constant':
                        cur_lrmult = 1.0
                    elif self.schedule == 'linear':
                        cur_lrmult = max(
                            1.0 - float(timesteps_so_far) / total_timesteps, 0)
                    else:
                        raise NotImplementedError

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    seg = seg_gen.__next__()

                    # Stop training early (triggered by the callback)
                    if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                        break

                    add_vtarg_and_adv(seg, self.gamma, self.lam)

                    # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                    observations, actions = seg["observations"], seg["actions"]
                    atarg, tdlamret = seg["adv"], seg["tdlamret"]

                    # true_rew is the reward without discount
                    if writer is not None:
                        total_episode_reward_logger(
                            self.episode_reward, seg["true_rewards"].reshape(
                                (self.n_envs, -1)), seg["dones"].reshape(
                                    (self.n_envs, -1)), writer,
                            self.num_timesteps)

                    # predicted value function before udpate
                    vpredbefore = seg["vpred"]

                    # standardized advantage function estimate
                    atarg = (atarg - atarg.mean()) / atarg.std()
                    dataset = Dataset(dict(ob=observations,
                                           ac=actions,
                                           atarg=atarg,
                                           vtarg=tdlamret),
                                      shuffle=not self.policy.recurrent)
                    optim_batchsize = self.optim_batchsize or observations.shape[
                        0]

                    # set old parameter values to new parameter values
                    self.assign_old_eq_new(sess=self.sess)
                    logger.log("Optimizing...")
                    logger.log(fmt_row(13, self.loss_names))

                    # Here we do a bunch of optimization epochs over the data
                    for k in range(self.optim_epochs):
                        # list of tuples, each of which gives the loss for a minibatch
                        losses = []
                        for i, batch in enumerate(
                                dataset.iterate_once(optim_batchsize)):
                            steps = (
                                self.num_timesteps + k * optim_batchsize +
                                int(i *
                                    (optim_batchsize / len(dataset.data_map))))
                            if writer is not None:
                                # run loss backprop with summary, but once every 10 runs save the metadata
                                # (memory, compute time, ...)
                                if self.full_tensorboard_log and (1 +
                                                                  k) % 10 == 0:
                                    run_options = tf.compat.v1.RunOptions(
                                        trace_level=tf.compat.v1.RunOptions.
                                        FULL_TRACE)
                                    run_metadata = tf.compat.v1.RunMetadata()
                                    summary, grad, *newlosses = self.lossandgrad(
                                        batch["ob"],
                                        batch["ob"],
                                        batch["ac"],
                                        batch["atarg"],
                                        batch["vtarg"],
                                        cur_lrmult,
                                        sess=self.sess,
                                        options=run_options,
                                        run_metadata=run_metadata)
                                    writer.add_run_metadata(
                                        run_metadata, 'step%d' % steps)
                                else:
                                    summary, grad, *newlosses = self.lossandgrad(
                                        batch["ob"],
                                        batch["ob"],
                                        batch["ac"],
                                        batch["atarg"],
                                        batch["vtarg"],
                                        cur_lrmult,
                                        sess=self.sess)
                                writer.add_summary(summary, steps)
                            else:
                                _, grad, *newlosses = self.lossandgrad(
                                    batch["ob"],
                                    batch["ob"],
                                    batch["ac"],
                                    batch["atarg"],
                                    batch["vtarg"],
                                    cur_lrmult,
                                    sess=self.sess)

                            self.adam.update(grad,
                                             self.optim_stepsize * cur_lrmult)
                            losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(losses, axis=0)))

                    logger.log("Evaluating losses...")
                    losses = []
                    for batch in dataset.iterate_once(optim_batchsize):
                        newlosses = self.compute_losses(batch["ob"],
                                                        batch["ob"],
                                                        batch["ac"],
                                                        batch["atarg"],
                                                        batch["vtarg"],
                                                        cur_lrmult,
                                                        sess=self.sess)
                        losses.append(newlosses)
                    mean_losses, _, _ = mpi_moments(losses, axis=0)
                    logger.log(fmt_row(13, mean_losses))
                    for (loss_val, name) in zipsame(mean_losses,
                                                    self.loss_names):
                        logger.record_tabular("loss_" + name, loss_val)
                    logger.record_tabular(
                        "ev_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    # local values
                    lrlocal = (seg["ep_lens"], seg["ep_rets"])

                    # list of tuples
                    listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)
                    lens, rews = map(flatten_lists, zip(*listoflrpairs))
                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)
                    if len(len_buffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))
                        logger.record_tabular("EpRewMean",
                                              np.mean(reward_buffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(
                        seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1
                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)
                    if self.verbose >= 1 and MPI.COMM_WORLD.Get_rank() == 0:
                        logger.dump_tabular()
        callback.on_training_end()
        return self

    def save(self, save_path, cloudpickle=False):
        data = {
            "gamma": self.gamma,
            "timesteps_per_actorbatch": self.timesteps_per_actorbatch,
            "clip_param": self.clip_param,
            "entcoeff": self.entcoeff,
            "optim_epochs": self.optim_epochs,
            "optim_stepsize": self.optim_stepsize,
            "optim_batchsize": self.optim_batchsize,
            "lam": self.lam,
            "adam_epsilon": self.adam_epsilon,
            "schedule": self.schedule,
            "verbose": self.verbose,
            "policy": self.policy,
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "n_envs": self.n_envs,
            "n_cpu_tf_sess": self.n_cpu_tf_sess,
            "seed": self.seed,
            "_vectorize_action": self._vectorize_action,
            "policy_kwargs": self.policy_kwargs
        }

        params_to_save = self.get_parameters()

        self._save_to_file(save_path,
                           data=data,
                           params=params_to_save,
                           cloudpickle=cloudpickle)
Esempio n. 4
0
class DDPG(OffPolicyRLModel):
    """
    Deep Deterministic Policy Gradient (DDPG) model

    DDPG: https://arxiv.org/pdf/1509.02971.pdf

    :param policy: (DDPGPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, LnMlpPolicy, ...)
    :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
    :param gamma: (float) the discount factor
    :param memory_policy: (Memory) the replay buffer (if None, default to baselines.ddpg.memory.Memory)
    :param eval_env: (Gym Environment) the evaluation environment (can be None)
    :param nb_train_steps: (int) the number of training steps
    :param nb_rollout_steps: (int) the number of rollout steps
    :param nb_eval_steps: (int) the number of evalutation steps
    :param param_noise: (AdaptiveParamNoiseSpec) the parameter noise type (can be None)
    :param action_noise: (ActionNoise) the action noise type (can be None)
    :param param_noise_adaption_interval: (int) apply param noise every N steps
    :param tau: (float) the soft update coefficient (keep old values, between 0 and 1)
    :param normalize_returns: (bool) should the critic output be normalized
    :param enable_popart: (bool) enable pop-art normalization of the critic output
        (https://arxiv.org/pdf/1602.07714.pdf)
    :param normalize_observations: (bool) should the observation be normalized
    :param batch_size: (int) the size of the batch for learning the policy
    :param observation_range: (tuple) the bounding values for the observation
    :param return_range: (tuple) the bounding values for the critic output
    :param critic_l2_reg: (float) l2 regularizer coefficient
    :param actor_lr: (float) the actor learning rate
    :param critic_lr: (float) the critic learning rate
    :param clip_norm: (float) clip the gradients (disabled if None)
    :param reward_scale: (float) the value the reward should be scaled by
    :param render: (bool) enable rendering of the environment
    :param render_eval: (bool) enable rendering of the evalution environment
    :param memory_limit: (int) the max number of transitions to store
    :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
    :param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
    :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
    """

    def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, nb_train_steps=50,
                 nb_rollout_steps=100, nb_eval_steps=100, param_noise=None, action_noise=None,
                 normalize_observations=False, tau=0.001, batch_size=128, param_noise_adaption_interval=50,
                 normalize_returns=False, enable_popart=False, observation_range=(-5., 5.), critic_l2_reg=0.,
                 return_range=(-np.inf, np.inf), actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.,
                 render=False, render_eval=False, memory_limit=100, verbose=0, tensorboard_log=None,
                 _init_setup_model=True):

        # TODO: replay_buffer refactoring
        super(DDPG, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DDPGPolicy,
                                   requires_vec_env=False)

        # Parameters.
        self.gamma = gamma
        self.tau = tau
        self.memory_policy = memory_policy or Memory
        self.normalize_observations = normalize_observations
        self.normalize_returns = normalize_returns
        self.action_noise = action_noise
        self.param_noise = param_noise
        self.return_range = return_range
        self.observation_range = observation_range
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.clip_norm = clip_norm
        self.enable_popart = enable_popart
        self.reward_scale = reward_scale
        self.batch_size = batch_size
        self.critic_l2_reg = critic_l2_reg
        self.eval_env = eval_env
        self.render = render
        self.render_eval = render_eval
        self.nb_eval_steps = nb_eval_steps
        self.param_noise_adaption_interval = param_noise_adaption_interval
        self.nb_train_steps = nb_train_steps
        self.nb_rollout_steps = nb_rollout_steps
        self.memory_limit = memory_limit
        self.tensorboard_log = tensorboard_log

        # init
        self.graph = None
        self.stats_sample = None
        self.memory = None
        self.policy_tf = None
        self.target_init_updates = None
        self.target_soft_updates = None
        self.critic_loss = None
        self.critic_grads = None
        self.critic_optimizer = None
        self.sess = None
        self.stats_ops = None
        self.stats_names = None
        self.perturbed_actor_tf = None
        self.perturb_policy_ops = None
        self.perturb_adaptive_policy_ops = None
        self.adaptive_policy_distance = None
        self.actor_loss = None
        self.actor_grads = None
        self.actor_optimizer = None
        self.old_std = None
        self.old_mean = None
        self.renormalize_q_outputs_op = None
        self.obs_rms = None
        self.ret_rms = None
        self.target_policy = None
        self.actor_tf = None
        self.normalized_critic_tf = None
        self.critic_tf = None
        self.normalized_critic_with_actor_tf = None
        self.critic_with_actor_tf = None
        self.target_q = None
        self.obs_train = None
        self.action_train_ph = None
        self.obs_target = None
        self.action_target = None
        self.obs_noise = None
        self.action_noise_ph = None
        self.obs_adapt_noise = None
        self.action_adapt_noise = None
        self.terminals1 = None
        self.rewards = None
        self.actions = None
        self.critic_target = None
        self.param_noise_stddev = None
        self.param_noise_actor = None
        self.adaptive_param_noise_actor = None
        self.params = None
        self.summary = None
        self.episode_reward = None
        self.tb_seen_steps = None
        self.target_params = None

        if _init_setup_model:
            self.setup_model()

    def setup_model(self):
        with SetVerbosity(self.verbose):

            assert isinstance(self.action_space, gym.spaces.Box), \
                "Error: DDPG cannot output a {} action space, only spaces.Box is supported.".format(self.action_space)
            assert issubclass(self.policy, DDPGPolicy), "Error: the input policy for the DDPG model must be " \
                                                        "an instance of DDPGPolicy."

            self.graph = tf.Graph()
            with self.graph.as_default():
                self.sess = tf_util.single_threaded_session(graph=self.graph)

                self.memory = self.memory_policy(limit=self.memory_limit, action_shape=self.action_space.shape,
                                                 observation_shape=self.observation_space.shape)

                with tf.variable_scope("input", reuse=False):
                    # Observation normalization.
                    if self.normalize_observations:
                        with tf.variable_scope('obs_rms'):
                            self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
                    else:
                        self.obs_rms = None

                    # Return normalization.
                    if self.normalize_returns:
                        with tf.variable_scope('ret_rms'):
                            self.ret_rms = RunningMeanStd()
                    else:
                        self.ret_rms = None

                    self.policy_tf = self.policy(self.sess, self.observation_space, self.action_space, 1, 1, None)

                    # Create target networks.
                    self.target_policy = self.policy(self.sess, self.observation_space, self.action_space, 1, 1, None)
                    self.obs_target = self.target_policy.obs_ph
                    self.action_target = self.target_policy.action_ph

                    normalized_obs0 = tf.clip_by_value(normalize(self.policy_tf.processed_obs, self.obs_rms),
                                                       self.observation_range[0], self.observation_range[1])
                    normalized_obs1 = tf.clip_by_value(normalize(self.target_policy.processed_obs, self.obs_rms),
                                                       self.observation_range[0], self.observation_range[1])

                    if self.param_noise is not None:
                        # Configure perturbed actor.
                        self.param_noise_actor = self.policy(self.sess, self.observation_space, self.action_space, 1, 1,
                                                             None)
                        self.obs_noise = self.param_noise_actor.obs_ph
                        self.action_noise_ph = self.param_noise_actor.action_ph

                        # Configure separate copy for stddev adoption.
                        self.adaptive_param_noise_actor = self.policy(self.sess, self.observation_space,
                                                                      self.action_space, 1, 1, None)
                        self.obs_adapt_noise = self.adaptive_param_noise_actor.obs_ph
                        self.action_adapt_noise = self.adaptive_param_noise_actor.action_ph

                    # Inputs.
                    self.obs_train = self.policy_tf.obs_ph
                    self.action_train_ph = self.policy_tf.action_ph
                    self.terminals1 = tf.placeholder(tf.float32, shape=(None, 1), name='terminals1')
                    self.rewards = tf.placeholder(tf.float32, shape=(None, 1), name='rewards')
                    self.actions = tf.placeholder(tf.float32, shape=(None,) + self.action_space.shape, name='actions')
                    self.critic_target = tf.placeholder(tf.float32, shape=(None, 1), name='critic_target')
                    self.param_noise_stddev = tf.placeholder(tf.float32, shape=(), name='param_noise_stddev')

                # Create networks and core TF parts that are shared across setup parts.
                with tf.variable_scope("model", reuse=False):
                    self.actor_tf = self.policy_tf.make_actor(normalized_obs0)
                    self.normalized_critic_tf = self.policy_tf.make_critic(normalized_obs0, self.actions)
                    self.normalized_critic_with_actor_tf = self.policy_tf.make_critic(normalized_obs0,
                                                                                      self.actor_tf,
                                                                                      reuse=True)
                # Noise setup
                if self.param_noise is not None:
                    self._setup_param_noise(normalized_obs0)

                with tf.variable_scope("target", reuse=False):
                    critic_target = self.target_policy.make_critic(normalized_obs1,
                                                                   self.target_policy.make_actor(normalized_obs1))

                with tf.variable_scope("loss", reuse=False):
                    self.critic_tf = denormalize(
                        tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]),
                        self.ret_rms)

                    self.critic_with_actor_tf = denormalize(
                        tf.clip_by_value(self.normalized_critic_with_actor_tf,
                                         self.return_range[0], self.return_range[1]),
                        self.ret_rms)

                    q_obs1 = denormalize(critic_target, self.ret_rms)
                    self.target_q = self.rewards + (1. - self.terminals1) * self.gamma * q_obs1

                    tf.summary.scalar('critic_target', tf.reduce_mean(self.critic_target))
                    tf.summary.histogram('critic_target', self.critic_target)

                    # Set up parts.
                    if self.normalize_returns and self.enable_popart:
                        self._setup_popart()
                    self._setup_stats()
                    self._setup_target_network_updates()

                with tf.variable_scope("input_info", reuse=False):
                    tf.summary.scalar('rewards', tf.reduce_mean(self.rewards))
                    tf.summary.histogram('rewards', self.rewards)
                    tf.summary.scalar('param_noise_stddev', tf.reduce_mean(self.param_noise_stddev))
                    tf.summary.histogram('param_noise_stddev', self.param_noise_stddev)
                    if len(self.observation_space.shape) == 3 and self.observation_space.shape[0] in [1, 3, 4]:
                        tf.summary.image('observation', self.obs_train)
                    else:
                        tf.summary.histogram('observation', self.obs_train)

                with tf.variable_scope("Adam_mpi", reuse=False):
                    self._setup_actor_optimizer()
                    self._setup_critic_optimizer()
                    tf.summary.scalar('actor_loss', self.actor_loss)
                    tf.summary.scalar('critic_loss', self.critic_loss)

                self.params = find_trainable_variables("model")
                self.target_params = find_trainable_variables("target")

                with self.sess.as_default():
                    self._initialize(self.sess)

                self.summary = tf.summary.merge_all()

    def _setup_target_network_updates(self):
        """
        set the target update operations
        """
        init_updates, soft_updates = get_target_updates(tf_util.get_trainable_vars('model/'),
                                                        tf_util.get_trainable_vars('target/'), self.tau,
                                                        self.verbose)
        self.target_init_updates = init_updates
        self.target_soft_updates = soft_updates

    def _setup_param_noise(self, normalized_obs0):
        """
        set the parameter noise operations

        :param normalized_obs0: (TensorFlow Tensor) the normalized observation
        """
        assert self.param_noise is not None

        with tf.variable_scope("noise", reuse=False):
            self.perturbed_actor_tf = self.param_noise_actor.make_actor(normalized_obs0)

        with tf.variable_scope("noise_adapt", reuse=False):
            adaptive_actor_tf = self.adaptive_param_noise_actor.make_actor(normalized_obs0)

        with tf.variable_scope("noise_update_func", reuse=False):
            if self.verbose >= 2:
                logger.info('setting up param noise')
            self.perturb_policy_ops = get_perturbed_actor_updates('model/pi/', 'noise/pi/', self.param_noise_stddev,
                                                                  verbose=self.verbose)

            self.perturb_adaptive_policy_ops = get_perturbed_actor_updates('model/pi/', 'noise_adapt/pi/',
                                                                           self.param_noise_stddev,
                                                                           verbose=self.verbose)
            self.adaptive_policy_distance = tf.sqrt(tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))

    def _setup_actor_optimizer(self):
        """
        setup the optimizer for the actor
        """
        if self.verbose >= 2:
            logger.info('setting up actor optimizer')
        self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)
        actor_shapes = [var.get_shape().as_list() for var in tf_util.get_trainable_vars('model/pi/')]
        actor_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
        if self.verbose >= 2:
            logger.info('  actor shapes: {}'.format(actor_shapes))
            logger.info('  actor params: {}'.format(actor_nb_params))
        self.actor_grads = tf_util.flatgrad(self.actor_loss, tf_util.get_trainable_vars('model/pi/'),
                                            clip_norm=self.clip_norm)
        self.actor_optimizer = MpiAdam(var_list=tf_util.get_trainable_vars('model/pi/'), beta1=0.9, beta2=0.999,
                                       epsilon=1e-08)

    def _setup_critic_optimizer(self):
        """
        setup the optimizer for the critic
        """
        if self.verbose >= 2:
            logger.info('setting up critic optimizer')
        normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms),
                                                       self.return_range[0], self.return_range[1])
        self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
        if self.critic_l2_reg > 0.:
            critic_reg_vars = [var for var in tf_util.get_trainable_vars('model/qf/')
                               if 'bias' not in var.name and 'output' not in var.name and 'b' not in var.name]
            if self.verbose >= 2:
                for var in critic_reg_vars:
                    logger.info('  regularizing: {}'.format(var.name))
                logger.info('  applying l2 regularization with {}'.format(self.critic_l2_reg))
            critic_reg = tc.layers.apply_regularization(
                tc.layers.l2_regularizer(self.critic_l2_reg),
                weights_list=critic_reg_vars
            )
            self.critic_loss += critic_reg
        critic_shapes = [var.get_shape().as_list() for var in tf_util.get_trainable_vars('model/qf/')]
        critic_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
        if self.verbose >= 2:
            logger.info('  critic shapes: {}'.format(critic_shapes))
            logger.info('  critic params: {}'.format(critic_nb_params))
        self.critic_grads = tf_util.flatgrad(self.critic_loss, tf_util.get_trainable_vars('model/qf/'),
                                             clip_norm=self.clip_norm)
        self.critic_optimizer = MpiAdam(var_list=tf_util.get_trainable_vars('model/qf/'), beta1=0.9, beta2=0.999,
                                        epsilon=1e-08)

    def _setup_popart(self):
        """
        setup pop-art normalization of the critic output

        See https://arxiv.org/pdf/1602.07714.pdf for details.
        Preserving Outputs Precisely, while Adaptively Rescaling Targets”.
        """
        self.old_std = tf.placeholder(tf.float32, shape=[1], name='old_std')
        new_std = self.ret_rms.std
        self.old_mean = tf.placeholder(tf.float32, shape=[1], name='old_mean')
        new_mean = self.ret_rms.mean

        self.renormalize_q_outputs_op = []
        for out_vars in [[var for var in tf_util.get_trainable_vars('model/qf/') if 'output' in var.name],
                         [var for var in tf_util.get_trainable_vars('target/qf/') if 'output' in var.name]]:
            assert len(out_vars) == 2
            # wieght and bias of the last layer
            weight, bias = out_vars
            assert 'kernel' in weight.name
            assert 'bias' in bias.name
            assert weight.get_shape()[-1] == 1
            assert bias.get_shape()[-1] == 1
            self.renormalize_q_outputs_op += [weight.assign(weight * self.old_std / new_std)]
            self.renormalize_q_outputs_op += [bias.assign((bias * self.old_std + self.old_mean - new_mean) / new_std)]

    def _setup_stats(self):
        """
        setup the running means and std of the inputs and outputs of the model
        """
        ops = []
        names = []

        if self.normalize_returns:
            ops += [self.ret_rms.mean, self.ret_rms.std]
            names += ['ret_rms_mean', 'ret_rms_std']

        if self.normalize_observations:
            ops += [tf.reduce_mean(self.obs_rms.mean), tf.reduce_mean(self.obs_rms.std)]
            names += ['obs_rms_mean', 'obs_rms_std']

        ops += [tf.reduce_mean(self.critic_tf)]
        names += ['reference_Q_mean']
        ops += [reduce_std(self.critic_tf)]
        names += ['reference_Q_std']

        ops += [tf.reduce_mean(self.critic_with_actor_tf)]
        names += ['reference_actor_Q_mean']
        ops += [reduce_std(self.critic_with_actor_tf)]
        names += ['reference_actor_Q_std']

        ops += [tf.reduce_mean(self.actor_tf)]
        names += ['reference_action_mean']
        ops += [reduce_std(self.actor_tf)]
        names += ['reference_action_std']

        if self.param_noise:
            ops += [tf.reduce_mean(self.perturbed_actor_tf)]
            names += ['reference_perturbed_action_mean']
            ops += [reduce_std(self.perturbed_actor_tf)]
            names += ['reference_perturbed_action_std']

        self.stats_ops = ops
        self.stats_names = names

    def _policy(self, obs, apply_noise=True, compute_q=True):
        """
        Get the actions and critic output, from a given observation

        :param obs: ([float] or [int]) the observation
        :param apply_noise: (bool) enable the noise
        :param compute_q: (bool) compute the critic output
        :return: ([float], float) the action and critic value
        """
        obs = np.array(obs).reshape((-1,) + self.observation_space.shape)
        feed_dict = {self.obs_train: obs}
        if self.param_noise is not None and apply_noise:
            actor_tf = self.perturbed_actor_tf
            feed_dict[self.obs_noise] = obs
        else:
            actor_tf = self.actor_tf

        if compute_q:
            action, q_value = self.sess.run([actor_tf, self.critic_with_actor_tf], feed_dict=feed_dict)
        else:
            action = self.sess.run(actor_tf, feed_dict=feed_dict)
            q_value = None

        action = action.flatten()
        if self.action_noise is not None and apply_noise:
            noise = self.action_noise()
            assert noise.shape == action.shape
            action += noise
        action = np.clip(action, -1, 1)
        return action, q_value

    def _store_transition(self, obs0, action, reward, obs1, terminal1):
        """
        Store a transition in the replay buffer

        :param obs0: ([float] or [int]) the last observation
        :param action: ([float]) the action
        :param reward: (float] the reward
        :param obs1: ([float] or [int]) the current observation
        :param terminal1: (bool) is the episode done
        """
        reward *= self.reward_scale
        self.memory.append(obs0, action, reward, obs1, terminal1)
        if self.normalize_observations:
            self.obs_rms.update(np.array([obs0]))

    def _train_step(self, step, writer, log=False):
        """
        run a step of training from batch

        :param step: (int) the current step iteration
        :param writer: (TensorFlow Summary.writer) the writer for tensorboard
        :param log: (bool) whether or not to log to metadata
        :return: (float, float) critic loss, actor loss
        """
        # Get a batch
        batch = self.memory.sample(batch_size=self.batch_size)

        if self.normalize_returns and self.enable_popart:
            old_mean, old_std, target_q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_q],
                                                        feed_dict={
                                                            self.obs_target: batch['obs1'],
                                                            self.rewards: batch['rewards'],
                                                            self.terminals1: batch['terminals1'].astype('float32')
                                                        })
            self.ret_rms.update(target_q.flatten())
            self.sess.run(self.renormalize_q_outputs_op, feed_dict={
                self.old_std: np.array([old_std]),
                self.old_mean: np.array([old_mean]),
            })

        else:
            target_q = self.sess.run(self.target_q, feed_dict={
                self.obs_target: batch['obs1'],
                self.rewards: batch['rewards'],
                self.terminals1: batch['terminals1'].astype('float32')
            })

        # Get all gradients and perform a synced update.
        ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss]
        td_map = {
            self.obs_train: batch['obs0'],
            self.actions: batch['actions'],
            self.action_train_ph: batch['actions'],
            self.rewards: batch['rewards'],
            self.critic_target: target_q,
            self.param_noise_stddev: 0 if self.param_noise is None else self.param_noise.current_stddev
        }
        if writer is not None:
            # run loss backprop with summary if the step_id was not already logged (can happen with the right
            # parameters as the step value is only an estimate)
            if log and step not in self.tb_seen_steps:
                run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
                summary, actor_grads, actor_loss, critic_grads, critic_loss = \
                    self.sess.run([self.summary] + ops, td_map, options=run_options, run_metadata=run_metadata)

                writer.add_run_metadata(run_metadata, 'step%d' % step)
                self.tb_seen_steps.append(step)
            else:
                summary, actor_grads, actor_loss, critic_grads, critic_loss = self.sess.run([self.summary] + ops,
                                                                                            td_map)
            writer.add_summary(summary, step)
        else:
            actor_grads, actor_loss, critic_grads, critic_loss = self.sess.run(ops, td_map)

        self.actor_optimizer.update(actor_grads, learning_rate=self.actor_lr)
        self.critic_optimizer.update(critic_grads, learning_rate=self.critic_lr)

        return critic_loss, actor_loss

    def _initialize(self, sess):
        """
        initialize the model parameters and optimizers

        :param sess: (TensorFlow Session) the current TensorFlow session
        """
        self.sess = sess
        self.sess.run(tf.global_variables_initializer())
        self.actor_optimizer.sync()
        self.critic_optimizer.sync()
        self.sess.run(self.target_init_updates)

    def _update_target_net(self):
        """
        run target soft update operation
        """
        self.sess.run(self.target_soft_updates)

    def _get_stats(self):
        """
        Get the mean and standard deviation of the model's inputs and outputs

        :return: (dict) the means and stds
        """
        if self.stats_sample is None:
            # Get a sample and keep that fixed for all further computations.
            # This allows us to estimate the change in value for the same set of inputs.
            self.stats_sample = self.memory.sample(batch_size=self.batch_size)

        feed_dict = {
            self.actions: self.stats_sample['actions']
        }

        for placeholder in [self.action_train_ph, self.action_target, self.action_adapt_noise, self.action_noise_ph]:
            if placeholder is not None:
                feed_dict[placeholder] = self.stats_sample['actions']

        for placeholder in [self.obs_train, self.obs_target, self.obs_adapt_noise, self.obs_noise]:
            if placeholder is not None:
                feed_dict[placeholder] = self.stats_sample['obs0']

        values = self.sess.run(self.stats_ops, feed_dict=feed_dict)

        names = self.stats_names[:]
        assert len(names) == len(values)
        stats = dict(zip(names, values))

        if self.param_noise is not None:
            stats = {**stats, **self.param_noise.get_stats()}

        return stats

    def _adapt_param_noise(self):
        """
        calculate the adaptation for the parameter noise

        :return: (float) the mean distance for the parameter noise
        """
        if self.param_noise is None:
            return 0.

        # Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
        batch = self.memory.sample(batch_size=self.batch_size)
        self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
            self.param_noise_stddev: self.param_noise.current_stddev,
        })
        distance = self.sess.run(self.adaptive_policy_distance, feed_dict={
            self.obs_adapt_noise: batch['obs0'], self.obs_train: batch['obs0'],
            self.param_noise_stddev: self.param_noise.current_stddev,
        })

        mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
        self.param_noise.adapt(mean_distance)
        return mean_distance

    def _reset(self):
        """
        Reset internal state after an episode is complete.
        """
        if self.action_noise is not None:
            self.action_noise.reset()
        if self.param_noise is not None:
            self.sess.run(self.perturb_policy_ops, feed_dict={
                self.param_noise_stddev: self.param_noise.current_stddev,
            })

    def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="DDPG"):
        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer:
            self._setup_learn(seed)

            # a list for tensorboard logging, to prevent logging with the same step number, if it already occured
            self.tb_seen_steps = []

            rank = MPI.COMM_WORLD.Get_rank()
            # we assume symmetric actions.
            assert np.all(np.abs(self.env.action_space.low) == self.env.action_space.high)
            if self.verbose >= 2:
                logger.log('Using agent with the following configuration:')
                logger.log(str(self.__dict__.items()))

            eval_episode_rewards_history = deque(maxlen=100)
            episode_rewards_history = deque(maxlen=100)
            self.episode_reward = np.zeros((1,))
            with self.sess.as_default(), self.graph.as_default():
                # Prepare everything.
                self._reset()
                obs = self.env.reset()
                eval_obs = None
                if self.eval_env is not None:
                    eval_obs = self.eval_env.reset()
                episode_reward = 0.
                episode_step = 0
                episodes = 0
                step = 0
                total_steps = 0

                start_time = time.time()

                epoch_episode_rewards = []
                epoch_episode_steps = []
                epoch_actor_losses = []
                epoch_critic_losses = []
                epoch_adaptive_distances = []
                eval_episode_rewards = []
                eval_qs = []
                epoch_actions = []
                epoch_qs = []
                epoch_episodes = 0
                epoch = 0
                while True:
                    for _ in range(log_interval):
                        # Perform rollouts.
                        for _ in range(self.nb_rollout_steps):
                            if total_steps >= total_timesteps:
                                return self

                            # Predict next action.
                            action, q_value = self._policy(obs, apply_noise=True, compute_q=True)
                            assert action.shape == self.env.action_space.shape

                            # Execute next action.
                            if rank == 0 and self.render:
                                self.env.render()
                            new_obs, reward, done, _ = self.env.step(action * np.abs(self.action_space.low))

                            if writer is not None:
                                ep_rew = np.array([reward]).reshape((1, -1))
                                ep_done = np.array([done]).reshape((1, -1))
                                self.episode_reward = total_episode_reward_logger(self.episode_reward, ep_rew, ep_done,
                                                                                  writer, total_steps)
                            step += 1
                            total_steps += 1
                            if rank == 0 and self.render:
                                self.env.render()
                            episode_reward += reward
                            episode_step += 1

                            # Book-keeping.
                            epoch_actions.append(action)
                            epoch_qs.append(q_value)
                            self._store_transition(obs, action, reward, new_obs, done)
                            obs = new_obs
                            if callback is not None:
                                # Only stop training if return value is False, not when it is None. This is for backwards
                                # compatibility with callbacks that have no return statement.
                                if callback(locals(), globals()) == False:
                                    return self

                            if done:
                                # Episode done.
                                epoch_episode_rewards.append(episode_reward)
                                episode_rewards_history.append(episode_reward)
                                epoch_episode_steps.append(episode_step)
                                episode_reward = 0.
                                episode_step = 0
                                epoch_episodes += 1
                                episodes += 1

                                self._reset()
                                if not isinstance(self.env, VecEnv):
                                    obs = self.env.reset()

                        # Train.
                        epoch_actor_losses = []
                        epoch_critic_losses = []
                        epoch_adaptive_distances = []
                        for t_train in range(self.nb_train_steps):
                            # Adapt param noise, if necessary.
                            if self.memory.nb_entries >= self.batch_size and \
                                    t_train % self.param_noise_adaption_interval == 0:
                                distance = self._adapt_param_noise()
                                epoch_adaptive_distances.append(distance)

                            # weird equation to deal with the fact the nb_train_steps will be different
                            # to nb_rollout_steps
                            step = (int(t_train * (self.nb_rollout_steps / self.nb_train_steps)) +
                                    total_steps - self.nb_rollout_steps)

                            critic_loss, actor_loss = self._train_step(step, writer, log=t_train == 0)
                            epoch_critic_losses.append(critic_loss)
                            epoch_actor_losses.append(actor_loss)
                            self._update_target_net()

                        # Evaluate.
                        eval_episode_rewards = []
                        eval_qs = []
                        if self.eval_env is not None:
                            eval_episode_reward = 0.
                            for _ in range(self.nb_eval_steps):
                                if total_steps >= total_timesteps:
                                    return self

                                eval_action, eval_q = self._policy(eval_obs, apply_noise=False, compute_q=True)
                                eval_obs, eval_r, eval_done, _ = self.eval_env.step(eval_action *
                                                                                    np.abs(self.action_space.low))
                                if self.render_eval:
                                    self.eval_env.render()
                                eval_episode_reward += eval_r

                                eval_qs.append(eval_q)
                                if eval_done:
                                    if not isinstance(self.env, VecEnv):
                                        eval_obs = self.eval_env.reset()
                                    eval_episode_rewards.append(eval_episode_reward)
                                    eval_episode_rewards_history.append(eval_episode_reward)
                                    eval_episode_reward = 0.

                    mpi_size = MPI.COMM_WORLD.Get_size()
                    # Log stats.
                    # XXX shouldn't call np.mean on variable length lists
                    duration = time.time() - start_time
                    stats = self._get_stats()
                    combined_stats = stats.copy()
                    combined_stats['rollout/return'] = np.mean(epoch_episode_rewards)
                    combined_stats['rollout/return_history'] = np.mean(episode_rewards_history)
                    combined_stats['rollout/episode_steps'] = np.mean(epoch_episode_steps)
                    combined_stats['rollout/actions_mean'] = np.mean(epoch_actions)
                    combined_stats['rollout/Q_mean'] = np.mean(epoch_qs)
                    combined_stats['train/loss_actor'] = np.mean(epoch_actor_losses)
                    combined_stats['train/loss_critic'] = np.mean(epoch_critic_losses)
                    if len(epoch_adaptive_distances) != 0:
                        combined_stats['train/param_noise_distance'] = np.mean(epoch_adaptive_distances)
                    combined_stats['total/duration'] = duration
                    combined_stats['total/steps_per_second'] = float(step) / float(duration)
                    combined_stats['total/episodes'] = episodes
                    combined_stats['rollout/episodes'] = epoch_episodes
                    combined_stats['rollout/actions_std'] = np.std(epoch_actions)
                    # Evaluation statistics.
                    if self.eval_env is not None:
                        combined_stats['eval/return'] = eval_episode_rewards
                        combined_stats['eval/return_history'] = np.mean(eval_episode_rewards_history)
                        combined_stats['eval/Q'] = eval_qs
                        combined_stats['eval/episodes'] = len(eval_episode_rewards)

                    def as_scalar(scalar):
                        """
                        check and return the input if it is a scalar, otherwise raise ValueError

                        :param scalar: (Any) the object to check
                        :return: (Number) the scalar if x is a scalar
                        """
                        if isinstance(scalar, np.ndarray):
                            assert scalar.size == 1
                            return scalar[0]
                        elif np.isscalar(scalar):
                            return scalar
                        else:
                            raise ValueError('expected scalar, got %s' % scalar)

                    combined_stats_sums = MPI.COMM_WORLD.allreduce(
                        np.array([as_scalar(x) for x in combined_stats.values()]))
                    combined_stats = {k: v / mpi_size for (k, v) in zip(combined_stats.keys(), combined_stats_sums)}

                    # Total statistics.
                    combined_stats['total/epochs'] = epoch + 1
                    combined_stats['total/steps'] = step

                    for key in sorted(combined_stats.keys()):
                        logger.record_tabular(key, combined_stats[key])
                    logger.dump_tabular()
                    logger.info('')
                    logdir = logger.get_dir()
                    if rank == 0 and logdir:
                        if hasattr(self.env, 'get_state'):
                            with open(os.path.join(logdir, 'env_state.pkl'), 'wb') as file_handler:
                                pickle.dump(self.env.get_state(), file_handler)
                        if self.eval_env and hasattr(self.eval_env, 'get_state'):
                            with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as file_handler:
                                pickle.dump(self.eval_env.get_state(), file_handler)

    def predict(self, observation, state=None, mask=None, deterministic=True):
        observation = np.array(observation)
        vectorized_env = self._is_vectorized_observation(observation, self.observation_space)

        observation = observation.reshape((-1,) + self.observation_space.shape)
        actions, _, = self._policy(observation, apply_noise=not deterministic, compute_q=False)
        actions = actions.reshape((-1,) + self.action_space.shape)  # reshape to the correct action shape
        actions = actions * np.abs(self.action_space.low)  # scale the output for the prediction

        if not vectorized_env:
            actions = actions[0]

        return actions, None

    def action_probability(self, observation, state=None, mask=None, actions=None):
        observation = np.array(observation)

        if actions is not None:
            raise ValueError("Error: DDPG does not have action probabilities.")

        # here there are no action probabilities, as DDPG does not use a probability distribution
        warnings.warn("Warning: action probability is meaningless for DDPG. Returning None")
        return None

    def save(self, save_path):
        data = {
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "nb_eval_steps": self.nb_eval_steps,
            "param_noise_adaption_interval": self.param_noise_adaption_interval,
            "nb_train_steps": self.nb_train_steps,
            "nb_rollout_steps": self.nb_rollout_steps,
            "verbose": self.verbose,
            "param_noise": self.param_noise,
            "action_noise": self.action_noise,
            "gamma": self.gamma,
            "tau": self.tau,
            "normalize_returns": self.normalize_returns,
            "enable_popart": self.enable_popart,
            "normalize_observations": self.normalize_observations,
            "batch_size": self.batch_size,
            "observation_range": self.observation_range,
            "return_range": self.return_range,
            "critic_l2_reg": self.critic_l2_reg,
            "actor_lr": self.actor_lr,
            "critic_lr": self.critic_lr,
            "clip_norm": self.clip_norm,
            "reward_scale": self.reward_scale,
            "memory_limit": self.memory_limit,
            "policy": self.policy,
            "memory_policy": self.memory_policy,
            "n_envs": self.n_envs,
            "_vectorize_action": self._vectorize_action
        }

        params = self.sess.run(self.params)
        target_params = self.sess.run(self.target_params)

        self._save_to_file(save_path, data=data, params=params + target_params)

    @classmethod
    def load(cls, load_path, env=None, **kwargs):
        data, params = cls._load_from_file(load_path)

        model = cls(None, env, _init_setup_model=False)
        model.__dict__.update(data)
        model.__dict__.update(kwargs)
        model.set_env(env)
        model.setup_model()

        restores = []
        for param, loaded_p in zip(model.params + model.target_params, params):
            restores.append(param.assign(loaded_p))
        model.sess.run(restores)

        return model
class TRPO(ActorCriticRLModel):
    """
    Trust Region Policy Optimization (https://arxiv.org/abs/1502.05477)

    :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...)
    :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
    :param gamma: (float) the discount value
    :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon)
    :param max_kl: (float) the Kullback-Leibler loss threshold
    :param cg_iters: (int) the number of iterations for the conjugate gradient calculation
    :param lam: (float) GAE factor
    :param entcoeff: (float) the weight for the entropy loss
    :param cg_damping: (float) the compute gradient dampening factor
    :param vf_stepsize: (float) the value function stepsize
    :param vf_iters: (int) the value function's number iterations for learning
    :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
    :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
    :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
        WARNING: this logging can take a lot of space quickly
    """
    def __init__(self,
                 policy,
                 env,
                 test_env=None,
                 gamma=0.99,
                 timesteps_per_batch=1024,
                 max_kl=0.01,
                 cg_iters=10,
                 kappa=1.0,
                 entcoeff=0.0,
                 cg_damping=1e-2,
                 vf_stepsize=3e-4,
                 vf_iters=3,
                 vf_phi_update_interval=1,
                 verbose=0,
                 _init_setup_model=True,
                 policy_kwargs=None,
                 eval_freq=10,
                 seed=0):
        super(TRPO, self).__init__(policy=policy,
                                   env=env,
                                   verbose=verbose,
                                   requires_vec_env=False,
                                   _init_setup_model=_init_setup_model,
                                   policy_kwargs=policy_kwargs)

        self.timesteps_per_batch = timesteps_per_batch
        self.cg_iters = cg_iters
        self.cg_damping = cg_damping
        self.gamma = gamma
        self.kappa = kappa
        self.max_kl = max_kl
        self.vf_iters = vf_iters
        self.vf_phi_update_interval = vf_phi_update_interval
        self.vf_stepsize = vf_stepsize
        self.entcoeff = entcoeff

        # GAIL Params
        self.g_step = 1

        self.graph = None
        self.sess = None
        self.policy_pi = None
        self.loss_names = None
        self.assign_old_eq_new = None
        self.compute_losses = None
        self.compute_lossandgrad = None
        self.compute_fvp = None
        self.compute_vflossandgrad = None
        self.d_adam = None
        self.vfadam = None
        self.get_flat = None
        self.set_from_flat = None
        self.timed = None
        self.allmean = None
        self.nworkers = None
        self.rank = None
        self.reward_giver = None
        self.step = None
        self.proba_step = None
        self.initial_state = None
        self.params = None
        self.summary = None
        self.episode_reward = None
        self.test_env = test_env
        self.eval_freq = eval_freq
        self.eval_reward = None
        self.seed = seed
        self.phi_seed = seed

        if _init_setup_model:
            self.setup_model()

    def _get_pretrain_placeholders(self):
        policy = self.policy_pi
        action_ph = policy.pdtype.sample_placeholder([None])
        if isinstance(self.action_space, gym.spaces.Discrete):
            return policy.obs_ph, action_ph, policy.policy
        return policy.obs_ph, action_ph, policy.deterministic_action

    def setup_model(self):
        # prevent import loops

        with SetVerbosity(self.verbose):

            assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the TRPO model must be " \
                                                               "an instance of common.policies.ActorCriticPolicy."

            self.nworkers = MPI.COMM_WORLD.Get_size()
            print("number of workers are", self.nworkers)
            self.rank = MPI.COMM_WORLD.Get_rank()
            np.set_printoptions(precision=3)

            self.graph = tf.Graph()
            with self.graph.as_default():
                self.sess = tf_util.single_threaded_session(graph=self.graph)
                self._setup_learn(self.seed)

                # Construct network for new policy
                self.policy_pi = self.policy(self.sess,
                                             self.observation_space,
                                             self.action_space,
                                             self.n_envs,
                                             1,
                                             None,
                                             reuse=False,
                                             **self.policy_kwargs)

                # Network for old policy
                with tf.variable_scope("oldpi", reuse=False):
                    old_policy = self.policy(self.sess,
                                             self.observation_space,
                                             self.action_space,
                                             self.n_envs,
                                             1,
                                             None,
                                             reuse=False,
                                             **self.policy_kwargs)
                # Network for phi
                with tf.variable_scope("phi", reuse=False):
                    self.policy_phi = self.policy(self.sess,
                                                  self.observation_space,
                                                  self.action_space,
                                                  self.n_envs,
                                                  1,
                                                  None,
                                                  reuse=False,
                                                  **self.policy_kwargs)
                # Network for phi old
                with tf.variable_scope("oldphi", reuse=False):
                    self.policy_phi_old = self.policy(self.sess,
                                                      self.observation_space,
                                                      self.action_space,
                                                      self.n_envs,
                                                      1,
                                                      None,
                                                      reuse=False,
                                                      **self.policy_kwargs)

                with tf.variable_scope("loss", reuse=False):
                    atarg = tf.placeholder(dtype=tf.float32, shape=[
                        None
                    ])  # Target advantage function (if applicable)
                    ret = tf.placeholder(dtype=tf.float32,
                                         shape=[None])  # Empirical return

                    observation = self.policy_pi.obs_ph
                    action = self.policy_pi.pdtype.sample_placeholder([None])

                    kloldnew = old_policy.proba_distribution.kl(
                        self.policy_pi.proba_distribution)
                    #kloldnew = self.policy_pi.proba_distribution.kl(old_policy.proba_distribution)
                    ent = self.policy_pi.proba_distribution.entropy()
                    meankl = tf.reduce_mean(kloldnew)
                    meanent = tf.reduce_mean(ent)
                    entbonus = self.entcoeff * meanent

                    vferr = tf.reduce_mean(
                        tf.square(self.policy_pi.value_flat - ret))
                    vf_phi_err = tf.reduce_mean(
                        tf.square(self.policy_phi.value_flat - ret))
                    vf_phi_old_err = tf.reduce_mean(
                        tf.square(self.policy_phi_old.value_flat))

                    # advantage * pnew / pold
                    ratio = tf.exp(
                        self.policy_pi.proba_distribution.logp(action) -
                        old_policy.proba_distribution.logp(action))
                    surrgain = tf.reduce_mean(ratio * atarg)

                    optimgain = surrgain + entbonus
                    losses = [optimgain, meankl, entbonus, surrgain, meanent]
                    self.loss_names = [
                        "optimgain", "meankl", "entloss", "surrgain", "entropy"
                    ]

                    dist = meankl

                    all_var_list = tf_util.get_trainable_vars("model")
                    var_list = [
                        v for v in all_var_list
                        if "/vf" not in v.name and "/q/" not in v.name
                    ]
                    vf_var_list = [
                        v for v in all_var_list
                        if "/pi" not in v.name and "/logstd" not in v.name
                    ]
                    all_var_oldpi_list = tf_util.get_trainable_vars("oldpi")
                    var_oldpi_list = [
                        v for v in all_var_oldpi_list
                        if "/vf" not in v.name and "/q/" not in v.name
                    ]

                    all_var_phi_list = tf_util.get_trainable_vars("phi")
                    vf_phi_var_list = [
                        v for v in all_var_phi_list if "/pi" not in v.name
                        and "/logstd" not in v.name and "/q" not in v.name
                    ]
                    all_var_phi_old_list = tf_util.get_trainable_vars("oldphi")
                    vf_phi_old_var_list = [
                        v for v in all_var_phi_old_list if "/pi" not in v.name
                        and "/logstd" not in v.name and "/q" not in v.name
                    ]
                    #print("vars", vf_var_list)
                    self.policy_vars = all_var_list
                    self.oldpolicy_vars = all_var_oldpi_list
                    print("all var list", all_var_list)
                    print("phi vars", vf_phi_var_list)
                    print("phi old vars", vf_phi_old_var_list)

                    self.get_flat = tf_util.GetFlat(var_list, sess=self.sess)
                    self.set_from_flat = tf_util.SetFromFlat(var_list,
                                                             sess=self.sess)

                    klgrads = tf.gradients(dist, var_list)
                    flat_tangent = tf.placeholder(dtype=tf.float32,
                                                  shape=[None],
                                                  name="flat_tan")
                    shapes = [var.get_shape().as_list() for var in var_list]
                    start = 0
                    tangents = []
                    for shape in shapes:
                        var_size = tf_util.intprod(shape)
                        tangents.append(
                            tf.reshape(flat_tangent[start:start + var_size],
                                       shape))
                        start += var_size
                    gvp = tf.add_n([
                        tf.reduce_sum(grad * tangent)
                        for (grad, tangent) in zipsame(klgrads, tangents)
                    ])  # pylint: disable=E1111
                    fvp = tf_util.flatgrad(gvp, var_list)

                    tf.summary.scalar('entropy_loss', meanent)
                    tf.summary.scalar('policy_gradient_loss', optimgain)
                    tf.summary.scalar('value_function_loss', surrgain)
                    tf.summary.scalar('approximate_kullback-leibler', meankl)
                    tf.summary.scalar(
                        'loss',
                        optimgain + meankl + entbonus + surrgain + meanent)

                    self.assign_old_eq_new = \
                        tf_util.function([], [], updates=[tf.assign(oldv, newv) for (oldv, newv) in
                                                          zipsame(tf_util.get_globals_vars("oldpi"),
                                                                  tf_util.get_globals_vars("model"))])
                    self.compute_losses = tf_util.function(
                        [observation, old_policy.obs_ph, action, atarg],
                        losses)
                    self.compute_fvp = tf_util.function([
                        flat_tangent, observation, old_policy.obs_ph, action,
                        atarg
                    ], fvp)
                    self.compute_vflossandgrad = tf_util.function(
                        [observation, old_policy.obs_ph, ret],
                        tf_util.flatgrad(vferr, vf_var_list))
                    self.compute_vf_phi_lossandgrad = tf_util.function(
                        [observation, self.policy_phi.obs_ph, ret],
                        tf_util.flatgrad(vf_phi_err, vf_phi_var_list))
                    self.compute_vf_loss = tf_util.function(
                        [observation, old_policy.obs_ph, ret], vferr)
                    self.compute_vf_phi_loss = tf_util.function(
                        [observation, self.policy_phi.obs_ph, ret], vf_phi_err)
                    #self.compute_vf_phi_old_loss = tf_util.function([self.policy_phi_old.obs_ph], vf_phi_old_err)
                    #self.phi_old_obs = np.array([-0.012815  , -0.02076313,  0.07524705,  0.09407324,  0.0901745 , -0.09339058,  0.03544853, -0.03297224])
                    #self.phi_old_obs = self.phi_old_obs.reshape((1, 8))

                    update_phi_old_expr = []
                    for var, var_target in zip(
                            sorted(vf_phi_var_list, key=lambda v: v.name),
                            sorted(vf_phi_old_var_list, key=lambda v: v.name)):
                        update_phi_old_expr.append(var_target.assign(var))
                    update_phi_old_expr = tf.group(*update_phi_old_expr)

                    self.update_phi_old = tf_util.function(
                        [], [], updates=[update_phi_old_expr])

                    @contextmanager
                    def timed(msg):
                        if self.rank == 0 and self.verbose >= 1:
                            print(colorize(msg, color='magenta'))
                            start_time = time.time()
                            yield
                            print(
                                colorize("done in {:.3f} seconds".format(
                                    (time.time() - start_time)),
                                         color='magenta'))
                        else:
                            yield

                    @contextmanager
                    def temp_seed(seed):
                        state = np.random.get_state()
                        np.random.seed(seed)
                        try:
                            yield
                        finally:
                            np.random.set_state(state)

                    def allmean(arr):
                        assert isinstance(arr, np.ndarray)
                        out = np.empty_like(arr)
                        MPI.COMM_WORLD.Allreduce(arr, out, op=MPI.SUM)
                        out /= self.nworkers
                        return out

                    tf_util.initialize(sess=self.sess)

                    th_init = self.get_flat()
                    MPI.COMM_WORLD.Bcast(th_init, root=0)
                    self.set_from_flat(th_init)

                with tf.variable_scope("Adam_mpi", reuse=False):
                    self.vfadam = MpiAdam(vf_var_list, sess=self.sess)
                    self.vf_phi_adam = MpiAdam(vf_phi_var_list, sess=self.sess)
                    self.vfadam.sync()
                    self.vf_phi_adam.sync()

                with tf.variable_scope("input_info", reuse=False):
                    tf.summary.scalar('discounted_rewards',
                                      tf.reduce_mean(ret))
                    tf.summary.scalar('learning_rate',
                                      tf.reduce_mean(self.vf_stepsize))
                    tf.summary.scalar('advantage', tf.reduce_mean(atarg))
                    tf.summary.scalar('kl_clip_range',
                                      tf.reduce_mean(self.max_kl))

                self.timed = timed
                self.allmean = allmean
                self.temp_seed = temp_seed

                self.step = self.policy_pi.step
                self.proba_step = self.policy_pi.proba_step
                self.initial_state = self.policy_pi.initial_state

                self.params = tf_util.get_trainable_vars(
                    "model") + tf_util.get_trainable_vars("oldpi")

                self.summary = tf.summary.merge_all()

                self.compute_lossandgrad = \
                    tf_util.function([observation, old_policy.obs_ph, action, atarg, ret],
                                     [self.summary, tf_util.flatgrad(optimgain, var_list)] + losses)

    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100,
              tb_log_name="TRPO",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        print("args", self.kappa, self.vf_phi_update_interval, seed)

        with SetVerbosity(self.verbose):
            #self._setup_learn(seed)

            with self.sess.as_default():
                seg_gen = traj_segment_generator(
                    self.policy_pi,
                    self.policy_phi_old,
                    self.env,
                    self.timesteps_per_batch,
                    self.gamma,
                    self.kappa,
                    reward_giver=self.reward_giver)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                len_buffer = deque(
                    maxlen=40)  # rolling buffer for episode lengths
                reward_buffer = deque(
                    maxlen=40)  # rolling buffer for episode rewards
                vf_loss_buffer = deque(maxlen=80)  # rolling buffer for vf loss
                vf_phi_loss_buffer = deque(
                    maxlen=80)  # rolling buffer for vf phi loss
                self.episode_reward = np.zeros((self.n_envs, ))

                true_reward_buffer = None

                while True:
                    if callback is not None:
                        # Only stop training if return value is False, not when it is None. This is for backwards
                        # compatibility with callbacks that have no return statement.
                        if callback(locals(), globals()) is False:
                            break
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    def fisher_vector_product(vec):
                        return self.allmean(
                            self.compute_fvp(
                                vec, *fvpargs,
                                sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for k in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()
                        add_vtarg_and_adv(seg, self.gamma, self.kappa)
                        #print("seg is", seg["tdlamret"], seg["tdlamret_phi"])
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        iters_val = int(
                            (2 * int(iters_so_far /
                                     (1 * self.vf_phi_update_interval)) + 1) *
                            (1 * self.vf_phi_update_interval) * 1 / 2)
                        print("iters_val", iters_val,
                              int(iters_so_far / self.vf_phi_update_interval),
                              self.vf_phi_update_interval)
                        vf_loss = []

                        #if iters_so_far < iters_val:
                        print("optimizing surrogate mdp...")
                        observation, action = seg["observations"], seg[
                            "actions"]
                        atarg, tdlamret = seg["adv"], seg["tdlamret"]

                        vpredbefore = seg[
                            "vpred"]  # predicted value function before update
                        atarg = (atarg - atarg.mean()) / atarg.std(
                        )  # standardized advantage function estimate

                        args = seg["observations"], seg["observations"], seg[
                            "actions"], atarg
                        # Subsampling: see p40-42 of John Schulman thesis
                        # http://joschu.net/docs/thesis.pdf
                        fvpargs = [arr[::5] for arr in args]

                        self.assign_old_eq_new(sess=self.sess)

                        with self.timed("computegrad"):
                            steps = self.num_timesteps + (k + 1) * (
                                seg["total_timestep"] / self.g_step)
                            run_options = tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE)
                            run_metadata = None
                            # run loss backprop with summary, and save the metadata (memory, compute time, ...)
                            _, grad, *lossbefore = self.compute_lossandgrad(
                                *args,
                                tdlamret,
                                sess=self.sess,
                                options=run_options,
                                run_metadata=run_metadata)

                        lossbefore = self.allmean(np.array(lossbefore))
                        grad = self.allmean(grad)
                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            with self.timed("conjugate_gradient"):
                                stepdir = conjugate_gradient(
                                    fisher_vector_product,
                                    grad,
                                    cg_iters=self.cg_iters,
                                    verbose=self.rank == 0
                                    and self.verbose >= 1)
                            assert np.isfinite(stepdir).all()
                            shs = .5 * stepdir.dot(
                                fisher_vector_product(stepdir))
                            # abs(shs) to avoid taking square root of negative values
                            lagrange_multiplier = np.sqrt(
                                abs(shs) / self.max_kl)
                            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                            fullstep = stepdir / lagrange_multiplier
                            expectedimprove = grad.dot(fullstep)
                            surrbefore = lossbefore[0]
                            stepsize = 1.0
                            thbefore = self.get_flat()
                            thnew = None
                            for _ in range(10):
                                thnew = thbefore + fullstep * stepsize
                                self.set_from_flat(thnew)
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(
                                        self.compute_losses(*args,
                                                            sess=self.sess)))
                                improve = surr - surrbefore
                                logger.log("Expected: %.3f Actual: %.3f" %
                                           (expectedimprove, improve))
                                if not np.isfinite(mean_losses).all():
                                    logger.log(
                                        "Got non-finite value of losses -- bad!"
                                    )
                                elif kl_loss > self.max_kl * 1.5:
                                    logger.log(
                                        "violated KL constraint. shrinking step."
                                    )
                                elif improve < 0:
                                    logger.log(
                                        "surrogate didn't improve. shrinking step."
                                    )
                                else:
                                    logger.log("Stepsize OK!")
                                    break
                                stepsize *= .5
                            else:
                                logger.log("couldn't compute a good step")
                                self.set_from_flat(thbefore)
                            if self.nworkers > 1 and iters_so_far % 20 == 0:
                                # list of tuples
                                paramsums = MPI.COMM_WORLD.allgather(
                                    (thnew.sum(), self.vfadam.getflat().sum()))
                                assert all(
                                    np.allclose(ps, paramsums[0])
                                    for ps in paramsums[1:])

                        with self.timed("vf"):
                            #vf_loss = []
                            for _ in range(self.vf_iters):
                                # NOTE: for recurrent policies, use shuffle=False?
                                for (mbob, mbret) in dataset.iterbatches(
                                    (seg["observations"], seg["tdlamret"]),
                                        include_final_partial_batch=False,
                                        batch_size=128,
                                        shuffle=True):
                                    grad = self.allmean(
                                        self.compute_vflossandgrad(
                                            mbob, mbob, mbret, sess=self.sess))
                                    self.vfadam.update(grad, self.vf_stepsize)
                                    vf_loss.append(
                                        self.compute_vf_loss(mbob,
                                                             mbob,
                                                             mbret,
                                                             sess=self.sess))

                        vf_phi_loss = []
                        #if iters_so_far >= iters_val: #self.vf_phi_update_interval == 0:
                        print("evaluating policy...")
                        with self.timed("vf_phi"):
                            #vf_phi_loss = []
                            for _ in range(self.vf_iters
                                           ):  #self.vf_phi_update_interval):
                                with self.temp_seed(self.phi_seed):
                                    for (mbob, mbret) in dataset.iterbatches(
                                        (seg["observations"],
                                         seg["tdlamret_phi"]),
                                            include_final_partial_batch=False,
                                            batch_size=128,
                                            shuffle=True):
                                        grad = self.allmean(
                                            self.compute_vf_phi_lossandgrad(
                                                mbob,
                                                mbob,
                                                mbret,
                                                sess=self.sess))
                                        #print("vf phi loss before", self.compute_vf_phi_loss(mbob, mbob, mbret, sess=self.sess))
                                        #print("vf loss before phi is updated", self.compute_vf_loss(mbob, mbob, mbret_vf, sess=self.sess))
                                        self.vf_phi_adam.update(
                                            grad, self.vf_stepsize)
                                        vf_phi_loss.append(
                                            self.compute_vf_phi_loss(
                                                mbob,
                                                mbob,
                                                mbret,
                                                sess=self.sess))
                                        #print("gradient value", grad)
                                        #print("vf phi loss after", self.compute_vf_phi_loss(mbob, mbob, mbret, sess=self.sess))
                                        #print("vf loss after phi is updated", self.compute_vf_loss(mbob, mbob, mbret_vf, sess=self.sess))
                                    self.phi_seed += 1

                        #print("vf phi old loss", self.compute_vf_phi_old_loss(self.phi_old_obs, sess=self.sess))
                        if iters_so_far % self.vf_phi_update_interval == 0:
                            with self.timed("vf_phi_old_update"):
                                self.update_phi_old(sess=self.sess)
                                #update_variables = set(self.policy_vars + self.oldpolicy_vars)
                                #self.sess.run(tf.variables_initializer(update_variables))

                        if iters_so_far % self.eval_freq == 0:
                            if self.test_env is not None:
                                self.eval_reward = self.evaluate_agent(
                                    self.test_env, self.eval_freq)

                    if mean_losses is None:
                        mean_losses = [None, None, None, None, None]
                    for (loss_name, loss_val) in zip(self.loss_names,
                                                     mean_losses):
                        logger.record_tabular(loss_name, loss_val)

                    if vpredbefore is not None:
                        logger.record_tabular(
                            "explained_variance_tdlam_before",
                            explained_variance(vpredbefore, tdlamret))
                    else:
                        logger.record_tabular(
                            "explained_variance_tdlam_before", None)

                    # lr: lengths and rewards
                    lr_local = (seg["ep_lens"], seg["ep_rets"])  # local values
                    list_lr_pairs = MPI.COMM_WORLD.allgather(
                        lr_local)  # list of tuples
                    lens, rews = map(flatten_lists, zip(*list_lr_pairs))
                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)
                    vf_loss_buffer.extend(vf_loss)
                    vf_phi_loss_buffer.extend(vf_phi_loss)

                    if len(len_buffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))
                        logger.record_tabular("EpRewMean",
                                              np.mean(reward_buffer))
                        logger.record_tabular("VfLoss",
                                              np.mean(vf_loss_buffer))
                        logger.record_tabular("VfPhiLoss",
                                              np.mean(vf_phi_loss_buffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(
                        seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)
                    logger.record_tabular("EvaluationScore", self.eval_reward)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        return self

    def evaluate_agent(self, test_env, test_episodes):

        step = 0
        games = 0
        total_rewards = 0.0
        reset = test_env.reset()

        while (games < test_episodes):
            terminal = False
            state = reset
            while not terminal:
                action, _, _, _ = self.policy_pi.step(state.reshape(
                    -1, *state.shape),
                                                      deterministic=True)
                state, reward, terminal, info = test_env.step(action[0])

                total_rewards += reward

                step += 1

            games += 1
            reset = test_env.reset()

        return total_rewards / games

    def save(self, save_path):
        data = {
            "gamma": self.gamma,
            "kappa": self.kappa,
            "timesteps_per_batch": self.timesteps_per_batch,
            "max_kl": self.max_kl,
            "cg_iters": self.cg_iters,
            "entcoeff": self.entcoeff,
            "cg_damping": self.cg_damping,
            "vf_stepsize": self.vf_stepsize,
            "vf_iters": self.vf_iters,
            "g_step": self.g_step,
            "verbose": self.verbose,
            "policy": self.policy,
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "n_envs": self.n_envs,
            "_vectorize_action": self._vectorize_action,
            "policy_kwargs": self.policy_kwargs,
            "eval_freq": self.eval_freq,
            "vf_phi_update_interval": self.vf_phi_update_interval
        }

        params_to_save = self.get_parameters()

        self._save_to_file(save_path, data=data, params=params_to_save)
Esempio n. 6
0
class TRPO(BaseRLModel):
    def __init__(self,
                 policy,
                 env,
                 gamma=0.99,
                 timesteps_per_batch=1024,
                 max_kl=0.01,
                 cg_iters=10,
                 lam=0.98,
                 entcoeff=0.0,
                 cg_damping=1e-2,
                 vf_stepsize=3e-4,
                 vf_iters=3,
                 verbose=0,
                 _init_setup_model=True):
        """
        learns a TRPO policy using the given environment

        :param policy: (function (str, Gym Space, Gym Space, bool): MLPPolicy) policy generator
        :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
        :param gamma: (float) the discount value
        :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon)
        :param max_kl: (float) the kullback leiber loss threshold
        :param cg_iters: (int) the number of iterations for the conjugate gradient calculation
        :param lam: (float) GAE factor
        :param entcoeff: (float) the weight for the entropy loss
        :param cg_damping: (float) the compute gradient dampening factor
        :param vf_stepsize: (float) the value function stepsize
        :param vf_iters: (int) the value function's number iterations for learning
        :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
        :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
        """
        super(TRPO, self).__init__(policy=policy,
                                   env=env,
                                   requires_vec_env=False,
                                   verbose=verbose)

        self.using_gail = False
        self.timesteps_per_batch = timesteps_per_batch
        self.cg_iters = cg_iters
        self.cg_damping = cg_damping
        self.gamma = gamma
        self.lam = lam
        self.max_kl = max_kl
        self.vf_iters = vf_iters
        self.vf_stepsize = vf_stepsize
        self.entcoeff = entcoeff

        # GAIL Params
        self.pretrained_weight = None
        self.hidden_size_adversary = 100
        self.adversary_entcoeff = 1e-3
        self.expert_dataset = None
        self.save_per_iter = 1
        self.checkpoint_dir = "/tmp/gail/ckpt/"
        self.g_step = 1
        self.d_step = 1
        self.task_name = "task_name"
        self.d_stepsize = 3e-4

        self.graph = None
        self.sess = None
        self.policy_pi = None
        self.loss_names = None
        self.assign_old_eq_new = None
        self.compute_losses = None
        self.compute_lossandgrad = None
        self.compute_fvp = None
        self.compute_vflossandgrad = None
        self.d_adam = None
        self.vfadam = None
        self.get_flat = None
        self.set_from_flat = None
        self.timed = None
        self.allmean = None
        self.nworkers = None
        self.rank = None
        self.reward_giver = None
        self.step = None
        self.proba_step = None
        self.initial_state = None
        self.params = None

        if _init_setup_model:
            self.setup_model()

    def setup_model(self):
        # prevent import loops
        from stable_baselines.gail.adversary import TransitionClassifier

        with SetVerbosity(self.verbose):

            self.nworkers = MPI.COMM_WORLD.Get_size()
            self.rank = MPI.COMM_WORLD.Get_rank()
            np.set_printoptions(precision=3)

            self.graph = tf.Graph()
            with self.graph.as_default():
                self.sess = tf_util.single_threaded_session(graph=self.graph)

                if self.using_gail:
                    self.reward_giver = TransitionClassifier(
                        self.env,
                        self.hidden_size_adversary,
                        entcoeff=self.adversary_entcoeff)

                # Construct network for new policy
                with tf.variable_scope("pi", reuse=False):
                    self.policy_pi = self.policy(self.sess,
                                                 self.observation_space,
                                                 self.action_space,
                                                 self.n_envs,
                                                 1,
                                                 None,
                                                 reuse=False)

                # Network for old policy
                with tf.variable_scope("oldpi", reuse=False):
                    old_policy = self.policy(self.sess,
                                             self.observation_space,
                                             self.action_space,
                                             self.n_envs,
                                             1,
                                             None,
                                             reuse=False)

                atarg = tf.placeholder(
                    dtype=tf.float32,
                    shape=[None])  # Target advantage function (if applicable)
                ret = tf.placeholder(dtype=tf.float32,
                                     shape=[None])  # Empirical return

                observation = self.policy_pi.obs_ph
                action = self.policy_pi.pdtype.sample_placeholder([None])

                kloldnew = old_policy.proba_distribution.kl(
                    self.policy_pi.proba_distribution)
                ent = self.policy_pi.proba_distribution.entropy()
                meankl = tf.reduce_mean(kloldnew)
                meanent = tf.reduce_mean(ent)
                entbonus = self.entcoeff * meanent

                vferr = tf.reduce_mean(
                    tf.square(self.policy_pi.value_fn[:, 0] - ret))

                # advantage * pnew / pold
                ratio = tf.exp(
                    self.policy_pi.proba_distribution.logp(action) -
                    old_policy.proba_distribution.logp(action))
                surrgain = tf.reduce_mean(ratio * atarg)

                optimgain = surrgain + entbonus
                losses = [optimgain, meankl, entbonus, surrgain, meanent]
                self.loss_names = [
                    "optimgain", "meankl", "entloss", "surrgain", "entropy"
                ]

                dist = meankl

                all_var_list = tf_util.get_trainable_vars("pi")
                var_list = [
                    v for v in all_var_list
                    if "/vf" not in v.name and "/q/" not in v.name
                ]
                vf_var_list = [
                    v for v in all_var_list
                    if "/pi" not in v.name and "/logstd" not in v.name
                ]
                self.vfadam = MpiAdam(vf_var_list, sess=self.sess)
                self.get_flat = tf_util.GetFlat(var_list, sess=self.sess)
                self.set_from_flat = tf_util.SetFromFlat(var_list,
                                                         sess=self.sess)

                if self.using_gail:
                    self.d_adam = MpiAdam(
                        self.reward_giver.get_trainable_variables())

                klgrads = tf.gradients(dist, var_list)
                flat_tangent = tf.placeholder(dtype=tf.float32,
                                              shape=[None],
                                              name="flat_tan")
                shapes = [var.get_shape().as_list() for var in var_list]
                start = 0
                tangents = []
                for shape in shapes:
                    var_size = tf_util.intprod(shape)
                    tangents.append(
                        tf.reshape(flat_tangent[start:start + var_size],
                                   shape))
                    start += var_size
                gvp = tf.add_n([
                    tf.reduce_sum(grad * tangent)
                    for (grad, tangent) in zipsame(klgrads, tangents)
                ])  # pylint: disable=E1111
                fvp = tf_util.flatgrad(gvp, var_list)

                self.assign_old_eq_new = tf_util.function(
                    [], [],
                    updates=[
                        tf.assign(oldv, newv) for (
                            oldv,
                            newv) in zipsame(tf_util.get_globals_vars("oldpi"),
                                             tf_util.get_globals_vars("pi"))
                    ])
                self.compute_losses = tf_util.function(
                    [observation, old_policy.obs_ph, action, atarg], losses)
                self.compute_lossandgrad = tf_util.function(
                    [observation, old_policy.obs_ph, action, atarg],
                    losses + [tf_util.flatgrad(optimgain, var_list)])
                self.compute_fvp = tf_util.function([
                    flat_tangent, observation, old_policy.obs_ph, action, atarg
                ], fvp)
                self.compute_vflossandgrad = tf_util.function(
                    [observation, old_policy.obs_ph, ret],
                    tf_util.flatgrad(vferr, vf_var_list))

                @contextmanager
                def timed(msg):
                    if self.rank == 0 and self.verbose >= 1:
                        print(colorize(msg, color='magenta'))
                        start_time = time.time()
                        yield
                        print(
                            colorize("done in %.3f seconds" %
                                     (time.time() - start_time),
                                     color='magenta'))
                    else:
                        yield

                def allmean(arr):
                    assert isinstance(arr, np.ndarray)
                    out = np.empty_like(arr)
                    MPI.COMM_WORLD.Allreduce(arr, out, op=MPI.SUM)
                    out /= self.nworkers
                    return out

                tf_util.initialize(sess=self.sess)

                th_init = self.get_flat()
                MPI.COMM_WORLD.Bcast(th_init, root=0)
                self.set_from_flat(th_init)

                if self.using_gail:
                    self.d_adam.sync()
                self.vfadam.sync()

                self.timed = timed
                self.allmean = allmean

                self.step = self.policy_pi.step
                self.proba_step = self.policy_pi.proba_step
                self.initial_state = self.policy_pi.initial_state

                self.params = find_trainable_variables("pi")
                if self.using_gail:
                    self.params.extend(
                        self.reward_giver.get_trainable_variables())

    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100):
        with SetVerbosity(self.verbose):
            self._setup_learn(seed)

            with self.sess.as_default():
                seg_gen = traj_segment_generator(
                    self.policy_pi,
                    self.env,
                    self.timesteps_per_batch,
                    reward_giver=self.reward_giver,
                    gail=self.using_gail)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                lenbuffer = deque(
                    maxlen=40)  # rolling buffer for episode lengths
                rewbuffer = deque(
                    maxlen=40)  # rolling buffer for episode rewards

                true_rewbuffer = None
                if self.using_gail:
                    true_rewbuffer = deque(maxlen=40)
                    #  Stats not used for now
                    #  g_loss_stats = Stats(loss_names)
                    #  d_loss_stats = Stats(reward_giver.loss_name)
                    #  ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"])

                    # if provide pretrained weight
                    if self.pretrained_weight is not None:
                        tf_util.load_state(
                            self.pretrained_weight,
                            var_list=tf_util.get_globals_vars("pi"),
                            sess=self.sess)

                while True:
                    if callback:
                        callback(locals(), globals())
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    def fisher_vector_product(vec):
                        return self.allmean(
                            self.compute_fvp(
                                vec, *fvpargs,
                                sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for _ in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()
                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        observation, action, atarg, tdlamret = seg["ob"], seg[
                            "ac"], seg["adv"], seg["tdlamret"]
                        vpredbefore = seg[
                            "vpred"]  # predicted value function before udpate
                        atarg = (atarg - atarg.mean()) / atarg.std(
                        )  # standardized advantage function estimate

                        args = seg["ob"], seg["ob"], seg["ac"], atarg
                        fvpargs = [arr[::5] for arr in args]

                        self.assign_old_eq_new(sess=self.sess)

                        with self.timed("computegrad"):
                            *lossbefore, grad = self.compute_lossandgrad(
                                *args, sess=self.sess)
                        lossbefore = self.allmean(np.array(lossbefore))
                        grad = self.allmean(grad)
                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            with self.timed("cg"):
                                stepdir = conjugate_gradient(
                                    fisher_vector_product,
                                    grad,
                                    cg_iters=self.cg_iters,
                                    verbose=self.rank == 0
                                    and self.verbose >= 1)
                            assert np.isfinite(stepdir).all()
                            shs = .5 * stepdir.dot(
                                fisher_vector_product(stepdir))
                            # abs(shs) to avoid taking square root of negative values
                            lagrange_multiplier = np.sqrt(
                                abs(shs) / self.max_kl)
                            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                            fullstep = stepdir / lagrange_multiplier
                            expectedimprove = grad.dot(fullstep)
                            surrbefore = lossbefore[0]
                            stepsize = 1.0
                            thbefore = self.get_flat()
                            thnew = None
                            for _ in range(10):
                                thnew = thbefore + fullstep * stepsize
                                self.set_from_flat(thnew)
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(
                                        self.compute_losses(*args,
                                                            sess=self.sess)))
                                improve = surr - surrbefore
                                logger.log("Expected: %.3f Actual: %.3f" %
                                           (expectedimprove, improve))
                                if not np.isfinite(mean_losses).all():
                                    logger.log(
                                        "Got non-finite value of losses -- bad!"
                                    )
                                elif kl_loss > self.max_kl * 1.5:
                                    logger.log(
                                        "violated KL constraint. shrinking step."
                                    )
                                elif improve < 0:
                                    logger.log(
                                        "surrogate didn't improve. shrinking step."
                                    )
                                else:
                                    logger.log("Stepsize OK!")
                                    break
                                stepsize *= .5
                            else:
                                logger.log("couldn't compute a good step")
                                self.set_from_flat(thbefore)
                            if self.nworkers > 1 and iters_so_far % 20 == 0:
                                # list of tuples
                                paramsums = MPI.COMM_WORLD.allgather(
                                    (thnew.sum(), self.vfadam.getflat().sum()))
                                assert all(
                                    np.allclose(ps, paramsums[0])
                                    for ps in paramsums[1:])

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                for (mbob, mbret) in dataset.iterbatches(
                                    (seg["ob"], seg["tdlamret"]),
                                        include_final_partial_batch=False,
                                        batch_size=128):
                                    grad = self.allmean(
                                        self.compute_vflossandgrad(
                                            mbob, mbob, mbret, sess=self.sess))
                                    self.vfadam.update(grad, self.vf_stepsize)

                    for (loss_name, loss_val) in zip(self.loss_names,
                                                     mean_losses):
                        logger.record_tabular(loss_name, loss_val)

                    logger.record_tabular(
                        "ev_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    if self.using_gail:
                        # ------------------ Update D ------------------
                        logger.log("Optimizing Discriminator...")
                        logger.log(fmt_row(13, self.reward_giver.loss_name))
                        ob_expert, ac_expert = self.expert_dataset.get_next_batch(
                            len(observation))
                        batch_size = len(observation) // self.d_step
                        d_losses = [
                        ]  # list of tuples, each of which gives the loss for a minibatch
                        for ob_batch, ac_batch in dataset.iterbatches(
                            (observation, action),
                                include_final_partial_batch=False,
                                batch_size=batch_size):
                            ob_expert, ac_expert = self.expert_dataset.get_next_batch(
                                len(ob_batch))
                            # update running mean/std for reward_giver
                            if hasattr(self.reward_giver, "obs_rms"):
                                self.reward_giver.obs_rms.update(
                                    np.concatenate((ob_batch, ob_expert), 0))
                            *newlosses, grad = self.reward_giver.lossandgrad(
                                ob_batch, ac_batch, ob_expert, ac_expert)
                            self.d_adam.update(self.allmean(grad),
                                               self.d_stepsize)
                            d_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                        lrlocal = (seg["ep_lens"], seg["ep_rets"],
                                   seg["ep_true_rets"])  # local values
                        listoflrpairs = MPI.COMM_WORLD.allgather(
                            lrlocal)  # list of tuples
                        lens, rews, true_rets = map(flatten_lists,
                                                    zip(*listoflrpairs))
                        true_rewbuffer.extend(true_rets)
                    else:
                        lrlocal = (seg["ep_lens"], seg["ep_rets"]
                                   )  # local values
                        listoflrpairs = MPI.COMM_WORLD.allgather(
                            lrlocal)  # list of tuples
                        lens, rews = map(flatten_lists, zip(*listoflrpairs))
                    lenbuffer.extend(lens)
                    rewbuffer.extend(rews)

                    logger.record_tabular("EpLenMean", np.mean(lenbuffer))
                    logger.record_tabular("EpRewMean", np.mean(rewbuffer))
                    if self.using_gail:
                        logger.record_tabular("EpTrueRewMean",
                                              np.mean(true_rewbuffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    timesteps_so_far += seg["total_timestep"]
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", timesteps_so_far)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        return self

    def predict(self, observation, state=None, mask=None):
        if state is None:
            state = self.initial_state
        if mask is None:
            mask = [False for _ in range(self.n_envs)]
        observation = np.array(observation).reshape(
            (-1, ) + self.observation_space.shape)

        actions, _, states, _ = self.step(observation, state, mask)
        return actions, states

    def action_probability(self, observation, state=None, mask=None):
        if state is None:
            state = self.initial_state
        if mask is None:
            mask = [False for _ in range(self.n_envs)]
        observation = np.array(observation).reshape(
            (-1, ) + self.observation_space.shape)

        return self.proba_step(observation, state, mask)

    def save(self, save_path):
        data = {
            "gamma": self.gamma,
            "timesteps_per_batch": self.timesteps_per_batch,
            "max_kl": self.max_kl,
            "cg_iters": self.cg_iters,
            "lam": self.lam,
            "entcoeff": self.entcoeff,
            "cg_damping": self.cg_damping,
            "vf_stepsize": self.vf_stepsize,
            "vf_iters": self.vf_iters,
            "pretrained_weight": self.pretrained_weight,
            "reward_giver": self.reward_giver,
            "expert_dataset": self.expert_dataset,
            "save_per_iter": self.save_per_iter,
            "checkpoint_dir": self.checkpoint_dir,
            "g_step": self.g_step,
            "d_step": self.d_step,
            "task_name": self.task_name,
            "d_stepsize": self.d_stepsize,
            "using_gail": self.using_gail,
            "verbose": self.verbose,
            "policy": self.policy,
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "n_envs": self.n_envs,
            "_vectorize_action": self._vectorize_action
        }

        params = self.sess.run(self.params)

        self._save_to_file(save_path, data=data, params=params)

    @classmethod
    def load(cls, load_path, env=None, **kwargs):
        data, params = cls._load_from_file(load_path)

        model = cls(policy=data["policy"], env=None, _init_setup_model=False)
        model.__dict__.update(data)
        model.__dict__.update(kwargs)
        model.set_env(env)
        model.setup_model()

        restores = []
        for param, loaded_p in zip(model.params, params):
            restores.append(param.assign(loaded_p))
        model.sess.run(restores)

        return model
Esempio n. 7
0
class DDPG(object):
    def __init__(self,
                 rank,
                 batch_size,
                 priority,
                 use_n_step=False,
                 n_step_return=5,
                 LAMBDA_BC=100,
                 LAMBDA_predict=0.5,
                 policy_delay=1,
                 use_TD3=False,
                 experiment_name='none',
                 Q_value_range=(-250, 250),
                 **kwargs):

        self.batch_size = batch_size
        self.use_prioritiy = priority
        self.n_step_return = n_step_return
        self.use_n_step = use_n_step
        self.LAMBDA_BC = LAMBDA_BC
        self.LAMBDA_predict = LAMBDA_predict
        self.use_TD3 = use_TD3
        self.experiment_name = experiment_name
        self.Q_value_range = Q_value_range  # 限制q的范围,防止过估计.
        self.demo_percent = []  # demo 在 sample中所占比例

        self.pointer = 0  # memory 计数器
        self.num_timesteps = 0  # steps for tensorboard
        self.lambda_1_step = 0.5  # 1_step_return_loss的权重
        self.lambda_n_step = 0.5  # n_step_return_loss的权重
        self.beta = 0.6

        # actor 比 critic 更新频率小
        self.policy_delay_iterate = 0
        self.policy_delay = policy_delay

        self._setup_model(rank, **kwargs)

    def _setup_model(self, rank, memory_size, alpha, obs_space, a_space,
                     noise_target_action, **kwargs):

        self.graph = tf.Graph()
        with self.graph.as_default():
            self.sess = tf_util.single_threaded_session(graph=self.graph)
            if self.use_prioritiy:
                from .priority_memory import PrioritizedMemory
                self.memory = PrioritizedMemory(capacity=memory_size,
                                                alpha=alpha)
            else:
                from .memory import Memory
                self.memory = Memory(limit=memory_size,
                                     action_shape=a_space.shape,
                                     observation_shape=obs_space.shape)
            # 定义 placeholders
            self.observe_Input = tf.placeholder(tf.float32,
                                                [None] + list(obs_space.shape),
                                                name='observe_Input')
            self.observe_Input_ = tf.placeholder(tf.float32, [None] +
                                                 list(obs_space.shape),
                                                 name='observe_Input_')
            self.R = tf.placeholder(tf.float32, [None, 1], 'r')
            self.terminals1 = tf.placeholder(tf.float32,
                                             shape=(None, 1),
                                             name='terminals1')
            self.ISWeights = tf.placeholder(tf.float32, [None, 1],
                                            name='IS_weights')
            self.n_step_steps = tf.placeholder(tf.float32,
                                               shape=(None, 1),
                                               name='n_step_reached')
            self.q_demo = tf.placeholder(tf.float32, [None, 1],
                                         name='Q_of_actions_from_memory')
            self.come_from_demo = tf.placeholder(tf.float32, [None, 1],
                                                 name='Demo_index')
            self.action_memory = tf.placeholder(tf.float32,
                                                [None] + list(a_space.shape),
                                                name='actions_from_memory')

            with tf.variable_scope('obs_rms'):
                self.obs_rms = RunningMeanStd(shape=obs_space.shape)

            with tf.name_scope('obs_preprocess'):
                self.normalized_observe_Input = tf.clip_by_value(
                    normalize(self.observe_Input, self.obs_rms), -5., 5.)
                self.normalized_observe_Input_ = tf.clip_by_value(
                    normalize(self.observe_Input_, self.obs_rms), -5., 5.)

            with tf.variable_scope('Actor'):
                self.action = self.build_actor(self.normalized_observe_Input,
                                               scope='eval',
                                               trainable=True,
                                               a_space=a_space)
                self.action_ = self.build_actor(self.normalized_observe_Input_,
                                                scope='target',
                                                trainable=False,
                                                a_space=a_space)

                # Target policy smoothing, by adding clipped noise to target actions
                if noise_target_action:
                    epsilon = tf.random_normal(tf.shape(self.action_),
                                               stddev=0.007)
                    epsilon = tf.clip_by_value(epsilon, -0.01, 0.01)
                    a2 = self.action_ + epsilon
                    noised_action_ = tf.clip_by_value(a2, -1, 1)
                else:
                    noised_action_ = self.action_

            with tf.variable_scope('Critic'):
                # Q值都要被clip 防止过估计.
                self.q_1 = tf.clip_by_value(
                    self.build_critic(self.normalized_observe_Input,
                                      self.action,
                                      scope='eval_1',
                                      trainable=True), self.Q_value_range[0],
                    self.Q_value_range[1])

                q_1_ = self.build_critic(self.normalized_observe_Input_,
                                         noised_action_,
                                         scope='target_1',
                                         trainable=False)

                if self.use_TD3:
                    q_2 = tf.clip_by_value(
                        self.build_critic(self.normalized_observe_Input,
                                          self.action,
                                          scope='eval_2',
                                          trainable=True),
                        self.Q_value_range[0], self.Q_value_range[1])

                    q_2_ = self.build_critic(self.normalized_observe_Input_,
                                             noised_action_,
                                             scope='target_2',
                                             trainable=False)

            # Collect networks parameters. It would make it more easily to manage them.
            self.ae_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                               scope='Actor/eval')
            self.at_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                               scope='Actor/target')
            self.ce1_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                scope='Critic/eval_1')
            self.ct1_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                scope='Critic/target_1')

            if self.use_TD3:
                self.ce2_params = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/eval_2')
                self.ct2_params = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/target_2')

            with tf.variable_scope('Soft_Update'):
                self.soft_replace_a = [
                    tf.assign(t, (1 - TAU) * t + TAU * e)
                    for t, e in zip(self.at_params, self.ae_params)
                ]
                self.soft_replace_c = [
                    tf.assign(t, (1 - TAU) * t + TAU * e)
                    for t, e in zip(self.ct1_params, self.ce1_params)
                ]
                if self.use_TD3:
                    self.soft_replace_c += [
                        tf.assign(t, (1 - TAU) * t + TAU * e)
                        for t, e in zip(self.ct2_params, self.ce2_params)
                    ]

            # critic 的误差 为 (one-step-td 误差 + n-step-td 误差 + critic_online 的L2惩罚)
            # TD3: critic一共有4个, 算两套 critic的误差, 秀儿.
            with tf.variable_scope('Critic_Lose'):
                if self.use_TD3:
                    min_q_ = tf.minimum(q_1_, q_2_)
                else:
                    min_q_ = q_1_

                self.q_target = self.R + (1. -
                                          self.terminals1) * GAMMA * min_q_
                if self.use_n_step:
                    self.n_step_target_q = self.R + (
                        1. - self.terminals1) * tf.pow(
                            GAMMA, self.n_step_steps) * min_q_
                    cliped_n_step_target_q = tf.clip_by_value(
                        self.n_step_target_q, self.Q_value_range[0],
                        self.Q_value_range[1])

                cliped_q_target = tf.clip_by_value(self.q_target,
                                                   self.Q_value_range[0],
                                                   self.Q_value_range[1])

                self.td_error_1 = tf.abs(cliped_q_target - self.q_1)
                if self.use_TD3:
                    self.td_error_2 = tf.abs(cliped_q_target - q_2)

                if self.use_n_step:
                    self.nstep_td_error_1 = tf.abs(cliped_n_step_target_q -
                                                   self.q_1)
                    if self.use_TD3:
                        self.nstep_td_error_2 = tf.abs(cliped_n_step_target_q -
                                                       q_2)

                L2_regular_1 = tf.contrib.layers.apply_regularization(
                    tf.contrib.layers.l2_regularizer(0.001),
                    weights_list=self.ce1_params)
                if self.use_TD3:
                    L2_regular_2 = tf.contrib.layers.apply_regularization(
                        tf.contrib.layers.l2_regularizer(0.001),
                        weights_list=self.ce2_params)

                one_step_losse_1 = tf.reduce_mean(
                    tf.multiply(self.ISWeights, tf.square(
                        self.td_error_1))) * self.lambda_1_step
                if self.use_TD3:
                    one_step_losse_2 = tf.reduce_mean(
                        tf.multiply(self.ISWeights, tf.square(
                            self.td_error_2))) * self.lambda_1_step

                if self.use_n_step:
                    n_step_td_losses_1 = tf.reduce_mean(
                        tf.multiply(
                            self.ISWeights, tf.square(
                                self.nstep_td_error_1))) * self.lambda_n_step
                    c_loss_1 = one_step_losse_1 + n_step_td_losses_1 + L2_regular_1

                    if self.use_TD3:
                        n_step_td_losses_2 = tf.reduce_mean(
                            tf.multiply(self.ISWeights,
                                        tf.square(self.nstep_td_error_2))
                        ) * self.lambda_n_step
                        c_loss_2 = one_step_losse_2 + n_step_td_losses_2 + L2_regular_2
                else:
                    c_loss_1 = one_step_losse_1 + L2_regular_1

                    if self.use_TD3:
                        c_loss_2 = one_step_losse_2 + L2_regular_2

            # actor 的 loss 为 最大化q(s,a) 最小化行为克隆误差.
            # (只有demo的transition 且 demo的action 比 actor生成的action q_1(s,a)高的时候 才会有克隆误差)
            with tf.variable_scope('Actor_lose'):
                Is_worse_than_demo = self.q_1 < self.q_demo
                Is_worse_than_demo = tf.cast(Is_worse_than_demo, tf.float32)
                worse_than_demo = tf.cast(tf.reduce_sum(Is_worse_than_demo),
                                          tf.int8)

                # 算action误差 我用的是平方和, 也有人用均方误差 reduce_mean. 其实都可以.
                # 我的action本来都是很小的数.
                action_diffs = Is_worse_than_demo * tf.reduce_sum(
                    self.come_from_demo *
                    tf.square(self.action - self.action_memory),
                    1,
                    keepdims=True)

                L_BC = self.LAMBDA_BC * tf.reduce_sum(action_diffs)
                a_loss = -tf.reduce_mean(self.q_1) + L_BC

            # Setting optimizer for Actor and Critic
            with tf.variable_scope('Critic_Optimizer'):
                if self.use_TD3:
                    self.critic_grads_1 = tf_util.flatgrad(
                        loss=c_loss_1, var_list=self.ce1_params)
                    self.critic_grads_2 = tf_util.flatgrad(
                        loss=c_loss_2, var_list=self.ce2_params)

                    self.critic_optimizer_1 = MpiAdam(var_list=self.ce1_params,
                                                      beta1=0.9,
                                                      beta2=0.999,
                                                      epsilon=1e-08)
                    self.critic_optimizer_2 = MpiAdam(var_list=self.ce2_params,
                                                      beta1=0.9,
                                                      beta2=0.999,
                                                      epsilon=1e-08)
                else:
                    self.critic_grads = tf_util.flatgrad(
                        loss=c_loss_1, var_list=self.ce1_params)
                    self.critic_optimizer = MpiAdam(var_list=self.ce1_params,
                                                    beta1=0.9,
                                                    beta2=0.999,
                                                    epsilon=1e-08)
            with tf.variable_scope('Actor_Optimizer'):
                self.actor_grads = tf_util.flatgrad(a_loss, self.ae_params)
                self.actor_optimizer = MpiAdam(var_list=self.ae_params,
                                               beta1=0.9,
                                               beta2=0.999,
                                               epsilon=1e-08)
            with self.sess.as_default():
                self._initialize(self.sess)

            # 保存模型
            var_list = tf.global_variables()
            print(
                "var_list!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
            )
            for v in var_list:
                print(v)
            self.saver = tf.train.Saver(var_list=var_list, max_to_keep=1)
            self.writer = tf.summary.FileWriter(
                "logs/" + self.experiment_name + "/DDPG_" + str(rank),
                self.graph)
            # TensorBoard summary
            self.a_summary = tf.summary.merge([
                tf.summary.scalar('a_loss', a_loss, family='actor'),
                tf.summary.scalar('L_BC', L_BC, family='actor'),
                tf.summary.scalar('worse_than_demo',
                                  worse_than_demo,
                                  family='actor')
            ])

            if self.use_TD3:
                self.c_summary = tf.summary.merge([
                    tf.summary.scalar('c_loss_1', c_loss_1, family='critic'),
                    tf.summary.scalar('c_loss_2', c_loss_2, family='critic')
                ])
            else:
                self.c_summary = tf.summary.merge(
                    [tf.summary.scalar('c_loss_1', c_loss_1, family='critic')])

            # episode summary
            self.episode_cumulate_reward = tf.placeholder(
                tf.float32, name='episode_cumulate_reward')
            self.episoed_length = tf.placeholder(
                tf.int16, name='episode_cumulate_reward')
            self.success_or_not = tf.placeholder(
                tf.int8, name='episode_cumulate_reward')

            self.eval_episode_cumulate_reward = tf.placeholder(
                tf.float32, name='episode_cumulate_reward')
            self.eval_episoed_length = tf.placeholder(
                tf.int16, name='episode_cumulate_reward')
            self.eval_success_or_not = tf.placeholder(
                tf.int8, name='episode_cumulate_reward')

            self.episode_summary = tf.summary.merge([
                tf.summary.scalar('episode_cumulate_reward',
                                  self.episode_cumulate_reward,
                                  family='episoed_result'),
                tf.summary.scalar('episoed_length',
                                  self.episoed_length,
                                  family='episoed_result'),
                tf.summary.scalar('success_or_not',
                                  self.success_or_not,
                                  family='episoed_result'),
            ])

            self.eval_episode_summary = tf.summary.merge([
                tf.summary.scalar('eval_episode_cumulate_reward',
                                  self.eval_episode_cumulate_reward,
                                  family='Eval_episoed_result'),
                tf.summary.scalar('eval_episoed_length',
                                  self.eval_episoed_length,
                                  family='Eval_episoed_result'),
                tf.summary.scalar('eval_success_or_not',
                                  self.eval_success_or_not,
                                  family='Eval_episoed_result'),
            ])

    def _initialize(self, sess):
        """
        initialize the model parameters and optimizers

        :param sess: (TensorFlow Session) the current TensorFlow session
        """
        self.sess = sess
        self.sess.run(tf.global_variables_initializer())
        self.actor_optimizer.sync()
        self.critic_optimizer.sync()

        #  init_target net-work with evaluate net-params
        init_a_t = [
            tf.assign(t, e) for t, e in zip(self.at_params, self.ae_params)
        ]
        init_c_t = [
            tf.assign(t, e) for t, e in zip(self.ct1_params, self.ce1_params)
        ]
        if self.use_TD3:
            init_c_t += [
                tf.assign(t, e)
                for t, e in zip(self.ct2_params, self.ce2_params)
            ]
        self.sess.run([init_a_t, init_c_t])

    def pi(self, obs):
        obs = obs.astype(dtype=np.float32)
        return self.sess.run(self.action,
                             {self.observe_Input: obs[np.newaxis, :]})[0]

    def Save(self):
        # 只存权重,不存计算图.
        self.saver.save(self.sess,
                        save_path="model/" + self.experiment_name +
                        "/model.ckpt")

    def load(self):
        self.saver.restore(self.sess,
                           save_path="model/" + self.experiment_name +
                           "/model.ckpt")

    def _init_num_timesteps(self, reset_num_timesteps=True):
        """
        Initialize and resets num_timesteps (total timesteps since beginning of training)
        if needed. Mainly used logging and plotting (tensorboard).

        :param reset_num_timesteps: (bool) Set it to false when continuing training
            to not create new plotting curves in tensorboard.
        :return: (bool) Whether a new tensorboard log needs to be created
        """
        if reset_num_timesteps:
            self.num_timesteps = 0

        new_tb_log = self.num_timesteps == 0
        return new_tb_log

    def save_episoed_result(self, epi_cumulate_reward, episoed_length,
                            success_or_not, episodes):
        s = self.sess.run(self.episode_summary,
                          feed_dict={
                              self.episode_cumulate_reward:
                              epi_cumulate_reward,
                              self.episoed_length: episoed_length,
                              self.success_or_not: success_or_not
                          })

        self.writer.add_summary(s, global_step=int(episodes))

    def save_eval_episoed_result(self, eval_epi_reward, eval_episoed_length,
                                 eval_success_or_not, eval_episodes):
        eval_s = self.sess.run(self.eval_episode_summary,
                               feed_dict={
                                   self.eval_episode_cumulate_reward:
                                   eval_epi_reward,
                                   self.eval_episoed_length:
                                   eval_episoed_length,
                                   self.eval_success_or_not:
                                   eval_success_or_not
                               })

        self.writer.add_summary(eval_s, global_step=int(eval_episodes))

    def learn(self, learn_step):
        if self.use_prioritiy:
            batch, n_step_batch, percentage = self.memory.sample_rollout(
                batch_size=self.batch_size,
                nsteps=self.n_step_return,
                beta=self.beta,
                gamma=GAMMA)
            self.demo_percent.append(float(percentage))
        else:
            batch = self.memory.sample(self.batch_size)

        one_step_target_q = self.sess.run(
            self.q_target,
            feed_dict={
                self.observe_Input_: batch['obs1'],  # low dim input
                self.R: batch['rewards'],
                self.terminals1: batch['terminals1']
            })

        if self.use_TD3:
            opt = [
                self.td_error_1, self.td_error_2, self.critic_grads_1,
                self.critic_grads_2, self.c_summary, self.q_1
            ]
        else:
            opt = [
                self.td_error_1, self.critic_grads, self.c_summary, self.q_1
            ]

        if self.use_prioritiy and self.use_n_step:
            n_step_target_q = self.sess.run(self.n_step_target_q,
                                            feed_dict={
                                                self.terminals1:
                                                n_step_batch["terminals1"],
                                                self.n_step_steps:
                                                n_step_batch["step_reached"],
                                                self.R:
                                                n_step_batch['rewards'],
                                                self.observe_Input_:
                                                n_step_batch['obs1']
                                            })
            res = self.sess.run(opt,
                                feed_dict={
                                    self.observe_Input: batch['obs0'],
                                    self.q_target: one_step_target_q,
                                    self.n_step_target_q: n_step_target_q,
                                    self.action: batch['actions'],
                                    self.ISWeights: batch['weights']
                                })
        else:
            res = self.sess.run(opt,
                                feed_dict={
                                    self.observe_Input: batch['obs0'],
                                    self.q_target: one_step_target_q,
                                    self.action: batch['actions'],
                                    self.ISWeights: batch['weights']
                                })

        # critic update
        if self.use_TD3:
            td_error_1, td_error_2, critic_grads_1, critic_grads_2, c_s, q_demo = res
            td_error = (td_error_1 + td_error_2) / 2.0

            self.critic_optimizer_1.update(critic_grads_1, learning_rate=LR_C)
            self.critic_optimizer_1.update(critic_grads_1, learning_rate=LR_C)
        else:
            td_error, critic_grads, c_s, q_demo = res
            self.critic_optimizer.update(critic_grads, LR_C)
        self.sess.run(self.soft_replace_c)

        # actor update
        if self.policy_delay_iterate % self.policy_delay == 0:
            actor_grads, a_s, = self.sess.run(
                [self.actor_grads, self.a_summary], {
                    self.observe_Input: batch['obs0'],
                    self.q_demo: q_demo,
                    self.come_from_demo: batch['demos'],
                    self.action_memory: batch['actions']
                })

            self.actor_optimizer.update(actor_grads, LR_A)
            self.sess.run(self.soft_replace_a)
            self.writer.add_summary(a_s, learn_step)

        # update priority
        if self.use_prioritiy:
            self.memory.update_priorities(batch['idxes'], td_error)

        self.writer.add_summary(c_s, learn_step)
        self.policy_delay_iterate += 1

    def store_transition(self,
                         obs0,
                         action,
                         reward,
                         obs1,
                         terminal1,
                         demo=False):
        obs0 = obs0.astype(np.float32)
        obs1 = obs1.astype(np.float32)
        if demo:
            self.memory.append_demo(obs0=obs0,
                                    action=action,
                                    reward=reward,
                                    obs1=obs1,
                                    terminal1=terminal1)
        else:
            self.memory.append(obs0=obs0,
                               action=action,
                               reward=reward,
                               obs1=obs1,
                               terminal1=terminal1)

        # 增量式的更新observe的 mean, std
        self.obs_rms.update(np.array([obs0]))
        self.obs_rms.update(np.array([obs1]))

        self.pointer += 1

    def build_actor(self, observe_input, scope, trainable, a_space):
        fc_a = partial(tf.layers.dense, activation=None, trainable=trainable)
        conv2_a = partial(conv2_, trainable=trainable)
        relu = partial(tf.nn.relu)

        with tf.variable_scope(scope):
            net = tf.layers.conv2d(observe_input,
                                   filters=64,
                                   kernel_size=3,
                                   activation=tf.nn.relu,
                                   strides=2,
                                   padding='valid',
                                   name='conv2_1',
                                   trainable=trainable)
            net = tf.layers.max_pooling2d(net, pool_size=2, strides=2)
            net = tf.layers.conv2d(net,
                                   filters=128,
                                   kernel_size=4,
                                   activation=tf.nn.relu,
                                   trainable=trainable,
                                   strides=1,
                                   padding='valid',
                                   name='conv2_2')
            net = tf.layers.max_pooling2d(net, pool_size=2, strides=2)
            net = tf.layers.conv2d(net,
                                   filters=64,
                                   kernel_size=3,
                                   name='conv2_3',
                                   activation=tf.nn.relu,
                                   trainable=trainable,
                                   strides=2,
                                   padding='valid')
            net_max = tf.layers.max_pooling2d(net, pool_size=2, strides=2)

            net = tf.layers.conv2d(net_max,
                                   filters=96,
                                   kernel_size=2,
                                   name='conv2_4',
                                   activation=tf.nn.relu,
                                   trainable=trainable,
                                   strides=2,
                                   padding='valid')

            net = tf.layers.flatten(net, name='cnn_flatten')

            # 合起来
            net = tf.concat([net, tf.layers.flatten(net_max)], axis=1)
            net = tf.layers.flatten(net)
            # conv -> relu

            net = relu(fc_a(net, 128))
            action_output = fc_a(
                net,
                a_space.shape[0],
                activation=tf.nn.tanh,
                kernel_initializer=tf.initializers.random_uniform(
                    minval=-0.0003, maxval=0.0003))
            #输出(1,4)

            return action_output

    def build_critic(self, observe_input, a, scope, trainable):
        relu = partial(tf.nn.relu)
        conv2_a = partial(conv2_, trainable=trainable)
        fc_c = partial(tf.layers.dense, activation=None, trainable=trainable)
        with tf.variable_scope(scope):

            net = relu(conv2_a(observe_input, 32, 7, 4))
            net = relu(conv2_a(net, 64, 5, 2))
            net = relu(conv2_a(net, 64, 3, 2))
            net = relu(conv2_a(net, 64, 3, 1))

            net = tf.layers.flatten(net)
            net = tf.concat([net, a], axis=1)
            net = relu(fc_c(net, 128))
            net = relu(fc_c(net, 128))

            q = fc_c(net,
                     1,
                     kernel_initializer=tf.initializers.random_uniform(
                         minval=-0.0003, maxval=0.0003))
            # Q(s,a) 输出一个[None,1]
            return q
def learn(env,
          policy_func,
          dataset,
          optim_batch_size=128,
          max_iters=1e4,
          adam_epsilon=1e-5,
          optim_stepsize=3e-4,
          ckpt_dir=None,
          task_name=None,
          verbose=False):
    """
    Learn a behavior clone policy, and return the save location

    :param env: (Gym Environment) the environment
    :param policy_func: (function (str, Gym Space, Gym Space): TensorFlow Tensor) creates the policy
    :param dataset: (Dset or MujocoDset) the dataset manager
    :param optim_batch_size: (int) the batch size
    :param max_iters: (int) the maximum number of iterations
    :param adam_epsilon: (float) the epsilon value for the adam optimizer
    :param optim_stepsize: (float) the optimizer stepsize
    :param ckpt_dir: (str) the save directory, can be None for temporary directory
    :param task_name: (str) the save name, can be None for saving directly to the directory name
    :param verbose: (bool)
    :return: (str) the save location for the TensorFlow model
    """

    val_per_iter = int(max_iters / 10)
    ob_space = env.observation_space
    ac_space = env.action_space
    policy = policy_func("pi", ob_space,
                         ac_space)  # Construct network for new policy
    # placeholder
    obs_ph = policy.obs_ph
    action_ph = policy.pdtype.sample_placeholder([None])
    stochastic_ph = policy.stochastic_ph
    loss = tf.reduce_mean(tf.square(action_ph - policy.ac))
    var_list = policy.get_trainable_variables()
    adam = MpiAdam(var_list, epsilon=adam_epsilon)
    lossandgrad = tf_util.function([obs_ph, action_ph, stochastic_ph],
                                   [loss] + [tf_util.flatgrad(loss, var_list)])

    tf_util.initialize()
    adam.sync()
    logger.log("Pretraining with Behavior Cloning...")
    for iter_so_far in tqdm(range(int(max_iters))):
        ob_expert, ac_expert = dataset.get_next_batch(optim_batch_size,
                                                      'train')
        train_loss, grad = lossandgrad(ob_expert, ac_expert, True)
        adam.update(grad, optim_stepsize)
        if verbose and iter_so_far % val_per_iter == 0:
            ob_expert, ac_expert = dataset.get_next_batch(-1, 'val')
            val_loss, _ = lossandgrad(ob_expert, ac_expert, True)
            logger.log("Training loss: {}, Validation loss: {}".format(
                train_loss, val_loss))

    if ckpt_dir is None:
        savedir_fname = tempfile.TemporaryDirectory().name
    else:
        savedir_fname = os.path.join(ckpt_dir, task_name)
    tf_util.save_state(savedir_fname, var_list=policy.get_variables())
    return savedir_fname
Esempio n. 9
0
class PPO1(BaseRLModel):
    """
    Proximal Policy Optimization algorithm (MPI version).
    Paper: https://arxiv.org/abs/1707.06347

    :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
    :param policy: (function (str, Gym Spaces, Gym Spaces): TensorFlow Tensor) creates the policy
    :param timesteps_per_actorbatch: (int) timesteps per actor per update
    :param clip_param: (float) clipping parameter epsilon
    :param entcoeff: (float) the entropy loss weight
    :param optim_epochs: (float) the optimizer's number of epochs
    :param optim_stepsize: (float) the optimizer's stepsize
    :param optim_batchsize: (int) the optimizer's the batch size
    :param gamma: (float) discount factor
    :param lam: (float) advantage estimation
    :param adam_epsilon: (float) the epsilon value for the adam optimizer
    :param schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
        'double_linear_con', 'middle_drop' or 'double_middle_drop')
    :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
    :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
    """
    def __init__(self,
                 policy,
                 env,
                 gamma=0.99,
                 timesteps_per_actorbatch=256,
                 clip_param=0.2,
                 entcoeff=0.01,
                 optim_epochs=4,
                 optim_stepsize=1e-3,
                 optim_batchsize=64,
                 lam=0.95,
                 adam_epsilon=1e-5,
                 schedule='linear',
                 verbose=0,
                 _init_setup_model=True):
        super().__init__(policy=policy,
                         env=env,
                         requires_vec_env=False,
                         verbose=verbose)

        self.gamma = gamma
        self.timesteps_per_actorbatch = timesteps_per_actorbatch
        self.clip_param = clip_param
        self.entcoeff = entcoeff
        self.optim_epochs = optim_epochs
        self.optim_stepsize = optim_stepsize
        self.optim_batchsize = optim_batchsize
        self.lam = lam
        self.adam_epsilon = adam_epsilon
        self.schedule = schedule

        self.graph = None
        self.sess = None
        self.policy_pi = None
        self.loss_names = None
        self.lossandgrad = None
        self.adam = None
        self.assign_old_eq_new = None
        self.compute_losses = None
        self.params = None
        self.step = None
        self.proba_step = None
        self.initial_state = None

        if _init_setup_model:
            self.setup_model()

    def setup_model(self):
        with SetVerbosity(self.verbose):

            self.graph = tf.Graph()
            with self.graph.as_default():
                self.sess = tf_util.single_threaded_session(graph=self.graph)

                # Construct network for new policy
                with tf.variable_scope("pi", reuse=False):
                    self.policy_pi = self.policy(self.sess,
                                                 self.observation_space,
                                                 self.action_space,
                                                 self.n_envs,
                                                 1,
                                                 None,
                                                 reuse=False)

                # Network for old policy
                with tf.variable_scope("oldpi", reuse=False):
                    old_pi = self.policy(self.sess,
                                         self.observation_space,
                                         self.action_space,
                                         self.n_envs,
                                         1,
                                         None,
                                         reuse=False)

                # Target advantage function (if applicable)
                atarg = tf.placeholder(dtype=tf.float32, shape=[None])

                # Empirical return
                ret = tf.placeholder(dtype=tf.float32, shape=[None])

                # learning rate multiplier, updated with schedule
                lrmult = tf.placeholder(name='lrmult',
                                        dtype=tf.float32,
                                        shape=[])

                # Annealed cliping parameter epislon
                clip_param = self.clip_param * lrmult

                obs_ph = self.policy_pi.obs_ph
                action_ph = self.policy_pi.pdtype.sample_placeholder([None])

                kloldnew = old_pi.proba_distribution.kl(
                    self.policy_pi.proba_distribution)
                ent = self.policy_pi.proba_distribution.entropy()
                meankl = tf.reduce_mean(kloldnew)
                meanent = tf.reduce_mean(ent)
                pol_entpen = (-self.entcoeff) * meanent

                # pnew / pold
                ratio = tf.exp(
                    self.policy_pi.proba_distribution.logp(action_ph) -
                    old_pi.proba_distribution.logp(action_ph))

                # surrogate from conservative policy iteration
                surr1 = ratio * atarg
                surr2 = tf.clip_by_value(ratio, 1.0 - clip_param,
                                         1.0 + clip_param) * atarg

                # PPO's pessimistic surrogate (L^CLIP)
                pol_surr = -tf.reduce_mean(tf.minimum(surr1, surr2))
                vf_loss = tf.reduce_mean(
                    tf.square(self.policy_pi.value_fn[:, 0] - ret))
                total_loss = pol_surr + pol_entpen + vf_loss
                losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
                self.loss_names = [
                    "pol_surr", "pol_entpen", "vf_loss", "kl", "ent"
                ]

                self.params = tf_util.get_trainable_vars("pi")
                self.lossandgrad = tf_util.function(
                    [obs_ph, old_pi.obs_ph, action_ph, atarg, ret, lrmult],
                    losses + [tf_util.flatgrad(total_loss, self.params)])
                self.adam = MpiAdam(self.params,
                                    epsilon=self.adam_epsilon,
                                    sess=self.sess)

                self.assign_old_eq_new = tf_util.function(
                    [], [],
                    updates=[
                        tf.assign(oldv, newv) for (
                            oldv,
                            newv) in zipsame(tf_util.get_globals_vars("oldpi"),
                                             tf_util.get_globals_vars("pi"))
                    ])
                self.compute_losses = tf_util.function(
                    [obs_ph, old_pi.obs_ph, action_ph, atarg, ret, lrmult],
                    losses)

                self.step = self.policy_pi.step
                self.proba_step = self.policy_pi.proba_step
                self.initial_state = self.policy_pi.initial_state

                tf_util.initialize(sess=self.sess)

    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100):
        with SetVerbosity(self.verbose):
            self._setup_learn(seed)

            with self.sess.as_default():
                self.adam.sync()

                # Prepare for rollouts
                seg_gen = traj_segment_generator(self.policy_pi, self.env,
                                                 self.timesteps_per_actorbatch)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()

                # rolling buffer for episode lengths
                lenbuffer = deque(maxlen=100)
                # rolling buffer for episode rewards
                rewbuffer = deque(maxlen=100)

                while True:
                    if callback:
                        callback(locals(), globals())
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    if self.schedule == 'constant':
                        cur_lrmult = 1.0
                    elif self.schedule == 'linear':
                        cur_lrmult = max(
                            1.0 - float(timesteps_so_far) / total_timesteps, 0)
                    else:
                        raise NotImplementedError

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    seg = seg_gen.__next__()
                    add_vtarg_and_adv(seg, self.gamma, self.lam)

                    # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                    obs_ph, action_ph, atarg, tdlamret = seg["ob"], seg[
                        "ac"], seg["adv"], seg["tdlamret"]

                    # predicted value function before udpate
                    vpredbefore = seg["vpred"]

                    # standardized advantage function estimate
                    atarg = (atarg - atarg.mean()) / atarg.std()
                    dataset = Dataset(
                        dict(ob=obs_ph,
                             ac=action_ph,
                             atarg=atarg,
                             vtarg=tdlamret),
                        shuffle=not issubclass(self.policy, LstmPolicy))
                    optim_batchsize = self.optim_batchsize or obs_ph.shape[0]

                    # set old parameter values to new parameter values
                    self.assign_old_eq_new(sess=self.sess)
                    logger.log("Optimizing...")
                    logger.log(fmt_row(13, self.loss_names))

                    # Here we do a bunch of optimization epochs over the data
                    for _ in range(self.optim_epochs):
                        # list of tuples, each of which gives the loss for a minibatch
                        losses = []
                        for batch in dataset.iterate_once(optim_batchsize):
                            *newlosses, grad = self.lossandgrad(batch["ob"],
                                                                batch["ob"],
                                                                batch["ac"],
                                                                batch["atarg"],
                                                                batch["vtarg"],
                                                                cur_lrmult,
                                                                sess=self.sess)
                            self.adam.update(grad,
                                             self.optim_stepsize * cur_lrmult)
                            losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(losses, axis=0)))

                    logger.log("Evaluating losses...")
                    losses = []
                    for batch in dataset.iterate_once(optim_batchsize):
                        newlosses = self.compute_losses(batch["ob"],
                                                        batch["ob"],
                                                        batch["ac"],
                                                        batch["atarg"],
                                                        batch["vtarg"],
                                                        cur_lrmult,
                                                        sess=self.sess)
                        losses.append(newlosses)
                    mean_losses, _, _ = mpi_moments(losses, axis=0)
                    logger.log(fmt_row(13, mean_losses))
                    for (loss_val, name) in zipsame(mean_losses,
                                                    self.loss_names):
                        logger.record_tabular("loss_" + name, loss_val)
                    logger.record_tabular(
                        "ev_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    # local values
                    lrlocal = (seg["ep_lens"], seg["ep_rets"])

                    # list of tuples
                    listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)
                    lens, rews = map(flatten_lists, zip(*listoflrpairs))
                    lenbuffer.extend(lens)
                    rewbuffer.extend(rews)
                    logger.record_tabular("EpLenMean", np.mean(lenbuffer))
                    logger.record_tabular("EpRewMean", np.mean(rewbuffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    timesteps_so_far += seg["total_timestep"]
                    iters_so_far += 1
                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", timesteps_so_far)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)
                    if self.verbose >= 1 and MPI.COMM_WORLD.Get_rank() == 0:
                        logger.dump_tabular()

        return self

    def predict(self, observation, state=None, mask=None):
        if state is None:
            state = self.initial_state
        if mask is None:
            mask = [False for _ in range(self.n_envs)]
        observation = np.array(observation).reshape(
            (-1, ) + self.observation_space.shape)

        actions, _, states, _ = self.step(observation, state, mask)
        return actions, states

    def action_probability(self, observation, state=None, mask=None):
        if state is None:
            state = self.initial_state
        if mask is None:
            mask = [False for _ in range(self.n_envs)]
        observation = np.array(observation).reshape(
            (-1, ) + self.observation_space.shape)

        return self.proba_step(observation, state, mask)

    def save(self, save_path):
        data = {
            "gamma": self.gamma,
            "timesteps_per_actorbatch": self.timesteps_per_actorbatch,
            "clip_param": self.clip_param,
            "entcoeff": self.entcoeff,
            "optim_epochs": self.optim_epochs,
            "optim_stepsize": self.optim_stepsize,
            "optim_batchsize": self.optim_batchsize,
            "lam": self.lam,
            "adam_epsilon": self.adam_epsilon,
            "schedule": self.schedule,
            "verbose": self.verbose,
            "policy": self.policy,
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "n_envs": self.n_envs,
            "_vectorize_action": self._vectorize_action
        }

        params = self.sess.run(self.params)

        self._save_to_file(save_path, data=data, params=params)

    @classmethod
    def load(cls, load_path, env=None, **kwargs):
        data, params = cls._load_from_file(load_path)

        model = cls(None, env=None, _init_setup_model=False)
        model.__dict__.update(data)
        model.__dict__.update(kwargs)
        model.set_env(env)
        model.setup_model()

        restores = []
        for param, loaded_p in zip(model.params, params):
            restores.append(param.assign(loaded_p))
        model.sess.run(restores)

        return model
Esempio n. 10
0
class TRPO(ActorCriticRLModel):
    def __init__(self,
                 policy,
                 env,
                 gamma=0.99,
                 timesteps_per_batch=1024,
                 max_kl=0.01,
                 cg_iters=10,
                 lam=0.98,
                 entcoeff=0.0,
                 cg_damping=1e-2,
                 vf_stepsize=3e-4,
                 vf_iters=3,
                 verbose=0,
                 tensorboard_log=None,
                 _init_setup_model=True,
                 policy_kwargs=None,
                 full_tensorboard_log=False,
                 seed=None,
                 n_cpu_tf_sess=1):
        super(TRPO, self).__init__(policy=policy,
                                   env=env,
                                   verbose=verbose,
                                   requires_vec_env=False,
                                   _init_setup_model=_init_setup_model,
                                   policy_kwargs=policy_kwargs,
                                   seed=seed,
                                   n_cpu_tf_sess=n_cpu_tf_sess)

        self.timesteps_per_batch = timesteps_per_batch
        self.cg_iters = cg_iters
        self.cg_damping = cg_damping
        self.gamma = gamma
        self.lam = lam
        self.max_kl = max_kl
        self.vf_iters = vf_iters
        self.vf_stepsize = vf_stepsize
        self.entcoeff = entcoeff
        self.tensorboard_log = tensorboard_log
        self.full_tensorboard_log = full_tensorboard_log

        # GAIL Params
        self.hidden_size_adversary = 100
        self.adversary_entcoeff = 1e-3
        self.expert_dataset = None
        self.g_step = 1
        self.d_step = 1
        self.d_stepsize = 3e-4

        self.graph = None
        self.sess = None
        self.policy_pi = None
        self.loss_names = None
        self.assign_old_eq_new = None
        self.compute_losses = None
        self.compute_lossandgrad = None
        self.compute_fvp = None
        self.compute_vflossandgrad = None
        self.d_adam = None
        self.vfadam = None
        self.get_flat = None
        self.set_from_flat = None
        self.timed = None
        self.allmean = None
        self.nworkers = None
        self.rank = None
        self.reward_giver = None
        self.step = None
        self.proba_step = None
        self.initial_state = None
        self.params = None
        self.summary = None

        if _init_setup_model:
            self.setup_model()

    def _get_pretrain_placeholders(self):
        policy = self.policy_pi
        action_ph = policy.pdtype.sample_placeholder([None])
        if isinstance(self.action_space, gym.spaces.Discrete):
            return policy.obs_ph, action_ph, policy.policy
        return policy.obs_ph, action_ph, policy.deterministic_action

    def setup_model(self):
        # prevent import loops
        from stable_baselines.gail.adversary import TransitionClassifier

        with SetVerbosity(self.verbose):

            assert issubclass( self.policy, ActorCriticPolicy ), "Error: the input policy for the TRPO model must be " \
                                                                 "an instance of common.policies.ActorCriticPolicy."

            self.nworkers = MPI.COMM_WORLD.Get_size()
            self.rank = MPI.COMM_WORLD.Get_rank()
            np.set_printoptions(precision=3)

            self.graph = tf.Graph()
            with self.graph.as_default():
                self.set_random_seed(self.seed)
                self.sess = tf_util.make_session(num_cpu=self.n_cpu_tf_sess,
                                                 graph=self.graph)

                # Construct network for new policy
                self.policy_pi = self.policy(self.sess,
                                             self.observation_space,
                                             self.action_space,
                                             self.n_envs,
                                             1,
                                             None,
                                             reuse=False,
                                             **self.policy_kwargs)

                # Network for old policy
                with tf.variable_scope("oldpi", reuse=False):
                    old_policy = self.policy(self.sess,
                                             self.observation_space,
                                             self.action_space,
                                             self.n_envs,
                                             1,
                                             None,
                                             reuse=False,
                                             **self.policy_kwargs)

                with tf.variable_scope("loss", reuse=False):
                    atarg = tf.placeholder(dtype=tf.float32, shape=[
                        None
                    ])  # Target advantage function (if applicable)
                    ret = tf.placeholder(dtype=tf.float32,
                                         shape=[None])  # Empirical return

                    observation = self.policy_pi.obs_ph
                    action = self.policy_pi.pdtype.sample_placeholder([None])

                    kloldnew = old_policy.proba_distribution.kl(
                        self.policy_pi.proba_distribution)
                    ent = self.policy_pi.proba_distribution.entropy()
                    meankl = tf.reduce_mean(kloldnew)
                    meanent = tf.reduce_mean(ent)
                    entbonus = self.entcoeff * meanent

                    vferr = tf.reduce_mean(
                        tf.square(self.policy_pi.value_flat - ret))

                    # advantage * pnew / pold
                    ratio = tf.exp(
                        self.policy_pi.proba_distribution.logp(action) -
                        old_policy.proba_distribution.logp(action))
                    surrgain = tf.reduce_mean(ratio * atarg)

                    optimgain = surrgain + entbonus
                    losses = [optimgain, meankl, entbonus, surrgain, meanent]
                    self.loss_names = [
                        "optimgain", "meankl", "entloss", "surrgain", "entropy"
                    ]

                    dist = meankl

                    all_var_list = tf_util.get_trainable_vars("model")
                    var_list = [
                        v for v in all_var_list
                        if "/vf" not in v.name and "/q/" not in v.name
                    ]
                    vf_var_list = [
                        v for v in all_var_list
                        if "/pi" not in v.name and "/logstd" not in v.name
                    ]

                    self.get_flat = tf_util.GetFlat(var_list, sess=self.sess)
                    self.set_from_flat = tf_util.SetFromFlat(var_list,
                                                             sess=self.sess)

                    klgrads = tf.gradients(dist, var_list)
                    flat_tangent = tf.placeholder(dtype=tf.float32,
                                                  shape=[None],
                                                  name="flat_tan")
                    shapes = [var.get_shape().as_list() for var in var_list]
                    start = 0
                    tangents = []
                    for shape in shapes:
                        var_size = tf_util.intprod(shape)
                        tangents.append(
                            tf.reshape(flat_tangent[start:start + var_size],
                                       shape))
                        start += var_size
                    gvp = tf.add_n([
                        tf.reduce_sum(grad * tangent)
                        for (grad, tangent) in zipsame(klgrads, tangents)
                    ])  # pylint: disable=E1111
                    # Fisher vector products
                    fvp = tf_util.flatgrad(gvp, var_list)

                    tf.summary.scalar('entropy_loss', meanent)
                    tf.summary.scalar('policy_gradient_loss', optimgain)
                    tf.summary.scalar('value_function_loss', surrgain)
                    tf.summary.scalar('approximate_kullback-leibler', meankl)
                    tf.summary.scalar(
                        'loss',
                        optimgain + meankl + entbonus + surrgain + meanent)

                    self.assign_old_eq_new = \
                     tf_util.function( [ ], [ ], updates=[ tf.assign( oldv, newv ) for (oldv, newv) in
                                                           zipsame( tf_util.get_globals_vars( "oldpi" ),
                                                                    tf_util.get_globals_vars( "model" ) ) ] )
                    self.compute_losses = tf_util.function(
                        [observation, old_policy.obs_ph, action, atarg],
                        losses)
                    self.compute_fvp = tf_util.function([
                        flat_tangent, observation, old_policy.obs_ph, action,
                        atarg
                    ], fvp)
                    self.compute_vflossandgrad = tf_util.function(
                        [observation, old_policy.obs_ph, ret],
                        tf_util.flatgrad(vferr, vf_var_list))

                    @contextmanager
                    def timed(msg):
                        if self.rank == 0 and self.verbose >= 1:
                            print(colorize(msg, color='magenta'))
                            start_time = time.time()
                            yield
                            print(
                                colorize("done in {:.3f} seconds".format(
                                    (time.time() - start_time)),
                                         color='magenta'))
                        else:
                            yield

                    def allmean(arr):
                        assert isinstance(arr, np.ndarray)
                        out = np.empty_like(arr)
                        MPI.COMM_WORLD.Allreduce(arr, out, op=MPI.SUM)
                        out /= self.nworkers
                        return out

                    tf_util.initialize(sess=self.sess)

                    th_init = self.get_flat()
                    MPI.COMM_WORLD.Bcast(th_init, root=0)
                    self.set_from_flat(th_init)

                with tf.variable_scope("Adam_mpi", reuse=False):
                    self.vfadam = MpiAdam(vf_var_list, sess=self.sess)

                    self.vfadam.sync()

                with tf.variable_scope("input_info", reuse=False):
                    tf.summary.scalar('discounted_rewards',
                                      tf.reduce_mean(ret))
                    tf.summary.scalar('learning_rate',
                                      tf.reduce_mean(self.vf_stepsize))
                    tf.summary.scalar('advantage', tf.reduce_mean(atarg))
                    tf.summary.scalar('kl_clip_range',
                                      tf.reduce_mean(self.max_kl))

                    if self.full_tensorboard_log:
                        tf.summary.histogram('discounted_rewards', ret)
                        tf.summary.histogram('learning_rate', self.vf_stepsize)
                        tf.summary.histogram('advantage', atarg)
                        tf.summary.histogram('kl_clip_range', self.max_kl)
                        if tf_util.is_image(self.observation_space):
                            tf.summary.image('observation', observation)
                        else:
                            tf.summary.histogram('observation', observation)

                self.timed = timed
                self.allmean = allmean

                self.step = self.policy_pi.step
                self.proba_step = self.policy_pi.proba_step
                self.initial_state = self.policy_pi.initial_state

                self.params = tf_util.get_trainable_vars(
                    "model") + tf_util.get_trainable_vars("oldpi")

                self.summary = tf.summary.merge_all()

                self.compute_lossandgrad = \
                 tf_util.function( [ observation, old_policy.obs_ph, action, atarg, ret ],
                                   [ self.summary, tf_util.flatgrad( optimgain, var_list ) ] + losses )

    def learn(self,
              total_timesteps,
              callback=None,
              log_interval=100,
              tb_log_name="TRPO",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)

        with SetVerbosity(self.verbose), TensorboardWriter(
                self.graph, self.tensorboard_log, tb_log_name,
                new_tb_log) as writer:
            self._setup_learn()

            with self.sess.as_default():
                callback.on_training_start(locals(), globals())

                seg_gen = traj_segment_generator(self.policy_pi,
                                                 self.env,
                                                 self.timesteps_per_batch,
                                                 callback=callback)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                len_buffer = deque(
                    maxlen=40)  # rolling buffer for episode lengths
                reward_buffer = deque(
                    maxlen=40)  # rolling buffer for episode rewards

                while True:
                    if timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    def fisher_vector_product(vec):
                        return self.allmean(
                            self.compute_fvp(
                                vec, *fvpargs,
                                sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for k in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()

                        # Stop training early (triggered by the callback)
                        if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                            break

                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        observation, action = seg["observations"], seg[
                            "actions"]
                        atarg, tdlamret = seg["adv"], seg["tdlamret"]

                        vpredbefore = seg[
                            "vpred"]  # predicted value function before update
                        atarg = (atarg - atarg.mean()) / (
                            atarg.std() + 1e-8
                        )  # standardized advantage function estimate

                        print('advantages: ', np.min(atarg), np.max(atarg),
                              np.mean(atarg))
                        # true_rew is the reward without discount
                        if writer is not None:
                            total_episode_reward_logger(
                                self.episode_reward,
                                seg["true_rewards"].reshape(
                                    (self.n_envs, -1)), seg["dones"].reshape(
                                        (self.n_envs, -1)), writer,
                                self.num_timesteps)

                        args = seg["observations"], seg["observations"], seg[
                            "actions"], atarg
                        # Subsampling: see p40-42 of John Schulman thesis
                        # http://joschu.net/docs/thesis.pdf
                        fvpargs = [arr[::5] for arr in args]

                        self.assign_old_eq_new(sess=self.sess)

                        with self.timed("computegrad"):
                            steps = self.num_timesteps + (k + 1) * (
                                seg["total_timestep"] / self.g_step)
                            run_options = tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE)
                            run_metadata = tf.RunMetadata(
                            ) if self.full_tensorboard_log else None

                            _, grad, *lossbefore = self.compute_lossandgrad(
                                *args,
                                tdlamret,
                                sess=self.sess,
                                options=run_options,
                                run_metadata=run_metadata)

                        print(f'losses before', lossbefore)
                        lossbefore = self.allmean(np.array(lossbefore))
                        grad = self.allmean(grad)
                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            with self.timed("conjugate_gradient"):
                                stepdir = conjugate_gradient(
                                    fisher_vector_product,
                                    grad,
                                    cg_iters=self.cg_iters,
                                    verbose=self.rank == 0
                                    and self.verbose >= 1)
                            assert np.isfinite(stepdir).all()
                            shs = .5 * stepdir.dot(
                                fisher_vector_product(stepdir))
                            # abs(shs) to avoid taking square root of negative values
                            lagrange_multiplier = np.sqrt(
                                abs(shs) / self.max_kl)
                            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                            fullstep = stepdir / lagrange_multiplier
                            expectedimprove = grad.dot(fullstep)
                            surrbefore = lossbefore[0]
                            stepsize = 1.0
                            thbefore = self.get_flat()
                            for _ in range(10):
                                thnew = thbefore + fullstep * stepsize
                                self.set_from_flat(thnew)
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(
                                        self.compute_losses(*args,
                                                            sess=self.sess)))
                                improve = surr - surrbefore
                                logger.log("Expected: %.3f Actual: %.3f" %
                                           (expectedimprove, improve))
                                if not np.isfinite(mean_losses).all():
                                    logger.log(
                                        "Got non-finite value of losses -- bad!"
                                    )
                                elif kl_loss > self.max_kl * 1.5:
                                    logger.log(
                                        "violated KL constraint. shrinking step."
                                    )
                                elif improve < 0:
                                    logger.log(
                                        "surrogate didn't improve. shrinking step."
                                    )
                                else:
                                    logger.log("Stepsize OK!")
                                    break
                                stepsize *= .5
                            else:
                                logger.log("couldn't compute a good step")
                                self.set_from_flat(thbefore)
                            if self.nworkers > 1 and iters_so_far % 20 == 0:
                                # list of tuples
                                paramsums = MPI.COMM_WORLD.allgather(
                                    (thnew.sum(), self.vfadam.getflat().sum()))
                                assert all(
                                    np.allclose(ps, paramsums[0])
                                    for ps in paramsums[1:])

                            for (loss_name,
                                 loss_val) in zip(self.loss_names,
                                                  mean_losses):
                                logger.record_tabular(loss_name, loss_val)

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                # NOTE: for recurrent policies, use shuffle=False?
                                for (mbob, mbret) in dataset.iterbatches(
                                    (seg["observations"], seg["tdlamret"]),
                                        include_final_partial_batch=False,
                                        batch_size=128,
                                        shuffle=True):
                                    grad = self.allmean(
                                        self.compute_vflossandgrad(
                                            mbob, mbob, mbret, sess=self.sess))
                                    self.vfadam.update(grad, self.vf_stepsize)

                    # Stop training early (triggered by the callback)
                    if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                        break

                    logger.record_tabular(
                        "explained_variance_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    # lr: lengths and rewards
                    lr_local = (seg["ep_lens"], seg["ep_rets"])  # local values
                    list_lr_pairs = MPI.COMM_WORLD.allgather(
                        lr_local)  # list of tuples
                    lens, rews = map(flatten_lists, zip(*list_lr_pairs))

                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)

                    if len(len_buffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))
                        logger.record_tabular("EpRewMean",
                                              np.mean(reward_buffer))

                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(
                        seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        callback.on_training_end()
        return self

    def save(self, save_path, cloudpickle=False):
        pass
Esempio n. 11
0
class MDPO_ON(ActorCriticRLModel):
    """
    Mirror Descent Policy Optimization (On-policy)

    :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...)
    :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
    :param gamma: (float) the discount value
    :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon)
    :param max_kl: (float) the Kullback-Leibler loss threshold
    :param cg_iters: (int) the number of iterations for the conjugate gradient calculation
    :param lam: (float) GAE factor
    :param entcoeff: (float) the weight for the entropy loss
    :param cg_damping: (float) the compute gradient dampening factor
    :param vf_stepsize: (float) the value function stepsize
    :param vf_iters: (int) the value function's number iterations for learning
    :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
    :param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
    :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
    :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
    :param full_tensorboard_log: (bool) enable additional logging when using tensorboard
        WARNING: this logging can take a lot of space quickly
    """

    def __init__(self, policy, env, gamma=0.99, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, lam=0.98,
                 entcoeff=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, verbose=0, tensorboard_log=None,
                 _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False, seed=0, sgd_steps=10,
                 klcoeff=0.1, method="multistep-SGD", tsallis_q=1.0, t_pi=1.0, t_c=0.01):
        super(MDPO_ON, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=False,
                                   _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs)

        self.using_gail = False
        self.using_mdal = False

        self.timesteps_per_batch = timesteps_per_batch
        self.cg_iters = cg_iters
        self.cg_damping = cg_damping
        self.gamma = gamma
        self.lam = lam
        self.max_kl = max_kl
        self.vf_iters = vf_iters
        self.vf_stepsize = vf_stepsize
        self.entcoeff = entcoeff
        self.tensorboard_log = tensorboard_log
        self.full_tensorboard_log = full_tensorboard_log

        # GAIL Params
        self.hidden_size_adversary = 100
        self.adversary_entcoeff = 1e-3
        self.expert_dataset = None
        self.g_step = 1
        self.d_step = 1
        self.d_stepsize = 3e-4

        self.graph = None
        self.sess = None
        self.policy_pi = None
        self.loss_names = None
        self.assign_old_eq_new = None
        self.compute_losses = None
        self.compute_lossandgrad = None
        self.compute_fvp = None
        self.compute_vflossandgrad = None
        self.d_adam = None
        self.vfadam = None
        self.get_flat = None
        self.set_from_flat = None
        self.timed = None
        self.allmean = None
        self.nworkers = None
        self.rank = None
        self.reward_giver = None
        self.step = None
        self.proba_step = None
        self.initial_state = None
        self.params = None
        self.summary = None
        self.episode_reward = None
        self.seed = seed
        self.sgd_steps = sgd_steps
        self.klcoeff = klcoeff
        self.cliprange_vf = 0.2
        self.method = method
        self.tsallis_q = tsallis_q
        self.t_pi = t_pi
        self.t_c = t_c


        if _init_setup_model:
            self.setup_model()

    def _get_pretrain_placeholders(self):
        policy = self.policy_pi
        action_ph = policy.pdtype.sample_placeholder([None])
        if isinstance(self.action_space, gym.spaces.Discrete):
            return policy.obs_ph, action_ph, policy.policy
        return policy.obs_ph, action_ph, policy.deterministic_action

    def setup_model(self):
        # prevent import loops
        from stable_baselines.gail.adversary import TransitionClassifier
        from stable_baselines.mdal.adversary import TabularAdversaryTF, NeuralAdversaryTRPO


        with SetVerbosity(self.verbose):

            assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the MDPO model must be " \
                                                               "an instance of common.policies.ActorCriticPolicy."

            self.nworkers = MPI.COMM_WORLD.Get_size()
            self.rank = MPI.COMM_WORLD.Get_rank()
            np.set_printoptions(precision=3)

            self.graph = tf.Graph()
            with self.graph.as_default():
                self.sess = tf_util.single_threaded_session(graph=self.graph)
                # self._setup_learn(self.seed)
                self._setup_learn()

                if self.using_gail:
                    self.reward_giver = TransitionClassifier(self.observation_space, self.action_space,
                                                             self.hidden_size_adversary,
                                                             entcoeff=self.adversary_entcoeff)
                elif self.using_mdal:
                    if self.neural:
                        self.reward_giver = NeuralAdversaryTRPO(self.sess, self.observation_space, self.action_space,
                                                                self.hidden_size_adversary,
                                                                entcoeff=self.adversary_entcoeff)
                    else:
                        self.reward_giver = TabularAdversaryTF(self.sess, self.observation_space, self.action_space,
                                                                 self.hidden_size_adversary,
                                                                 entcoeff=self.adversary_entcoeff,
                                                                 expert_features=self.expert_dataset.successor_features,
                                                                 exploration_bonus=self.exploration_bonus,
                                                                 bonus_coef=self.bonus_coef,
                                                                 t_c=self.t_c,
                                                                 is_action_features=self.is_action_features)
                # Construct network for new policy
                self.policy_pi = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
                                             None, reuse=False, **self.policy_kwargs)

                # Network for old policy
                with tf.variable_scope("oldpi", reuse=False):
                    self.old_policy = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
                                             None, reuse=False, **self.policy_kwargs)

                # Network for fitting closed form
                with tf.variable_scope("closedpi", reuse=False):
                    self.closed_policy = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
                                             None, reuse=False, **self.policy_kwargs)

                with tf.variable_scope("loss", reuse=False):
                    self.atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
                    self.vtarg = tf.placeholder(dtype=tf.float32, shape=[None])
                    self.ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
                    self.learning_rate_ph = tf.placeholder(dtype=tf.float32, shape=[], name="learning_rate_ph")
                    self.outer_learning_rate_ph = tf.placeholder(dtype=tf.float32, shape=[], name="outer_learning_rate_ph")
                    self.old_vpred_ph = tf.placeholder(dtype=tf.float32, shape=[None], name="old_vpred_ph")
                    self.clip_range_vf_ph = tf.placeholder(dtype=tf.float32, shape=[], name="clip_range_ph")

                    observation = self.policy_pi.obs_ph
                    self.action = self.policy_pi.pdtype.sample_placeholder([None])

                    if self.tsallis_q == 1.0:
                        kloldnew = self.policy_pi.proba_distribution.kl(self.old_policy.proba_distribution)
                        ent = self.policy_pi.proba_distribution.entropy()
                        meankl = tf.reduce_mean(kloldnew)

                    else:
                        logp_pi = self.policy_pi.proba_distribution.logp(self.action)
                        logp_pi_old =  self.old_policy.proba_distribution.logp(self.action)
                        ent = self.policy_pi.proba_distribution.entropy()
                        #kloldnew = self.policy_pi.proba_distribution.kl_tsallis(self.old_policy.proba_distribution, self.tsallis_q)
                        tsallis_q = 2.0 - self.tsallis_q
                        meankl = tf.reduce_mean(tf_log_q(tf.exp(logp_pi), tsallis_q) - tf_log_q(tf.exp(logp_pi_old), tsallis_q)) #tf.reduce_mean(kloldnew)

                    meanent = tf.reduce_mean(ent)
                    entbonus = self.entcoeff * meanent

                    if self.cliprange_vf is None:
                        vpred_clipped = self.policy_pi.value_flat
                    else:
                        vpred_clipped = self.old_vpred_ph + \
                            tf.clip_by_value(self.policy_pi.value_flat - self.old_vpred_ph,
                                             - self.clip_range_vf_ph, self.clip_range_vf_ph)

                    vf_losses1 = tf.square(self.policy_pi.value_flat - self.ret)
                    vf_losses2 = tf.square(vpred_clipped - self.ret)
                    vferr = tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))

                    # advantage * pnew / pold
                    ratio = tf.exp(self.policy_pi.proba_distribution.logp(self.action) -
                                   self.old_policy.proba_distribution.logp(self.action))

                    if self.method == "multistep-SGD":
                        surrgain = tf.reduce_mean(ratio * self.atarg) - meankl / self.learning_rate_ph
                    elif self.method == "closedreverse-KL":
                        surrgain = tf.reduce_mean(tf.exp(self.atarg) * self.policy_pi.proba_distribution.logp(self.action))
                    else:
                        policygain = tf.reduce_mean(tf.exp(self.atarg) * tf.log(self.closed_policy.proba_distribution.mean))
                        surrgain = tf.reduce_mean(ratio * self.atarg) - tf.reduce_mean(self.learning_rate_ph * ratio * self.policy_pi.proba_distribution.logp(self.action))

                    optimgain = surrgain #+ entbonus - self.learning_rate_ph * meankl
                    losses = [optimgain, meankl, entbonus, surrgain, meanent]
                    self.loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

                    dist = meankl

                    all_var_list = tf_util.get_trainable_vars("model")
                    var_list = [v for v in all_var_list if "/vf" not in v.name and "/q/" not in v.name]
                    vf_var_list = [v for v in all_var_list if "/pi" not in v.name and "/logstd" not in v.name]
                    print("policy vars", var_list)

                    all_closed_var_list = tf_util.get_trainable_vars("closedpi")
                    closed_var_list = [v for v in all_closed_var_list if "/vf" not in v.name and "/q" not in v.name]

                    self.get_flat = tf_util.GetFlat(var_list, sess=self.sess)
                    self.set_from_flat = tf_util.SetFromFlat(var_list, sess=self.sess)

                    klgrads = tf.gradients(dist, var_list)
                    flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
                    shapes = [var.get_shape().as_list() for var in var_list]
                    start = 0
                    tangents = []
                    for shape in shapes:
                        var_size = tf_util.intprod(shape)
                        tangents.append(tf.reshape(flat_tangent[start: start + var_size], shape))
                        start += var_size
                    gvp = tf.add_n([tf.reduce_sum(grad * tangent)
                                    for (grad, tangent) in zipsame(klgrads, tangents)])  # pylint: disable=E1111
                    fvp = tf_util.flatgrad(gvp, var_list)

                    # tf.summary.scalar('entropy_loss', meanent)
                    # tf.summary.scalar('policy_gradient_loss', optimgain)
                    # tf.summary.scalar('value_function_loss', surrgain)
                    # tf.summary.scalar('approximate_kullback-leibler', meankl)
                    # tf.summary.scalar('loss', optimgain + meankl + entbonus + surrgain + meanent)

                    self.assign_old_eq_new = \
                        tf_util.function([], [], updates=[tf.assign(oldv, newv) for (oldv, newv) in
                                                          zipsame(tf_util.get_globals_vars("oldpi"),
                                                                  tf_util.get_globals_vars("model"))])
                    self.compute_losses = tf_util.function([observation, self.old_policy.obs_ph, self.action, self.atarg, self.learning_rate_ph, self.vtarg], losses)
                    self.compute_fvp = tf_util.function([flat_tangent, observation, self.old_policy.obs_ph, self.action, self.atarg],
                                                        fvp)
                    self.compute_vflossandgrad = tf_util.function([observation, self.old_policy.obs_ph, self.ret, self.old_vpred_ph, self.clip_range_vf_ph],
                                                                  tf_util.flatgrad(vferr, vf_var_list))

                    grads = tf.gradients(-optimgain, var_list)
                    grads, _grad_norm = tf.clip_by_global_norm(grads, 0.5)
                    trainer = tf.train.AdamOptimizer(learning_rate=self.outer_learning_rate_ph, epsilon=1e-5)
                    # trainer = tf.train.AdamOptimizer(learning_rate=3e-4, epsilon=1e-5)
                    grads = list(zip(grads, var_list))
                    self._train = trainer.apply_gradients(grads)

                    @contextmanager
                    def timed(msg):
                        if self.rank == 0 and self.verbose >= 1:
                            # print(colorize(msg, color='magenta'))
                            # start_time = time.time()
                            yield
                            # print(colorize("done in {:.3f} seconds".format((time.time() - start_time)),
                            #                color='magenta'))
                        else:
                            yield

                    def allmean(arr):
                        assert isinstance(arr, np.ndarray)
                        out = np.empty_like(arr)
                        MPI.COMM_WORLD.Allreduce(arr, out, op=MPI.SUM)
                        out /= self.nworkers
                        return out

                    tf_util.initialize(sess=self.sess)

                    th_init = self.get_flat()
                    MPI.COMM_WORLD.Bcast(th_init, root=0)
                    self.set_from_flat(th_init)

                with tf.variable_scope("Adam_mpi", reuse=False):
                    self.vfadam = MpiAdam(vf_var_list, sess=self.sess)
                    if self.using_gail or self.using_mdal:
                        self.d_adam = MpiAdam(self.reward_giver.get_trainable_variables(), sess=self.sess)
                        self.d_adam.sync()
                    self.vfadam.sync()

                with tf.variable_scope("input_info", reuse=False):
                    tf.summary.scalar('discounted_rewards', tf.reduce_mean(self.ret))
                    tf.summary.scalar('learning_rate', tf.reduce_mean(self.vf_stepsize))
                    tf.summary.scalar('advantage', tf.reduce_mean(self.atarg))
                    tf.summary.scalar('kl_clip_range', tf.reduce_mean(self.max_kl))

                    if self.full_tensorboard_log:
                        tf.summary.histogram('discounted_rewards', self.ret)
                        tf.summary.histogram('learning_rate', self.vf_stepsize)
                        tf.summary.histogram('advantage', self.atarg)
                        tf.summary.histogram('kl_clip_range', self.max_kl)
                        if tf_util.is_image(self.observation_space):
                            tf.summary.image('observation', observation)
                        else:
                            tf.summary.histogram('observation', observation)

                self.timed = timed
                self.allmean = allmean

                self.step = self.policy_pi.step
                self.proba_step = self.policy_pi.proba_step
                self.initial_state = self.policy_pi.initial_state

                self.params = tf_util.get_trainable_vars("model") + tf_util.get_trainable_vars("oldpi")
                if self.using_gail:
                    self.params.extend(self.reward_giver.get_trainable_variables())

                self.summary = tf.summary.merge_all()

                self.compute_lossandgrad = \
                    tf_util.function([observation, self.old_policy.obs_ph, self.action, self.atarg, self.ret, self.learning_rate_ph, self.vtarg, self.closed_policy.obs_ph],
                                     [self.summary, tf_util.flatgrad(optimgain, var_list)] + losses)

    def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="MDPO",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)
        print("got seed {}, sgd_steps {}".format(seed, self.sgd_steps))

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
                as writer:

            with self.sess.as_default():
                callback.on_training_start(locals(), globals())

                seg_gen = traj_segment_generator(self.old_policy, self.env, self.timesteps_per_batch,
                                                     reward_giver=self.reward_giver,
                                                     gail=self.using_gail, mdal=self.using_mdal, neural=self.neural,
                                                     action_space=self.action_space, gamma=self.gamma, callback=callback)


                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                len_buffer = deque(maxlen=40)  # rolling buffer for episode lengths
                reward_buffer = deque(maxlen=40)  # rolling buffer for episode rewards

                self.episode_reward = np.zeros((self.n_envs,))
                self.outer_learning_rate = get_schedule_fn(3e-4)
                self.cliprange_vf = get_schedule_fn(0.2)

                true_reward_buffer = None
                if self.using_gail or self.using_mdal:
                    true_reward_buffer = deque(maxlen=40)

                    # Initialize dataloader
                    batchsize = self.timesteps_per_batch // self.d_step
                    self.expert_dataset.init_dataloader(batchsize)

                    #  Stats not used for now
                    # TODO: replace with normal tb logging
                    #  g_loss_stats = Stats(loss_names)
                    #  d_loss_stats = Stats(reward_giver.loss_name)
                    #  ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"])

                while True:
                    # if callback is not None:
                    #     # Only stop training if return value is False, not when it is None. This is for backwards
                    #     # compatibility with callbacks that have no return statement.
                    #     if callback(locals(), globals()) is False:
                    #         break
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" % iters_so_far)

                    #def fisher_vector_product(vec):
                    #    return self.allmean(self.compute_fvp(vec, *fvpargs, sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    # logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for k in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()
                        if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                            break

                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        if self.using_mdal:
                            policy_successor_features = add_successor_features(seg, self.gamma,
                                                                           is_action_features=self.is_action_features)
                        else:
                            policy_successor_features = add_successor_features(seg, self.gamma)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        observation, action = seg["observations"], seg["actions"]
                        atarg, tdlamret = seg["adv"], seg["tdlamret"]


                        vpredbefore = seg["vpred"]  # predicted value function before update
                        atarg = (atarg - atarg.mean()) / atarg.std()  # standardized advantage function estimate

                        # true_rew is the reward without discount
                        if writer is not None:
                            self.episode_reward = total_episode_reward_logger(self.episode_reward,
                                                                              seg["true_rewards"].reshape(
                                                                                  (self.n_envs, -1)),
                                                                              seg["dones"].reshape((self.n_envs, -1)),
                                                                              writer, self.num_timesteps)

                        n_updates = int(total_timesteps / self.timesteps_per_batch)
                        lr_now = np.float32(1.0 - (iters_so_far - 1.0) / n_updates)
                        outer_lr_now = self.outer_learning_rate(1.0 - (iters_so_far - 1.0) / n_updates)
                        clip_now = self.cliprange_vf(1.0 - (iters_so_far - 1.0) / n_updates)
                        args = seg["observations"], seg["observations"], seg["actions"], atarg
                        # Subsampling: see p40-42 of John Schulman thesis
                        # http://joschu.net/docs/thesis.pdf
                        #fvpargs = [arr[::5] for arr in args]

                        with self.timed("computegrad"):
                            steps = self.num_timesteps + (k + 1) * (seg["total_timestep"] / self.g_step)
                            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                            run_metadata = tf.RunMetadata() if self.full_tensorboard_log else None
                            # run loss backprop with summary, and save the metadata (memory, compute time, ...)
                            if writer is not None:
                                summary, grad, *lossbefore = self.compute_lossandgrad(*args, tdlamret,
                                                                                      lr_now, seg["vpred"],
                                                                                      seg["observations"],
                                                                                      sess=self.sess,
                                                                                      options=run_options,
                                                                                      run_metadata=run_metadata)
                                if self.full_tensorboard_log:
                                    writer.add_run_metadata(run_metadata, 'step%d' % steps)
                                writer.add_summary(summary, steps)
                            else:
                                _, grad, *lossbefore = self.compute_lossandgrad(*args, tdlamret,
                                                                                lr_now, seg["vpred"],
                                                                                seg["observations"],
                                                                                sess=self.sess,
                                                                                options=run_options,
                                                                                run_metadata=run_metadata)
                                td_map = {self.policy_pi.obs_ph: seg["observations"],
                                            self.old_policy.obs_ph: seg["observations"],
                                            self.closed_policy.obs_ph: seg["observations"],
                                            self.action: seg["actions"], self.atarg: atarg, self.ret: tdlamret,
                                            self.learning_rate_ph: lr_now, self.outer_learning_rate_ph: outer_lr_now,
                                            self.vtarg: seg["vpred"]}
                                for _ in range(int(self.sgd_steps)):
                                    _ = self.sess.run(self._train, td_map)
                                    #if self.method == "closed-KL":
                                    #    _ = self.sess.run(self._train_policy, td_map)

                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            for _ in range(1):
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(self.compute_losses(*args, lr_now, seg["vpred"], sess=self.sess)))

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                # NOTE: for recurrent policies, use shuffle=False?
                                for (mbob, mbret, mbval) in dataset.iterbatches((seg["observations"], seg["tdlamret"], seg["vpred"]),
                                                                         include_final_partial_batch=False,
                                                                         batch_size=128,
                                                                         shuffle=True):
                                    grad = self.allmean(self.compute_vflossandgrad(mbob, mbob, mbret, mbval, clip_now, sess=self.sess))
                                    self.vfadam.update(grad, outer_lr_now) #self.vf_stepsize)

                        if iters_so_far % 1 == 0:
                            # print("updating theta now")
                            self.assign_old_eq_new(sess=self.sess)

                    for (loss_name, loss_val) in zip(self.loss_names, mean_losses):
                        logger.record_tabular(loss_name, loss_val)

                    logger.record_tabular("explained_variance_tdlam_before",
                                          explained_variance(vpredbefore, tdlamret))

                    if self.using_gail:
                        # ------------------ Update D ------------------
                        logger.log("Optimizing Discriminator...")
                        logger.log(fmt_row(13, self.reward_giver.loss_name))
                        assert len(observation) == self.timesteps_per_batch
                        batch_size = self.timesteps_per_batch // self.d_step

                        # NOTE: uses only the last g step for observation
                        d_losses = []  # list of tuples, each of which gives the loss for a minibatch
                        # NOTE: for recurrent policies, use shuffle=False?
                        for ob_batch, ac_batch in dataset.iterbatches((observation, action),
                                                                      include_final_partial_batch=False,
                                                                      batch_size=batch_size,
                                                                      shuffle=True):
                            ob_expert, ac_expert = self.expert_dataset.get_next_batch()
                            # update running mean/std for reward_giver
                            if self.reward_giver.normalize:
                                self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))

                            # Reshape actions if needed when using discrete actions
                            if isinstance(self.action_space, gym.spaces.Discrete):
                                if len(ac_batch.shape) == 2:
                                    ac_batch = ac_batch[:, 0]
                                if len(ac_expert.shape) == 2:
                                    ac_expert = ac_expert[:, 0]
                            *newlosses, grad = self.reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert)
                            self.d_adam.update(self.allmean(grad), self.d_stepsize)
                            d_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                    elif self.using_mdal:
                        batch_sampling = True

                        if self.neural:

                            if batch_sampling:
                                batch_size = self.timesteps_per_batch // self.d_step

                                # NOTE: uses only the last g step for observation
                                d_losses = []  # list of tuples, each of which gives the loss for a minibatch
                                # NOTE: for recurrent policies, use shuffle=False?
                                for ob_batch, ac_batch in dataset.iterbatches((observation, action),
                                                                              include_final_partial_batch=False,
                                                                              batch_size=batch_size,
                                                                              shuffle=True):
                                # ob_batch, ac_batch, gamma_batch = np.array(batch_buffer['obs']), np.array(
                                #     batch_buffer['acs']), np.array(batch_buffer['gammas'])
                                    gamma_batch = np.ones((ob_batch.shape[0]))
                                    ob_expert, ac_expert = self.expert_dataset.get_next_batch()
                                    gamma_expert = np.ones((ob_expert.shape[0]))
                                    # ob_expert, ac_expert, gamma_expert = np.concatenate(self.expert_dataset.ep_obs),\
                                    #                                      np.concatenate(self.expert_dataset.ep_acs),\
                                    #                                      np.concatenate(self.expert_dataset.ep_gammas)

                                    # update running mean/std for reward_giver
                                    if self.reward_giver.normalize:
                                        self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))

                                    # Reshape actions if needed when using discrete actions
                                    if isinstance(self.action_space, gym.spaces.Discrete):
                                        if len(ac_batch.shape) == 2:
                                            ac_batch = ac_batch[:, 0]
                                        if len(ac_expert.shape) == 2:
                                            ac_expert = ac_expert[:, 0]

                                    ob_reg_expert, ac_reg_expert = np.array(ob_expert), np.array(ac_expert)

                                    # while True:
                                    #     if ob_reg_expert.shape[0] == ob_batch.shape[0] and ac_reg_expert.shape[0] == \
                                    #             ac_batch.shape[0]:
                                    #         break
                                    #     ob_reg_expert, ac_reg_expert = self.expert_dataset.get_next_batch()
                                    #     ob_reg_expert, ac_reg_expert = np.array(ob_reg_expert), np.array(ac_reg_expert)


                                    alpha = np.random.uniform(0.0, 1.0, size=(ob_reg_expert.shape[0], 1))
                                    ob_mix_batch = alpha * ob_batch[:ob_reg_expert.shape[0]] + (1 - alpha) * ob_reg_expert
                                    ac_mix_batch = alpha * ac_batch[:ac_reg_expert.shape[0]] + (1 - alpha) * ac_reg_expert
                                    with self.sess.as_default():
                                        # self.reward_giver.train(ob_batch, ac_batch, np.expand_dims(gamma_batch, axis=1),
                                        #                         ob_expert, ac_expert, np.expand_dims(gamma_expert, axis=1))
                                        *newlosses, grad = self.reward_giver.lossandgrad(
                                                                ob_batch, ac_batch, np.expand_dims(gamma_batch, axis=1),
                                                                ob_expert, ac_expert, np.expand_dims(gamma_expert, axis=1),
                                                                ob_mix_batch, ac_mix_batch)
                                        self.d_adam.update(self.allmean(grad), self.d_stepsize)
                            else:
                                # assert len(observation) == self.timesteps_per_batch
                                # Comment out if you want only the latest rewards:
                                obs_batch, acs_batch, gammas_batch = seg['obs_batch'], seg['acs_batch'], seg['gammas_batch']
                                batch_successor_features = seg['successor_features_batch']


                                if self.reward_giver.normalize:
                                    ob_reg_batch, ac_reg_batch = observation, action
                                    ob_expert, _ = self.expert_dataset.get_next_batch()
                                    self.reward_giver.obs_rms.update(np.concatenate((ob_reg_batch, ob_expert), 0))
                                #     self.reward_giver.obs_rms.update(
                                #         np.array(batch_successor_features)[:, :self.observation_space.shape[0]])

                                for idx, (ob_batch, ac_batch, gamma_batch) in enumerate(
                                        zip(obs_batch, acs_batch, gammas_batch)):
                                    rand_traj = np.random.randint(self.expert_dataset.num_traj)
                                    ob_expert, ac_expert, gamma_expert = self.expert_dataset.ep_obs[rand_traj], \
                                                                         self.expert_dataset.ep_acs[rand_traj], \
                                                                         self.expert_dataset.ep_gammas[rand_traj]

                                    ob_batch, ac_batch, gamma_batch = np.array(ob_batch), np.array(ac_batch), np.array(
                                        gamma_batch)

                                    while True:
                                        ob_reg_expert, ac_reg_expert = self.expert_dataset.get_next_batch()
                                        ob_reg_expert, ac_reg_expert = np.array(ob_reg_expert), np.array(ac_reg_expert)

                                        if ob_reg_expert.shape[0] == ob_reg_batch.shape[0] and ac_reg_expert.shape[0] == \
                                                ac_reg_batch.shape[0]:
                                            break
                                    alpha = np.random.uniform(0.0, 1.0, size=(ob_reg_batch.shape[0], 1))
                                    ob_mix_batch = alpha * ob_reg_batch + (1 - alpha) * ob_reg_expert
                                    ac_mix_batch = alpha * ac_reg_batch + (1 - alpha) * ac_reg_expert

                                    with self.sess.as_default():
                                        *newlosses, grad = self.reward_giver.lossandgrad(
                                                                ob_batch, ac_batch, np.expand_dims(gamma_batch, axis=1),
                                                                ob_expert, ac_expert, np.expand_dims(gamma_expert, axis=1),
                                                                ob_mix_batch, ac_mix_batch)
                                        self.d_adam.update(self.allmean(grad), self.d_stepsize)
                                        # self.reward_giver.train(ob_batch, ac_batch, np.expand_dims(gamma_batch, axis=1),
                                        #                         ob_expert, ac_expert,
                                        #                         np.expand_dims(gamma_expert, axis=1),
                                        #                         ob_mix_batch, ac_mix_batch)

                    if self.using_gail or self.using_mdal:
                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"])  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(lr_local)  # list of tuples
                        lens, rews, true_rets = map(flatten_lists, zip(*list_lr_pairs))
                        true_reward_buffer.extend(true_rets)
                    else:
                        lr_local = (seg["ep_lens"], seg["ep_rets"])  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(lr_local)  # list of tuples
                        lens, rews = map(flatten_lists, zip(*list_lr_pairs))
                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)

                    if len(len_buffer) > 0:
                        if self.using_gail or self.using_mdal:
                            logger.record_tabular("EpTrueRewMean", np.mean(true_reward_buffer))

                        logger.record_tabular("EpRewMean", np.mean(reward_buffer))
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))


                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)
                    logger.record_tabular("Tsallis-q", self.tsallis_q)
                    logger.record_tabular("steps", self.num_timesteps)
                    logger.record_tabular("seed", self.seed)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()
        callback.on_training_end()

        return self

    def save(self, save_path):
        if (self.using_gail or self.using_mdal) and self.expert_dataset is not None:
            # Exit processes to pickle the dataset
            self.expert_dataset.prepare_pickling()
        data = {
            "gamma": self.gamma,
            "timesteps_per_batch": self.timesteps_per_batch,
            "max_kl": self.max_kl,
            "cg_iters": self.cg_iters,
            "lam": self.lam,
            "entcoeff": self.entcoeff,
            "cg_damping": self.cg_damping,
            "vf_stepsize": self.vf_stepsize,
            "vf_iters": self.vf_iters,
            "hidden_size_adversary": self.hidden_size_adversary,
            "adversary_entcoeff": self.adversary_entcoeff,
            "expert_dataset": self.expert_dataset,
            "g_step": self.g_step,
            "d_step": self.d_step,
            "d_stepsize": self.d_stepsize,
            "using_gail": self.using_gail,
            "verbose": self.verbose,
            "policy": self.policy,
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "n_envs": self.n_envs,
            "_vectorize_action": self._vectorize_action,
            "policy_kwargs": self.policy_kwargs,
        }

        params_to_save = self.get_parameters()

        self._save_to_file(save_path, data=data, params=params_to_save)
Esempio n. 12
0
def general_actor_critic(input_shape_vec,
                         act_output_shape,
                         comm,
                         learn_rate=[0.001, 0.001],
                         trainable=True,
                         label=""):

    sess = K.get_session()
    np.random.seed(0)
    tf.set_random_seed(0)

    # network 1 (new policy)
    with tf.variable_scope(label + "_pi_new", reuse=False):
        inp = Input(shape=input_shape_vec)  # [5,6,3]
        # rc_lyr = Lambda(lambda x:  ned_to_ripCoords_tf(x, 4000))(inp)
        trunk_x = Reshape([input_shape_vec[0], input_shape_vec[1] * 3])(inp)
        trunk_x = LSTM(128)(trunk_x)
        dist, sample_action_op, action_ph, value_output = ppo_continuous(
            3, trunk_x)

    # network 2 (old policy)
    with tf.variable_scope(label + "_pi_old", reuse=False):
        inp_old = Input(shape=input_shape_vec)  # [5,6,3]
        # rc_lyr = Lambda(lambda x:  ned_to_ripCoords_tf(x, 4000))(inp_old)
        trunk_x = Reshape([input_shape_vec[0],
                           input_shape_vec[1] * 3])(inp_old)
        trunk_x = LSTM(128)(trunk_x)
        dist_old, sample_action_op_old, action_ph_old, value_output_old = ppo_continuous(
            3, trunk_x)

    # additional placeholders
    adv_ph = tf.placeholder(tf.float32, [None], name="advantages_ph")
    alpha_ph = tf.placeholder(tf.float32, shape=(), name="alpha_ph")
    vtarg = tf.placeholder(tf.float32, [None])  # target value placeholder

    # loss
    loss = ppo_continuous_loss(dist, dist_old, value_output, action_ph,
                               alpha_ph, adv_ph, vtarg)

    # gradient
    with tf.variable_scope("grad", reuse=False):
        gradient = tf_util.flatgrad(
            loss, tf_util.get_trainable_vars(label + "_pi_new"))
        adam = MpiAdam(tf_util.get_trainable_vars(label + "_pi_new"),
                       epsilon=0.00001,
                       sess=sess,
                       comm=comm)

    # method for sync'ing the two policies
    assign_old_eq_new = tf_util.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(tf_util.get_globals_vars(label + "_pi_old"),
                                  tf_util.get_globals_vars(label + "_pi_new"))
        ])

    # initialize all the things
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # methods for interacting with this model

    def sync_weights():
        assign_old_eq_new(sess=sess)

    def sample_action(states, logstd_override=None):
        a = sess.run(sample_action_op, feed_dict={inp: states})
        return a

    def sample_value(states):
        v = sess.run(value_output, feed_dict={inp: states})
        return v

    def train(states, actions, vtarget, advs, alpha):
        alpha = max(alpha, 0.0)
        adam_lr = learn_rate[0]

        g = sess.run(
            [gradient],
            feed_dict={
                inp: states,
                inp_old: states,
                action_ph: actions,
                adv_ph: advs,
                alpha_ph: alpha,
                vtarg: vtarget
            })

        adam.update(g[0], adam_lr * alpha)

    # initial sync
    adam.sync()
    sync_weights()

    return sync_weights, sample_action, sample_value, train