Beispiel #1
0
    def print(self):
        """Print a helpful report about the experiment grid."""
        print('='*DIV_LINE_WIDTH)

        # Prepare announcement at top of printing. If the ExperimentGrid has a
        # short name, write this as one line. If the name is long, break the
        # announcement over two lines.
        base_msg = 'ExperimentGrid %s runs over parameters:\n'
        name_insert = '['+self._name+']'
        if len(base_msg%name_insert) <= 80:
            msg = base_msg%name_insert
        else:
            msg = base_msg%(name_insert+'\n')
        print(colorize(msg, color='green', bold=True))

        # List off parameters, shorthands, and possible values.
        for k, v, sh in zip(self.keys, self.vals, self.shs):
            color_k = colorize(k.ljust(40), color='cyan', bold=True)
            print('', color_k, '['+sh+']' if sh is not None else '', '\n')
            for i, val in enumerate(v):
                print('\t' + str(convert_json(val)))
            print()

        # Count up the number of variants. The number counting seeds
        # is the total number of experiments that will run; the number not
        # counting seeds is the total number of otherwise-unique configs
        # being investigated.
        nvars_total = int(np.prod([len(v) for v in self.vals]))
        if 'seed' in self.keys:
            num_seeds = len(self.vals[self.keys.index('seed')])
            nvars_seedless = int(nvars_total / num_seeds)
        else:
            nvars_seedless = nvars_total
        print(' Variants, counting seeds: '.ljust(40), nvars_total)
        print(' Variants, not counting seeds: '.ljust(40), nvars_seedless)
        print()
        print('='*DIV_LINE_WIDTH)
Beispiel #2
0
def call_experiment(exp_name, thunk, seed=0, num_cpu=1, data_dir=None,
                    datestamp=False, **kwargs):
    """
    Run a function (thunk) with hyperparameters (kwargs), plus configuration.

    This wraps a few pieces of functionality which are useful when you want
    to run many experiments in sequence, including logger configuration and
    splitting into multiple processes for MPI.

    There's also a SpinningUp-specific convenience added into executing the
    thunk: if ``env_name`` is one of the kwargs passed to call_experiment, it's
    assumed that the thunk accepts an argument called ``env_fn``, and that
    the ``env_fn`` should make a gym environment with the given ``env_name``.

    The way the experiment is actually executed is slightly complicated: the
    function is serialized to a string, and then ``run_entrypoint.py`` is
    executed in a subprocess call with the serialized string as an argument.
    ``run_entrypoint.py`` unserializes the function call and executes it.
    We choose to do it this way---instead of just calling the function
    directly here---to avoid leaking state between successive experiments.

    Args:

        exp_name (string): Name for experiment.

        thunk (callable): A python function.

        seed (int): Seed for random number generators.

        num_cpu (int): Number of MPI processes to split into. Also accepts
            'auto', which will set up as many procs as there are cpus on
            the machine.

        data_dir (string): Used in configuring the logger, to decide where
            to store experiment results. Note: if left as None, data_dir will
            default to ``DEFAULT_DATA_DIR`` from ``spinup/user_config.py``.

        **kwargs: All kwargs to pass to thunk.

    """

    # Determine number of CPU cores to run on
    num_cpu = psutil.cpu_count(logical=False) if num_cpu=='auto' else num_cpu

    # Send random seed to thunk
    kwargs['seed'] = seed

    # Be friendly and print out your kwargs, so we all know what's up
    print(colorize('Running experiment:\n', color='cyan', bold=True))
    print(exp_name + '\n')
    print(colorize('with kwargs:\n', color='cyan', bold=True))
    kwargs_json = convert_json(kwargs)
    print(json.dumps(kwargs_json, separators=(',',':\t'), indent=4, sort_keys=True))
    print('\n')

    # Set up logger output directory
    if 'logger_kwargs' not in kwargs:
        kwargs['logger_kwargs'] = setup_logger_kwargs(exp_name, seed, data_dir, datestamp)
    else:
        print('Note: Call experiment is not handling logger_kwargs.\n')

    def thunk_plus():
        # Make 'env_fn' from 'env_name'
        if 'env_name' in kwargs:
            import gym
            env_name = kwargs['env_name']
            kwargs['env_fn'] = lambda : gym.make(env_name)
            del kwargs['env_name']

        # Fork into multiple processes
        mpi_fork(num_cpu)

        # Run thunk
        thunk(**kwargs)

    # Prepare to launch a script to run the experiment
    pickled_thunk = cloudpickle.dumps(thunk_plus)
    encoded_thunk = base64.b64encode(zlib.compress(pickled_thunk)).decode('utf-8')

    entrypoint = osp.join(osp.abspath(osp.dirname(__file__)),'run_entrypoint.py')
    cmd = [sys.executable if sys.executable else 'python', entrypoint, encoded_thunk]
    try:
        subprocess.check_call(cmd, env=os.environ)
    except CalledProcessError:
        err_msg = '\n'*3 + '='*DIV_LINE_WIDTH + '\n' + dedent("""

            There appears to have been an error in your experiment.

            Check the traceback above to see what actually went wrong. The
            traceback below, included for completeness (but probably not useful
            for diagnosing the error), shows the stack leading up to the
            experiment launch.

            """) + '='*DIV_LINE_WIDTH + '\n'*3
        print(err_msg)
        raise

    # Tell the user about where results are, and how to check them
    logger_kwargs = kwargs['logger_kwargs']

    plot_cmd = 'python3 -m spinup.run plot '+logger_kwargs['output_dir']
    plot_cmd = colorize(plot_cmd, 'green')

    test_cmd = 'python3 -m spinup.run test_policy '+logger_kwargs['output_dir']
    test_cmd = colorize(test_cmd, 'green')

    output_msg = '\n'*5 + '='*DIV_LINE_WIDTH +'\n' + dedent("""\
    End of experiment.


    Plot results from this run with:

    %s


    Watch the trained agent with:

    %s


    """%(plot_cmd,test_cmd)) + '='*DIV_LINE_WIDTH + '\n'*5

    print(output_msg)
Beispiel #3
0
    def run(self, thunk, num_cpu=1, data_dir=None, datestamp=False):
        """
        Run each variant in the grid with function 'thunk'.

        Note: 'thunk' must be either a callable function, or a string. If it is
        a string, it must be the name of a parameter whose values are all
        callable functions.

        Uses ``call_experiment`` to actually launch each experiment, and gives
        each variant a name using ``self.variant_name()``.

        Maintenance note: the args for ExperimentGrid.run should track closely
        to the args for call_experiment. However, ``seed`` is omitted because
        we presume the user may add it as a parameter in the grid.
        """

        # Print info about self.
        self.print()

        # Make the list of all variants.
        variants = self.variants()

        # Print variant names for the user.
        var_names = set([self.variant_name(var) for var in variants])
        var_names = sorted(list(var_names))
        line = '='*DIV_LINE_WIDTH
        preparing = colorize('Preparing to run the following experiments...',
                             color='green', bold=True)
        joined_var_names = '\n'.join(var_names)
        announcement = f"\n{preparing}\n\n{joined_var_names}\n\n{line}"
        print(announcement)


        if WAIT_BEFORE_LAUNCH > 0:
            delay_msg = colorize(dedent("""
            Launch delayed to give you a few seconds to review your experiments.

            To customize or disable this behavior, change WAIT_BEFORE_LAUNCH in
            spinup/user_config.py.

            """), color='cyan', bold=True)+line
            print(delay_msg)
            wait, steps = WAIT_BEFORE_LAUNCH, 100
            prog_bar = trange(steps, desc='Launching in...',
                              leave=False, ncols=DIV_LINE_WIDTH,
                              mininterval=0.25,
                              bar_format='{desc}: {bar}| {remaining} {elapsed}')
            for _ in prog_bar:
                time.sleep(wait/steps)

        # Run the variants.
        for var in variants:
            exp_name = self.variant_name(var)

            # Figure out what the thunk is.
            if isinstance(thunk, str):
                # Assume one of the variant parameters has the same
                # name as the string you passed for thunk, and that
                # variant[thunk] is a valid callable function.
                thunk_ = var[thunk]
                del var[thunk]
            else:
                # Assume thunk is given as a function.
                thunk_ = thunk

            call_experiment(exp_name, thunk_, num_cpu=num_cpu,
                            data_dir=data_dir, datestamp=datestamp, **var)
Beispiel #4
0
def ppo(
        env_fn,
        ac_kwargs=dict(),  # ac_kwargs 存储了网络结构的参数
        seed=0,
        steps_per_epoch=4000,
        epochs=50,
        gamma=0.99,
        lam=0.97,  # gamma, lambda 的设置
        clip_ratio=0.2,
        pi_lr=3e-4,
        vf_lr=1e-3,  # 学习率的设置
        train_pi_iters=80,
        train_v_iters=80,
        max_ep_len=1000,
        target_kl=0.01):

    env = env_fn()
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    tf.set_random_seed(seed)
    np.random.seed(seed)

    # Main outputs from computation graph
    actor_critic = ActorCritic(ac_kwargs["hidden_sizes"],
                               activation=tf.nn.tanh,
                               output_activation=None,
                               action_space=env.action_space)
    test_y = actor_critic.choose_action_prob(
        s=tf.convert_to_tensor(np.random.random((1, obs_dim)),
                               dtype=tf.float32),
        a=tf.convert_to_tensor(np.random.random((1, act_dim)),
                               dtype=tf.float32))

    # var counts
    var_counts = np.sum(
        [int(np.prod(v.shape)) for v in actor_critic.trainable_variables])
    print('\nNumber of parameters: ', var_counts)

    # Experience buffer
    buf = PPOBuffer(obs_dim, act_dim, steps_per_epoch, gamma, lam)

    # Optimizers
    actor_optimizer = tf.train.AdamOptimizer(learning_rate=pi_lr)
    critic_optimizer = tf.train.AdamOptimizer(learning_rate=vf_lr)

    # Reset
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    ep_ret_old, ep_len_old, best_rew = 0, 0, -10000.
    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        for t in range(steps_per_epoch):
            # a.shape=(1,6), logp_pi.shape=(1,), v_t.shape=(1,)
            a, _, logp_pi, v_t = actor_critic.choose_action_prob(
                tf.convert_to_tensor(o.reshape(1, -1), tf.float32), a=None)
            buf.store(
                o, a, r, v_t,
                logp_pi)  # save to buffer. shape 分别为 (1,6), None, (1,) (1,)

            o, r, d, _ = env.step(a[0])  # take action
            ep_ret += r
            ep_len += 1

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == steps_per_epoch - 1):
                if not terminal:
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)
                last_val = r if d else actor_critic.get_critic_output(
                    tf.convert_to_tensor(o.reshape(1, -1)), tf.float32)
                buf.finish_path(
                    last_val
                )  # calculate advantage function and discount return
                if terminal:
                    ep_ret_old = ep_ret
                    ep_len_old = ep_len

                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        print("----------------------", buf.ptr)
        # update actor
        stop_iter = False
        buf_tensors = [tf.convert_to_tensor(x) for x in buf.get()]
        obs, act, adv, ret, logp_old = buf_tensors
        # 只输出损失, 不更新, 做一个初始的记录
        pi_loss_old, _ = update_actor(buf_tensors,
                                      actor_critic,
                                      actor_optimizer,
                                      clip_ratio,
                                      update=False)
        for i in range(train_pi_iters):
            pi_loss, logp = update_actor(buf_tensors, actor_critic,
                                         actor_optimizer, clip_ratio)
            # for record
            kl = tf.reduce_mean(logp_old - logp)
            if kl > 1.5 * target_kl:
                print(
                    colorize("Early stopping at step " + str(i) +
                             " due to reaching max kl.",
                             color='green',
                             bold=True,
                             highlight=False))
                stop_iter = True
                break

        # update critic multiple times
        v_loss_old, _ = update_critic(buf_tensors,
                                      actor_critic,
                                      critic_optimizer,
                                      update=False)
        for i in range(train_v_iters):
            v_loss, _ = update_critic(buf_tensors, actor_critic,
                                      critic_optimizer)

        # Log info about actor
        pi_loss, logp = update_actor(buf_tensors,
                                     actor_critic,
                                     actor_optimizer,
                                     clip_ratio,
                                     update=False)
        kl = tf.reduce_mean(logp_old - logp)
        ratio = tf.exp(logp - logp_old)  # pi(a|s) / pi_old(a|s)
        clipped = tf.logical_or(ratio > (1 + clip_ratio), ratio <
                                (1 - clip_ratio))
        clip_frac = tf.reduce_mean(tf.cast(clipped, tf.float32))
        delta_loss_pi = pi_loss - pi_loss_old
        ent = tf.reduce_mean(-logp)
        # log info about critic
        v_loss, v = update_critic(buf_tensors,
                                  actor_critic,
                                  critic_optimizer,
                                  update=False)
        delta_loss_v = v_loss - v_loss_old

        print("\n\n---------------------------")
        print("Epoch: \t\t", epoch)
        print("EpRet: \t\t", ep_ret_old)
        print("EpLen: \t\t", ep_len_old)
        print("VVals: \t\t", np.mean(v.numpy()))
        print("TotalEnvInteracts: \t", (epoch + 1) * steps_per_epoch)
        print("LossPi: \t\t", pi_loss.numpy())
        print("LossV: \t\t", v_loss.numpy())
        print("DeltaLossPi: \t\t", delta_loss_pi.numpy())
        print("DeltaLossV: \t\t", delta_loss_v.numpy())
        print("Entropy: \t\t", ent.numpy())
        print("KL: \t\t", kl.numpy())
        print('ClipFrac:  \t\t', clip_frac.numpy())
        print("StopIter: \t\t", stop_iter)

        if ep_ret_old > best_rew:
            print("new best rewards:", ep_ret_old)
            actor_critic.save_weights("actor_critic.h5")
            best_rew = ep_ret_old
def ppo(env_fn,
        expert=None,
        policy_path=None,
        actor_critic=core.mlp_actor_critic_m,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=5000,
        epochs=10000,
        dagger_epochs=500,
        pretrain_epochs=50,
        gamma=0.99,
        clip_ratio=0.2,
        pi_lr=1e-4,
        dagger_noise=0.01,
        batch_size=64,
        replay_size=int(5e3),
        vf_lr=1e-4,
        train_pi_iters=80,
        train_v_iters=80,
        lam=0.999,
        max_ep_len=500,
        target_kl=0.01,
        logger_kwargs=dict(),
        save_freq=10,
        test_freq=10):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols 
            for state, ``x_ph``, and action, ``a_ph``, and returns the main 
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Samples actions from policy given 
                                           | states.
            ``logp``     (batch,)          | Gives log probability, according to
                                           | the policy, of taking actions ``a_ph``
                                           | in states ``x_ph``.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``.
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x_ph``. (Critical: make sure 
                                           | to flatten this!)
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
            function you provided to PPO.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.

        gamma (float): Discount factor. (Always between 0 and 1.)

        clip_ratio (float): Hyperparameter for clipping in the policy objective.
            Roughly: how far can the new policy go from the old policy while 
            still profiting (improving the objective function)? The new policy 
            can still go farther than the clip_ratio says, but it doesn't help
            on the objective anymore. (Usually small, 0.1 to 0.3.)

        pi_lr (float): Learning rate for policy optimizer.

        vf_lr (float): Learning rate for value function optimizer.

        train_pi_iters (int): Maximum number of gradient descent steps to take 
            on policy loss per epoch. (Early stopping may cause optimizer
            to take fewer than this.)

        train_v_iters (int): Number of gradient descent steps to take on 
            value function per epoch.

        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        target_kl (float): Roughly what KL divergence we think is appropriate
            between new and old policies after an update. This will get used 
            for early stopping. (Usually small, 0.01 or 0.05.)

        policy_path (str): path of pretrained policy model
            train from scratch if None

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())
    test_logger_kwargs = dict()
    test_logger_kwargs['output_dir'] = osp.join(logger_kwargs['output_dir'],
                                                "test")
    test_logger_kwargs['exp_name'] = logger_kwargs['exp_name']
    test_logger = EpochLogger(**test_logger_kwargs)
    test_logger.save_config(locals())

    seed += 10000 * proc_id()
    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space
    act_high_limit = env.action_space.high
    act_low_limit = env.action_space.low

    sess = tf.Session()
    if policy_path is None:
        # Inputs to computation graph
        x_ph, a_ph = core.placeholders_from_spaces(env.observation_space,
                                                   env.action_space)
        adv_ph, ret_ph, logp_old_ph = core.placeholders(None, None, None)
        tfa_ph = core.placeholder(act_dim)

        # Main outputs from computation graph
        mu, pi, logp, logp_pi, v = actor_critic(x_ph, a_ph, **ac_kwargs)
        sess.run(tf.global_variables_initializer())

    else:
        # load pretrained model
        # sess, x_ph, a_ph, mu, pi, logp, logp_pi, v = load_policy(policy_path, itr='last', deterministic=False, act_high=env.action_space.high)
        # # get_action_2 = lambda x : sess.run(mu, feed_dict={x_ph: x[None,:]})[0]
        # adv_ph, ret_ph, logp_old_ph = core.placeholders(None, None, None)
        model = restore_tf_graph(sess, osp.join(policy_path, 'simple_save'))
        x_ph, a_ph, adv_ph, ret_ph, logp_old_ph = model['x_ph'], model[
            'a_ph'], model['adv_ph'], model['ret_ph'], model['logp_old_ph']
        mu, pi, logp, logp_pi, v = model['mu'], model['pi'], model[
            'logp'], model['logp_pi'], model['v']
        # tfa_ph = core.placeholder(act_dim)
        tfa_ph = model['tfa_ph']

    # Need all placeholders in *this* order later (to zip with data from buffer)
    all_phs = [x_ph, a_ph, adv_ph, ret_ph, logp_old_ph]

    # Every step, get: action, value, and logprob
    get_action_ops = [pi, v, logp_pi]

    # Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    print("---------------", local_steps_per_epoch)
    buf = PPOBuffer(obs_dim, act_dim, steps_per_epoch, gamma, lam)
    # print(obs_dim)
    # print(act_dim)
    dagger_replay_buffer = DaggerReplayBuffer(obs_dim=obs_dim[0],
                                              act_dim=act_dim[0],
                                              size=replay_size)
    # Count variables
    var_counts = tuple(core.count_vars(scope) for scope in ['pi', 'v'])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # PPO objectives
    if policy_path is None:
        ratio = tf.exp(logp - logp_old_ph)  # pi(a|s) / pi_old(a|s)
        min_adv = tf.where(adv_ph > 0, (1 + clip_ratio) * adv_ph,
                           (1 - clip_ratio) * adv_ph)
        pi_loss = -tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
        v_loss = tf.reduce_mean((ret_ph - v)**2)
        dagger_pi_loss = tf.reduce_mean(tf.square(mu - tfa_ph))

        # Info (useful to watch during learning)
        approx_kl = tf.reduce_mean(
            logp_old_ph -
            logp)  # a sample estimate for KL-divergence, easy to compute
        approx_ent = tf.reduce_mean(
            -logp)  # a sample estimate for entropy, also easy to compute
        clipped = tf.logical_or(ratio > (1 + clip_ratio), ratio <
                                (1 - clip_ratio))
        clipfrac = tf.reduce_mean(tf.cast(clipped, tf.float32))

        # Optimizers
        dagger_pi_optimizer = tf.train.AdamOptimizer(learning_rate=pi_lr)
        optimizer_pi = tf.train.AdamOptimizer(learning_rate=pi_lr)
        optimizer_v = tf.train.AdamOptimizer(learning_rate=vf_lr)
        train_dagger_pi_op = dagger_pi_optimizer.minimize(
            dagger_pi_loss, name='train_dagger_pi_op')
        train_pi = optimizer_pi.minimize(pi_loss, name='train_pi_op')
        train_v = optimizer_v.minimize(v_loss, name='train_v_op')

        sess.run(tf.variables_initializer(optimizer_pi.variables()))
        sess.run(tf.variables_initializer(optimizer_v.variables()))
        sess.run(tf.variables_initializer(dagger_pi_optimizer.variables()))
    else:
        graph = tf.get_default_graph()
        dagger_pi_loss = model['dagger_pi_loss']
        pi_loss = model['pi_loss']
        v_loss = model['v_loss']
        approx_ent = model['approx_ent']
        approx_kl = model['approx_kl']
        clipfrac = model['clipfrac']

        train_dagger_pi_op = graph.get_operation_by_name('train_dagger_pi_op')
        train_pi = graph.get_operation_by_name('train_pi_op')
        train_v = graph.get_operation_by_name('train_v_op')
    # sess = tf.Session()
    # sess.run(tf.global_variables_initializer())

    # Sync params across processes
    # sess.run(sync_all_params())

    tf.summary.FileWriter("log/", sess.graph)
    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x_ph': x_ph, 'a_ph': a_ph, 'tfa_ph': tfa_ph, 'adv_ph': adv_ph, 'ret_ph': ret_ph, 'logp_old_ph': logp_old_ph}, \
        outputs={'mu': mu, 'pi': pi, 'v': v, 'logp': logp, 'logp_pi': logp_pi, 'clipfrac': clipfrac, 'approx_kl': approx_kl, \
            'pi_loss': pi_loss, 'v_loss': v_loss, 'dagger_pi_loss': dagger_pi_loss, 'approx_ent': approx_ent})

    def update():
        inputs = {k: v for k, v in zip(all_phs, buf.get())}
        pi_l_old, v_l_old, ent = sess.run([pi_loss, v_loss, approx_ent],
                                          feed_dict=inputs)

        # Training
        for i in range(train_pi_iters):
            _, kl = sess.run([train_pi, approx_kl], feed_dict=inputs)
            kl = mpi_avg(kl)
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break
        logger.store(StopIter=i)
        for _ in range(train_v_iters):
            sess.run(train_v, feed_dict=inputs)

        # Log changes from update
        pi_l_new, v_l_new, kl, cf = sess.run(
            [pi_loss, v_loss, approx_kl, clipfrac], feed_dict=inputs)
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

    def choose_action(s, add_noise=False):
        s = s[np.newaxis, :]
        a = sess.run(mu, {x_ph: s})[0]
        if add_noise:
            noise = dagger_noise * act_high_limit * np.random.normal(
                size=a.shape)
            a = a + noise
        return np.clip(a, act_low_limit, act_high_limit)

    def test_agent(n=81, test_num=1):
        n = env.unwrapped._set_test_mode(True)
        con_flag = False
        for j in range(n):
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                o, r, d, info = env.step(choose_action(np.array(o), 0))
                ep_ret += r
                ep_len += 1
                if d:
                    test_logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
                    test_logger.store(arrive_des=info['arrive_des'])
                    test_logger.store(
                        arrive_des_appro=info['arrive_des_appro'])
                    if not info['out_of_range']:
                        test_logger.store(converge_dis=info['converge_dis'])
                        con_flag = True
                    test_logger.store(out_of_range=info['out_of_range'])
                    # print(info)
        # test_logger.dump_tabular()
        # time.sleep(10)
        if not con_flag:
            test_logger.store(converge_dis=10000)
        env.unwrapped._set_test_mode(False)

    def ref_test_agent(n=81, test_num=1):
        n = env.unwrapped._set_test_mode(True)
        con_flag = False
        for j in range(n):
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                a = call_ref_controller(env, expert)
                o, r, d, info = env.step(a)
                ep_ret += r
                ep_len += 1
                if d:
                    test_logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
                    test_logger.store(arrive_des=info['arrive_des'])
                    test_logger.store(
                        arrive_des_appro=info['arrive_des_appro'])
                    if not info['out_of_range']:
                        test_logger.store(converge_dis=info['converge_dis'])
                        con_flag = True
                    test_logger.store(out_of_range=info['out_of_range'])
                    # print(info)
        # test_logger.dump_tabular()
        if not con_flag:
            test_logger.store(converge_dis=10000)
        env.unwrapped._set_test_mode(False)

    ref_test_agent(test_num=-1)
    test_logger.log_tabular('epoch', -1)
    test_logger.log_tabular('TestEpRet', average_only=True)
    test_logger.log_tabular('TestEpLen', average_only=True)
    test_logger.log_tabular('arrive_des', average_only=True)
    test_logger.log_tabular('arrive_des_appro', average_only=True)
    test_logger.log_tabular('converge_dis', average_only=True)
    test_logger.log_tabular('out_of_range', average_only=True)
    test_logger.dump_tabular()

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    test_policy_epochs = 91
    episode_steps = 500
    total_env_t = 0
    test_num = 0
    print(colorize("begin dagger training", 'green', bold=True))
    for epoch in range(1, dagger_epochs + 1, 1):
        # test policy
        if epoch > 0 and (epoch % save_freq == 0) or (epoch == epochs):
            # Save model
            logger.save_state({}, None)

            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)

            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            test_logger.log_tabular('arrive_des_appro', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()

        # train policy
        o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        env.unwrapped._set_test_mode(False)
        obs, acs, rewards = [], [], []
        for t in range(local_steps_per_epoch):
            a, v_t, logp_t = sess.run(
                get_action_ops, feed_dict={x_ph: np.array(o).reshape(1, -1)})
            # a = get_action_2(np.array(o))
            # save and log
            obs.append(o)
            ref_action = call_ref_controller(env, expert)
            if (epoch < pretrain_epochs):
                action = ref_action
            else:
                action = choose_action(np.array(o), True)

            buf.store(o, action, r, v_t, logp_t)
            logger.store(VVals=v_t)

            o, r, d, _ = env.step(action)
            acs.append(ref_action)
            rewards.append(r)

            ep_ret += r
            ep_len += 1
            total_env_t += 1

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == local_steps_per_epoch - 1):
                if not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)
                # if trajectory didn't reach terminal state, bootstrap value target
                last_val = r if d else sess.run(
                    v, feed_dict={x_ph: np.array(o).reshape(1, -1)})
                buf.finish_path(last_val)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # Perform dagger and partical PPO update!
        inputs = {k: v for k, v in zip(all_phs, buf.get())}
        # pi_l_old, v_l_old, ent = sess.run([pi_loss, v_loss, approx_ent], feed_dict=inputs)
        for _ in range(train_v_iters):
            sess.run(train_v, feed_dict=inputs)

        # Log changes from update
        max_step = len(np.array(rewards))
        dagger_replay_buffer.stores(obs, acs, rewards)
        for _ in range(int(local_steps_per_epoch / 10)):
            batch = dagger_replay_buffer.sample_batch(batch_size)
            feed_dict = {x_ph: batch['obs1'], tfa_ph: batch['acts']}
            q_step_ops = [dagger_pi_loss, train_dagger_pi_op]
            for j in range(10):
                outs = sess.run(q_step_ops, feed_dict)
            logger.store(LossPi=outs[0])

        c_v_loss = sess.run(v_loss, feed_dict=inputs)
        logger.store(LossV=c_v_loss,
                     KL=0,
                     Entropy=0,
                     ClipFrac=0,
                     DeltaLossPi=0,
                     DeltaLossV=0,
                     StopIter=0)

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()

    # Main loop: collect experience in env and update/log each epoch
    print(colorize("begin ppo training", 'green', bold=True))
    for epoch in range(1, epochs + 1, 1):
        # test policy
        if epoch > 0 and (epoch % save_freq == 0) or (epoch
                                                      == epochs) or epoch == 1:
            # Save model
            logger.save_state({}, None)

            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)

            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            test_logger.log_tabular('arrive_des_appro', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()

        # train policy
        o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        env.unwrapped._set_test_mode(False)
        for t in range(local_steps_per_epoch):
            a, v_t, logp_t = sess.run(
                get_action_ops, feed_dict={x_ph: np.array(o).reshape(1, -1)})
            # a = a[0]
            # a = get_action_2(np.array(o))
            # a = np.clip(a, act_low_limit, act_high_limit)
            # if epoch < pretrain_epochs:
            #     a = env.action_space.sample()
            # a = np.clip(a, act_low_limit, act_high_limit)
            # save and log
            buf.store(o, a, r, v_t, logp_t)
            logger.store(VVals=v_t)

            o, r, d, _ = env.step(a[0])
            ep_ret += r
            ep_len += 1

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == local_steps_per_epoch - 1):
                if not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)
                # if trajectory didn't reach terminal state, bootstrap value target
                last_val = r if d else sess.run(
                    v, feed_dict={x_ph: np.array(o).reshape(1, -1)})
                buf.finish_path(last_val)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # Perform PPO update!
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()
def sac(env_fn,  expert=None, policy_path=None, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0, 
        steps_per_epoch=500, epochs=100000, replay_size=int(5e3), gamma=0.99, 
        dagger_noise=0.02, polyak=0.995, lr=1e-4, alpha=0.2, batch_size=64, dagger_epochs=200, pretrain_epochs=50,
        max_ep_len=500, logger_kwargs=dict(), save_freq=50, update_steps=10):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols 
            for state, ``x_ph``, and action, ``a_ph``, and returns the main 
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``mu``       (batch, act_dim)  | Computes mean actions from policy
                                           | given states.
            ``pi``       (batch, act_dim)  | Samples actions from policy given 
                                           | states.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``. Critical: must be differentiable
                                           | with respect to policy parameters all
                                           | the way through action sampling.
            ``q1``       (batch,)          | Gives one estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``q2``       (batch,)          | Gives another estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``q1_pi``    (batch,)          | Gives the composition of ``q1`` and 
                                           | ``pi`` for states in ``x_ph``: 
                                           | q1(x, pi(x)).
            ``q2_pi``    (batch,)          | Gives the composition of ``q2`` and 
                                           | ``pi`` for states in ``x_ph``: 
                                           | q2(x, pi(x)).
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x_ph``. 
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
            function you provided to SAC.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target 
            networks. Target networks are updated towards main networks 
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow 
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually 
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        alpha (float): Entropy regularization coefficient. (Equivalent to 
            inverse of reward scale in the original SAC paper.)

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())
    test_logger_kwargs = dict()
    test_logger_kwargs['output_dir'] = osp.join(logger_kwargs['output_dir'], "test")
    test_logger_kwargs['exp_name'] = logger_kwargs['exp_name']
    test_logger = EpochLogger(**test_logger_kwargs)

    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    print(obs_dim)
    print(act_dim)
    
    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space
    act_high_limit = env.action_space.high
    act_low_limit = env.action_space.low

    sess = tf.Session()
    if policy_path is None:
        # Inputs to computation graph
        x_ph, a_ph, x2_ph, r_ph, d_ph = core.placeholders(obs_dim, act_dim, obs_dim, None, None)
        tfa_ph = core.placeholder(act_dim)
        # Main outputs from computation graph
        with tf.variable_scope('main'):
            mu, pi, logp_pi, q1, q2, q1_pi, q2_pi, v = actor_critic(x_ph, a_ph, **ac_kwargs)
        
        # Target value network
        with tf.variable_scope('target'):
            _, _, _, _, _, _, _, v_targ  = actor_critic(x2_ph, a_ph, **ac_kwargs)
        # sess.run(tf.global_variables_initializer())
    
    else:
        # load pretrained model
        model = restore_tf_graph(sess, osp.join(policy_path, 'simple_save'))
        x_ph, a_ph, x2_ph, r_ph, d_ph = model['x_ph'], model['a_ph'], model['x2_ph'], model['r_ph'], model['d_ph']
        mu, pi, logp_pi, q1, q2, q1_pi, q2_pi, v = model['mu'], model['pi'], model['logp_pi'], model['q1'], model['q2'], model['q1_pi'], model['q2_pi'], model['v']
        # tfa_ph = core.placeholder(act_dim)
        tfa_ph = model['tfa_ph']

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)
    dagger_replay_buffer = DaggerReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)
    # Count variables
    var_counts = tuple(core.count_vars(scope) for scope in 
                       ['main/pi', 'main/q1', 'main/q2', 'main/v', 'main'])
    print(('\nNumber of parameters: \t pi: %d, \t' + \
           'q1: %d, \t q2: %d, \t v: %d, \t total: %d\n')%var_counts)


    # print(obs_dim)
    # print(act_dim)

    # SAC objectives
    if policy_path is None:
        # Min Double-Q:
        min_q_pi = tf.minimum(q1_pi, q2_pi)

        # Targets for Q and V regression
        q_backup = tf.stop_gradient(r_ph + gamma*(1-d_ph)*v_targ)
        v_backup = tf.stop_gradient(min_q_pi - alpha * logp_pi)

        # Soft actor-critic losses
        dagger_pi_loss = tf.reduce_mean(tf.square(mu-tfa_ph))
        pi_loss = tf.reduce_mean(alpha * logp_pi - q1_pi)
        q1_loss = 0.5 * tf.reduce_mean((q_backup - q1)**2)
        q2_loss = 0.5 * tf.reduce_mean((q_backup - q2)**2)
        v_loss = 0.5 * tf.reduce_mean((v_backup - v)**2)
        value_loss = q1_loss + q2_loss + v_loss

        # Policy train op 
        # (has to be separate from value train op, because q1_pi appears in pi_loss)
        dagger_pi_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        train_dagger_pi_op = dagger_pi_optimizer.minimize(dagger_pi_loss, name='train_dagger_pi_op')

        pi_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        train_pi_op = pi_optimizer.minimize(pi_loss, var_list=get_vars('main/pi'), name='train_pi_op')
        # sess.run(tf.variables_initializer(pi_optimizer.variables()))

        # Value train op
        # (control dep of train_pi_op because sess.run otherwise evaluates in nondeterministic order)
        value_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        value_params = get_vars('main/q') + get_vars('main/v')
        with tf.control_dependencies([train_pi_op]):
            train_value_op = value_optimizer.minimize(value_loss, var_list=value_params, name='train_value_op')
            # sess.run(tf.variables_initializer(value_optimizer.variables()))

        # Polyak averaging for target variables
        # (control flow because sess.run otherwise evaluates in nondeterministic order)
        with tf.control_dependencies([train_value_op]):
            target_update = tf.group([tf.assign(v_targ, polyak*v_targ + (1-polyak)*v_main)
                                    for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])

        # All ops to call during one training step
        step_ops = [pi_loss, q1_loss, q2_loss, v_loss, q1, q2, v, logp_pi, 
                    train_pi_op, train_value_op, target_update]

        # Initializing targets to match main variables
        target_init = tf.group([tf.assign(v_targ, v_main)
                                for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])
        sess.run(tf.global_variables_initializer())
    else:
        graph = tf.get_default_graph()
        dagger_pi_loss = model['dagger_pi_loss']
        pi_loss = model['pi_loss']
        q1_loss = model['q1_loss']
        q2_loss = model['q2_loss']        
        v_loss = model['v_loss']

        train_dagger_pi_op = graph.get_operation_by_name('train_dagger_pi_op')
        train_value_op = graph.get_operation_by_name('train_value_op')
        train_pi_op = graph.get_operation_by_name('train_pi_op')
        
        # Polyak averaging for target variables
        # (control flow because sess.run otherwise evaluates in nondeterministic order)
        with tf.control_dependencies([train_value_op]):
            target_update = tf.group([tf.assign(v_targ, polyak*v_targ + (1-polyak)*v_main)
                                    for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])

        # All ops to call during one training step
        step_ops = [pi_loss, q1_loss, q2_loss, v_loss, q1, q2, v, logp_pi, 
                    train_pi_op, train_value_op, target_update]

        # Initializing targets to match main variables
        target_init = tf.group([tf.assign(v_targ, v_main)
                                for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])
    # sess = tf.Session()
    # sess.run(tf.global_variables_initializer())
    dagger_step_ops = [q1_loss, q2_loss, v_loss, q1, q2, v, logp_pi, train_value_op, target_update]
    tf.summary.FileWriter("log/", sess.graph)
    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x_ph': x_ph, 'a_ph': a_ph, 'tfa_ph': tfa_ph, 'x2_ph': x2_ph, 'r_ph': r_ph, 'd_ph': d_ph}, \
        outputs={'mu': mu, 'pi': pi, 'v': v, 'logp_pi': logp_pi, 'q1': q1, 'q2': q2, 'q1_pi': q1_pi, 'q2_pi': q2_pi, \
            'pi_loss': pi_loss, 'v_loss': v_loss, 'dagger_pi_loss': dagger_pi_loss, 'q1_loss': q1_loss, 'q2_loss': q2_loss})
    
    def get_action(o, deterministic=False):
        act_op = mu if deterministic else pi
        a = sess.run(act_op, feed_dict={x_ph: o.reshape(1,-1)})[0]
        return np.clip(a, act_low_limit, act_high_limit)

    def choose_action(s, add_noise=False):
        s = s[np.newaxis, :]
        a = sess.run(mu, {x_ph: s})[0]
        if add_noise:
            noise = dagger_noise * act_high_limit * np.random.normal(size=a.shape)
            a = a + noise
        return np.clip(a, act_low_limit, act_high_limit)

    def test_agent(n=81, test_num=1):
        n = env.unwrapped._set_test_mode(True)
        con_flag = False
        for j in range(n):
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            while not(d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                o, r, d, info = env.step(choose_action(np.array(o), 0))
                ep_ret += r
                ep_len += 1
                if d:
                    test_logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
                    test_logger.store(arrive_des=info['arrive_des'])
                    test_logger.store(arrive_des_appro=info['arrive_des_appro'])
                    if not info['out_of_range']:
                        test_logger.store(converge_dis=info['converge_dis'])
                        con_flag = True
                    test_logger.store(out_of_range=info['out_of_range'])
                    # print(info)
        # test_logger.dump_tabular()
        # time.sleep(10)
        if not con_flag:
            test_logger.store(converge_dis=10000)
        env.unwrapped._set_test_mode(False)

    def ref_test_agent(n=81, test_num=1):
        n = env.unwrapped._set_test_mode(True)
        con_flag = False
        for j in range(n):
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            while not(d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                a  = call_ref_controller(env, expert)
                o, r, d, info = env.step(a)
                ep_ret += r
                ep_len += 1
                if d:
                    test_logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
                    test_logger.store(arrive_des=info['arrive_des'])
                    test_logger.store(arrive_des_appro=info['arrive_des_appro'])
                    if not info['out_of_range']:
                        test_logger.store(converge_dis=info['converge_dis'])
                        con_flag = True
                    test_logger.store(out_of_range=info['out_of_range'])
                    # print(info)
        # test_logger.dump_tabular()
        if not con_flag:
            test_logger.store(converge_dis=10000)
        env.unwrapped._set_test_mode(False)

    # ref_test_agent(test_num = -1)
    # test_logger.log_tabular('epoch', -1)
    # test_logger.log_tabular('TestEpRet', average_only=True)
    # test_logger.log_tabular('TestEpLen', average_only=True)
    # test_logger.log_tabular('arrive_des', average_only=True)
    # test_logger.log_tabular('arrive_des_appro', average_only=True)
    # test_logger.log_tabular('converge_dis', average_only=True)
    # test_logger.log_tabular('out_of_range', average_only=True)
    # test_logger.dump_tabular()



    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    episode_steps = 500
    total_env_t = 0
    test_num = 0
    print(colorize("begin dagger training", 'green', bold=True))
    for epoch in range(1, dagger_epochs + 1, 1):
        # test policy
        if epoch > 0 and (epoch % save_freq == 0) or (epoch == epochs):
            # Save model
            logger.save_state({}, None)
            
            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)
            
            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            test_logger.log_tabular('arrive_des_appro', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True) 
            logger.log_tabular('Q2Vals', with_min_and_max=True) 
            logger.log_tabular('VVals', with_min_and_max=True) 
            logger.log_tabular('LogPi', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ1', average_only=True)
            logger.log_tabular('LossQ2', average_only=True)
            logger.log_tabular('LossV', average_only=True)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()

        # train policy
        o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        env.unwrapped._set_test_mode(False)
        obs, acs, rewards = [], [], []
        for t in range(steps_per_epoch):
            obs.append(o)
            ref_action = call_ref_controller(env, expert)
            if(epoch < pretrain_epochs):
                action = ref_action
            else:
                action = choose_action(np.array(o), True)
            
            o2, r, d, _ = env.step(action)
            o = o2
            acs.append(ref_action)
            rewards.append(r)

            if (t == steps_per_epoch-1):
                # print ("reached the end")
                d = True

            # Store experience to replay buffer
            replay_buffer.store(o, action, r, o2, d)

            ep_ret += r
            ep_len += 1
            total_env_t += 1

            if d:
                # Perform partical sac update!
                for j in range(ep_len):
                    batch = replay_buffer.sample_batch(batch_size)
                    feed_dict = {x_ph: batch['obs1'],
                                x2_ph: batch['obs2'],
                                a_ph: batch['acts'],
                                r_ph: batch['rews'],
                                d_ph: batch['done'],
                                }
                    outs = sess.run(dagger_step_ops, feed_dict)
                    logger.store(LossQ1=outs[0], LossQ2=outs[1],
                                LossV=outs[2], Q1Vals=outs[3], Q2Vals=outs[4],
                                VVals=outs[5], LogPi=outs[6])

                # Perform dagger policy update
                dagger_replay_buffer.stores(obs, acs, rewards)
                for _ in range(int(ep_len/5)):
                    batch = dagger_replay_buffer.sample_batch(batch_size)
                    feed_dict = {x_ph: batch['obs1'], tfa_ph: batch['acts']}
                    q_step_ops = [dagger_pi_loss, train_dagger_pi_op]
                    for j in range(10):
                        outs = sess.run(q_step_ops, feed_dict)
                    logger.store(LossPi = outs[0])

                logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
                break

    # Main loop: collect experience in env and update/log each epoch
    print(colorize("begin sac training", 'green', bold=True))
    for epoch in range(1, epochs + 1, 1):
        # test policy
        if epoch > 0 and (epoch % save_freq == 0) or (epoch == epochs):
            # Save model
            logger.save_state({}, None)
            
            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)
            
            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            # test_logger.log_tabular('arrive_des_appro', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('VVals', with_min_and_max=True)
            logger.log_tabular('TotalEnvInteracts', (epoch+1)*steps_per_epoch)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossV', average_only=True)
            # logger.log_tabular('DeltaLossPi', average_only=True)
            # logger.log_tabular('DeltaLossV', average_only=True)
            # logger.log_tabular('Entropy', average_only=True)
            # logger.log_tabular('KL', average_only=True)
            # logger.log_tabular('ClipFrac', average_only=True)
            # logger.log_tabular('StopIter', average_only=True)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()

        # train policy
        o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        env.unwrapped._set_test_mode(False)
        for t in range(steps_per_epoch):
            a = get_action(np.array(o))

            o2, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1
            if (t == steps_per_epoch-1):
                # print ("reached the end")
                d = True

            replay_buffer.store(o, a, r, o2, d)
            o = o2
            if d:
                """
                Perform all SAC updates at the end of the trajectory.
                This is a slight difference from the SAC specified in the
                original paper.
                """
                for j in range(ep_len):
                    batch = replay_buffer.sample_batch(batch_size)
                    feed_dict = {x_ph: batch['obs1'],
                                x2_ph: batch['obs2'],
                                a_ph: batch['acts'],
                                r_ph: batch['rews'],
                                d_ph: batch['done'],
                                }
                    outs = sess.run(step_ops, feed_dict)
                    logger.store(LossPi=outs[0], LossQ1=outs[1], LossQ2=outs[2],
                                LossV=outs[3], Q1Vals=outs[4], Q2Vals=outs[5],
                                VVals=outs[6], LogPi=outs[7])

                logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
def td3(env_fn,
        expert=None,
        policy_path=None,
        actor_critic=core.mlp_actor_critic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=500,
        epochs=1000,
        replay_size=int(5e3),
        gamma=0.99,
        polyak=0.995,
        pi_lr=1e-4,
        q_lr=1e-4,
        batch_size=64,
        start_epochs=500,
        dagger_epochs=500,
        pretrain_epochs=50,
        dagger_noise=0.02,
        act_noise=0.02,
        target_noise=0.02,
        noise_clip=0.5,
        policy_delay=2,
        max_ep_len=500,
        logger_kwargs=dict(),
        save_freq=50,
        UPDATE_STEP=10):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols 
            for state, ``x_ph``, and action, ``a_ph``, and returns the main 
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Deterministically computes actions
                                           | from policy given states.
            ``q1``       (batch,)          | Gives one estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``q2``       (batch,)          | Gives another estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``q1_pi``    (batch,)          | Gives the composition of ``q1`` and 
                                           | ``pi`` for states in ``x_ph``: 
                                           | q1(x, pi(x)).
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
            function you provided to TD3.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target 
            networks. Target networks are updated towards main networks 
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow 
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually 
            close to 1.)

        pi_lr (float): Learning rate for policy.

        q_lr (float): Learning rate for Q-networks.

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        act_noise (float): Stddev for Gaussian exploration noise added to 
            policy at training time. (At test time, no noise is added.)

        target_noise (float): Stddev for smoothing noise added to target 
            policy.

        noise_clip (float): Limit for absolute value of target policy 
            smoothing noise.

        policy_delay (int): Policy will only be updated once every 
            policy_delay times for each update of the Q-networks.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())
    test_logger_kwargs = dict()
    test_logger_kwargs['output_dir'] = osp.join(logger_kwargs['output_dir'],
                                                "test")
    test_logger_kwargs['exp_name'] = logger_kwargs['exp_name']
    test_logger = EpochLogger(**test_logger_kwargs)

    # test_logger_kwargs = dict()
    # test_logger_kwargs['output_dir'] = osp.join(logger_kwargs['output_dir'], "test")
    # test_logger_kwargs['exp_name'] = logger_kwargs['exp_name']
    # test_logger = EpochLogger(**test_logger_kwargs)

    # pretrain_logger_kwargs = dict()
    # pretrain_logger_kwargs['output_dir'] = osp.join(logger_kwargs['output_dir'], "pretrain")
    # pretrain_logger_kwargs['exp_name'] = logger_kwargs['exp_name']
    # pretrain_logger = EpochLogger(**pretrain_logger_kwargs)

    tf.set_random_seed(seed)
    np.random.seed(seed)

    env, test_env = env_fn(), env_fn()
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    # Action limit for clamping: critically, do not assumes all dimensions share the same bound!
    act_limit = env.action_space.high / 2
    act_high_limit = env.action_space.high
    act_low_limit = env.action_space.low

    act_noise_limit = act_noise * act_limit
    sess = tf.Session()
    if policy_path is None:
        # Share information about action space with policy architecture
        ac_kwargs['action_space'] = env.action_space

        # Inputs to computation graph
        x_ph, a_ph, x2_ph, r_ph, d_ph = core.placeholders(
            obs_dim, act_dim, obs_dim, None, None)
        tfa_ph = core.placeholder(act_dim)

        # Main outputs from computation graph
        with tf.variable_scope('main'):
            pi, q1, q2, q1_pi = actor_critic(x_ph, a_ph, **ac_kwargs)

        # Target policy network
        with tf.variable_scope('target'):
            pi_targ, _, _, _ = actor_critic(x2_ph, a_ph, **ac_kwargs)

        # Target Q networks
        with tf.variable_scope('target', reuse=True):

            # Target policy smoothing, by adding clipped noise to target actions
            epsilon = tf.random_normal(tf.shape(pi_targ), stddev=target_noise)
            epsilon = tf.clip_by_value(epsilon, -noise_clip, noise_clip)
            a2 = pi_targ + epsilon
            a2 = tf.clip_by_value(a2, act_low_limit, act_high_limit)

            # Target Q-values, using action from target policy
            _, q1_targ, q2_targ, _ = actor_critic(x2_ph, a2, **ac_kwargs)

    else:
        # sess = tf.Session()
        model = restore_tf_graph(sess, osp.join(policy_path, 'simple_save'))
        x_ph, a_ph, x2_ph, r_ph, d_ph = model['x_ph'], model['a_ph'], model[
            'x2_ph'], model['r_ph'], model['d_ph']
        pi, q1, q2, q1_pi = model['pi'], model['q1'], model['q2'], model[
            'q1_pi']
        pi_targ, q1_targ, q2_targ = model['pi_targ'], model['q1_targ'], model[
            'q2_targ']
        tfa_ph = core.placeholder(act_dim)
        dagger_epochs = 0

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim,
                                 act_dim=act_dim,
                                 size=replay_size)
    dagger_replay_buffer = DaggerReplayBuffer(obs_dim=obs_dim,
                                              act_dim=act_dim,
                                              size=replay_size)
    # Count variables
    var_counts = tuple(
        core.count_vars(scope)
        for scope in ['main/pi', 'main/q1', 'main/q2', 'main'])
    print(
        '\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d, \t total: %d\n'
        % var_counts)

    if policy_path is None:
        # Bellman backup for Q functions, using Clipped Double-Q targets
        min_q_targ = tf.minimum(q1_targ, q2_targ)
        backup = tf.stop_gradient(r_ph + gamma * (1 - d_ph) * min_q_targ)

        # dagger loss
        dagger_pi_loss = tf.reduce_mean(tf.square(pi - tfa_ph))
        # TD3 losses
        pi_loss = -tf.reduce_mean(q1_pi)
        q1_loss = tf.reduce_mean((q1 - backup)**2)
        q2_loss = tf.reduce_mean((q2 - backup)**2)
        q_loss = tf.add(q1_loss, q2_loss)
        pi_loss = tf.identity(pi_loss, name="pi_loss")
        q1_loss = tf.identity(q1_loss, name="q1_loss")
        q2_loss = tf.identity(q2_loss, name="q2_loss")
        q_loss = tf.identity(q_loss, name="q_loss")

        # Separate train ops for pi, q
        dagger_pi_optimizer = tf.train.AdamOptimizer(learning_rate=pi_lr)
        pi_optimizer = tf.train.AdamOptimizer(learning_rate=pi_lr)
        q_optimizer = tf.train.AdamOptimizer(learning_rate=q_lr)
        train_dagger_pi_op = dagger_pi_optimizer.minimize(
            dagger_pi_loss,
            var_list=get_vars('main/pi'),
            name='train_dagger_pi_op')
        train_pi_op = pi_optimizer.minimize(pi_loss,
                                            var_list=get_vars('main/pi'),
                                            name='train_pi_op')
        train_q_op = q_optimizer.minimize(q_loss,
                                          var_list=get_vars('main/q'),
                                          name='train_q_op')

        # Polyak averaging for target variables
        target_update = tf.group([
            tf.assign(v_targ, polyak * v_targ + (1 - polyak) * v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])

        # Initializing targets to match main variables
        target_init = tf.group([
            tf.assign(v_targ, v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])
        sess.run(tf.global_variables_initializer())
    else:
        graph = tf.get_default_graph()
        # opts = graph.get_operations()
        # print (opts)
        pi_loss = model['pi_loss']
        q1_loss = model['q1_loss']
        q2_loss = model['q2_loss']
        q_loss = model['q_loss']
        train_q_op = graph.get_operation_by_name('train_q_op')
        train_pi_op = graph.get_operation_by_name('train_pi_op')
        # target_update = graph.get_operation_by_name('target_update')
        # target_init = graph.get_operation_by_name('target_init')
        # Polyak averaging for target variables
        target_update = tf.group([
            tf.assign(v_targ, polyak * v_targ + (1 - polyak) * v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])

        # Initializing targets to match main variables
        target_init = tf.group([
            tf.assign(v_targ, v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])

    # sess = tf.Session()
    # sess.run(tf.global_variables_initializer())
    sess.run(target_init)

    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x_ph': x_ph, 'a_ph': a_ph, 'x2_ph': x2_ph, 'r_ph': r_ph, 'd_ph': d_ph}, \
         outputs={'pi': pi, 'q1': q1, 'q2': q2, 'q1_pi': q1_pi, 'pi_targ': pi_targ, 'q1_targ': q1_targ, 'q2_targ': q2_targ, \
             'pi_loss': pi_loss, 'q1_loss': q1_loss, 'q2_loss': q2_loss, 'q_loss': q_loss})

    def get_action(o, noise_scale):
        a = sess.run(pi, feed_dict={x_ph: o.reshape(1, -1)})[0]
        # todo: add act_limit scale noise
        a += noise_scale * np.random.randn(act_dim)
        return np.clip(a, act_low_limit, act_high_limit)

    def choose_action(s, add_noise=False):
        s = s[np.newaxis, :]
        a = sess.run(pi, {x_ph: s})[0]
        if add_noise:
            noise = dagger_noise * act_high_limit * np.random.normal(
                size=a.shape)
            a = a + noise
        return np.clip(a, act_low_limit, act_high_limit)

    def test_agent(n=81, test_num=1):
        n = env.unwrapped._set_test_mode(True)
        con_flag = False
        for j in range(n):
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                o, r, d, info = env.step(choose_action(np.array(o), 0))
                ep_ret += r
                ep_len += 1
                if d:
                    test_logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
                    test_logger.store(arrive_des=info['arrive_des'])
                    test_logger.store(
                        arrive_des_appro=info['arrive_des_appro'])
                    if not info['out_of_range']:
                        test_logger.store(converge_dis=info['converge_dis'])
                        con_flag = True
                    test_logger.store(out_of_range=info['out_of_range'])
                    # print(info)
        # test_logger.dump_tabular()
        # time.sleep(10)
        if not con_flag:
            test_logger.store(converge_dis=10000)
        env.unwrapped._set_test_mode(False)

    start_time = time.time()
    env.unwrapped._set_test_mode(False)
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    total_steps = steps_per_epoch * epochs
    test_num = 0

    total_env_t = 0
    print(colorize("begin dagger training", 'green', bold=True))
    # Main loop for dagger pretrain
    for epoch in range(1, dagger_epochs + 1, 1):
        obs, acs, rewards = [], [], []
        # number of timesteps
        for t in range(steps_per_epoch):
            # action = env.action_space.sample()
            # action = ppo.choose_action(np.array(observation))
            obs.append(o)
            ref_action = call_ref_controller(env, expert)
            if (epoch < pretrain_epochs):
                action = ref_action
            else:
                action = choose_action(np.array(o), True)

            o2, r, d, info = env.step(action)
            ep_ret += r
            ep_len += 1
            total_env_t += 1

            acs.append(ref_action)
            rewards.append(r)
            # Store experience to replay buffer
            replay_buffer.store(o, action, r, o2, d)

            o = o2

            if (t == steps_per_epoch - 1):
                # print ("reached the end")
                d = True

            if d:
                # collected data to replaybuffer
                max_step = len(np.array(rewards))
                q = [
                    np.sum(
                        np.power(gamma, np.arange(max_step - t)) * rewards[t:])
                    for t in range(max_step)
                ]
                dagger_replay_buffer.stores(obs, acs, rewards, q)

                # update policy
                for _ in range(int(max_step / 5)):
                    batch = dagger_replay_buffer.sample_batch(batch_size)
                    feed_dict = {x_ph: batch['obs1'], tfa_ph: batch['acts']}
                    q_step_ops = [dagger_pi_loss, train_dagger_pi_op]
                    for j in range(UPDATE_STEP):
                        outs = sess.run(q_step_ops, feed_dict)
                    logger.store(LossPi=outs[0])

                # train q function
                for j in range(int(max_step / 5)):
                    batch = replay_buffer.sample_batch(batch_size)
                    feed_dict = {
                        x_ph: batch['obs1'],
                        x2_ph: batch['obs2'],
                        a_ph: batch['acts'],
                        r_ph: batch['rews'],
                        d_ph: batch['done']
                    }
                    q_step_ops = [q_loss, q1, q2, train_q_op]
                    # for _ in range(UPDATE_STEP):
                    outs = sess.run(q_step_ops, feed_dict)
                    logger.store(LossQ=outs[0], Q1Vals=outs[1], Q2Vals=outs[2])

                    if j % policy_delay == 0:
                        # Delayed target update
                        outs = sess.run([target_update], feed_dict)
                        # logger.store(LossPi=outs[0])

                # logger.store(LossQ=1000000, Q1Vals=1000000, Q2Vals=1000000)
                logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
                break

        # End of epoch wrap-up
        if epoch > 0 and (epoch % save_freq == 0) or (epoch == dagger_epochs):
            # Save model
            logger.save_state({}, None)

            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)

            # Log info about epoch
            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()

    sess.run(target_init)
    print(colorize("begin td3 training", 'green', bold=True))
    # Main loop: collect experience in env and update/log each epoch
    # total_env_t = 0
    for epoch in range(1, epochs + 1, 1):

        # End of epoch wrap-up
        if epoch > 0 and (epoch % save_freq == 0) or (epoch == epochs):

            # Save model
            logger.save_state({}, None)

            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)

            # Log info about epoch
            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()
        """
        Until start_steps have elapsed, randomly sample actions
        from a uniform distribution for better exploration. Afterwards, 
        use the learned policy (with some noise, via act_noise). 
        """
        # o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        for t in range(steps_per_epoch):
            if epoch > start_epochs:
                a = get_action(np.array(o), act_noise_limit)
            else:
                a = env.action_space.sample()
                # ref_action = call_ref_controller(env, expert)

            # Step the env
            o2, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1
            total_env_t += 1

            # Ignore the "done" signal if it comes from hitting the time
            # horizon (that is, when it's an artificial terminal signal
            # that isn't based on the agent's state)
            # d = False if ep_len==max_ep_len else d

            # Store experience to replay buffer
            replay_buffer.store(o, a, r, o2, d)

            # Super critical, easy to overlook step: make sure to update
            # most recent observation!
            o = o2

            if (t == steps_per_epoch - 1):
                # print ("reached the end")
                d = True

            if d:
                """
                Perform all TD3 updates at the end of the trajectory
                (in accordance with source code of TD3 published by
                original authors).
                """
                for j in range(ep_len):
                    batch = replay_buffer.sample_batch(batch_size)
                    feed_dict = {
                        x_ph: batch['obs1'],
                        x2_ph: batch['obs2'],
                        a_ph: batch['acts'],
                        r_ph: batch['rews'],
                        d_ph: batch['done']
                    }
                    q_step_ops = [q_loss, q1, q2, train_q_op]
                    # for _ in range(UPDATE_STEP):
                    outs = sess.run(q_step_ops, feed_dict)
                    logger.store(LossQ=outs[0], Q1Vals=outs[1], Q2Vals=outs[2])

                    if j % policy_delay == 0:
                        # Delayed policy update
                        outs = sess.run([pi_loss, train_pi_op, target_update],
                                        feed_dict)
                        logger.store(LossPi=outs[0])

                logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
                break