Ejemplo n.º 1
0
def ppo(env_fn,
        actor_critic=core.mlp_actor_critic,
        ac_kwargs=dict(),
        seed=0,
        batch_size=250000,
        n=100,
        epochs=100,
        gamma=0.99,
        clip_ratio=0.2,
        pi_lr=3e-4,
        vf_lr=1e-3,
        train_pi_iters=1000,
        train_v_iters=80,
        lam=0.97,
        max_ep_len=1000,
        target_kl=0.01,
        logger_kwargs=dict(),
        save_freq=10):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols 
            for state, ``x_ph``, and action, ``a_ph``, and returns the main 
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Samples actions from policy given 
                                           | states.
            ``logp``     (batch,)          | Gives log probability, according to
                                           | the policy, of taking actions ``a_ph``
                                           | in states ``x_ph``.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``.
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x_ph``. (Critical: make sure 
                                           | to flatten this!)
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
            function you provided to PPO.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.

        gamma (float): Discount factor. (Always between 0 and 1.)

        clip_ratio (float): Hyperparameter for clipping in the policy objective.
            Roughly: how far can the new policy go from the old policy while 
            still profiting (improving the objective function)? The new policy 
            can still go farther than the clip_ratio says, but it doesn't help
            on the objective anymore. (Usually small, 0.1 to 0.3.)

        pi_lr (float): Learning rate for policy optimizer.

        vf_lr (float): Learning rate for value function optimizer.

        train_pi_iters (int): Maximum number of gradient descent steps to take 
            on policy loss per epoch. (Early stopping may cause optimizer
            to take fewer than this.)

        train_v_iters (int): Number of gradient descent steps to take on 
            value function per epoch.

        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        target_kl (float): Roughly what KL divergence we think is appropriate
            between new and old policies after an update. This will get used 
            for early stopping. (Usually small, 0.01 or 0.05.)

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    sequence_length = n * max_ep_len
    trials = batch_size // sequence_length

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Inputs to computation graph
    # x_ph, a_ph = core.placeholders_from_spaces(env.observation_space, env.action_space)
    # rew_ph, adv_ph, ret_ph, logp_old_ph = core.placeholders(1, None, None, None)
    x_ph = tf.placeholder(dtype=tf.int32,
                          shape=(None, sequence_length),
                          name='x_ph')
    t_ph = tf.placeholder(dtype=tf.int32,
                          shape=(None, sequence_length),
                          name='t_ph')
    a_ph = tf.placeholder(dtype=tf.int32,
                          shape=(None, sequence_length),
                          name='a_ph')
    r_ph = tf.placeholder(dtype=tf.float32,
                          shape=(None, sequence_length),
                          name='r_ph')
    #    input_ph = tf.placeholder(dtype=tf.float32, shape=(None, None, n, None), name='rew_ph')
    adv_ph = tf.placeholder(dtype=tf.float32, shape=(None), name='adv_ph')
    ret_ph = tf.placeholder(dtype=tf.float32, shape=(None), name='ret_ph')
    logp_old_ph = tf.placeholder(dtype=tf.float32,
                                 shape=(None),
                                 name='logp_old_ph')
    # Main outputs from computation graph
    pi, logp, logp_pi, v = actor_critic(x_ph, t_ph, a_ph, r_ph,
                                        sequence_length, env.action_space.n,
                                        env.observation_space.shape[0])

    # Need all placeholders in *this* order later (to zip with data from buffer)
    all_phs = [x_ph, t_ph, a_ph, r_ph, adv_ph, ret_ph, logp_old_ph]
    #    for ph in all_phs:
    #        print(ph.shape)

    # Every step, get: action, value, and logprob
    get_action_ops = [pi, v, logp_pi]

    # Experience buffer
    buf = PPOBuffer(obs_dim, act_dim, batch_size, gamma, lam)

    # Count variables
    var_counts = tuple(core.count_vars(scope) for scope in ['pi', 'v'])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # PPO objectives
    ratio = tf.exp(logp - logp_old_ph)  # pi(a|s) / pi_old(a|s)
    min_adv = tf.where(adv_ph > 0, (1 + clip_ratio) * adv_ph,
                       (1 - clip_ratio) * adv_ph)
    pi_loss = -tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
    v_loss = tf.reduce_mean((ret_ph - v)**2)

    # Info (useful to watch during learning)
    approx_kl = tf.reduce_mean(
        logp_old_ph -
        logp)  # a sample estimate for KL-divergence, easy to compute
    approx_ent = tf.reduce_mean(
        -logp)  # a sample estimate for entropy, also easy to compute
    clipped = tf.logical_or(ratio > (1 + clip_ratio), ratio < (1 - clip_ratio))
    clipfrac = tf.reduce_mean(tf.cast(clipped, tf.float32))

    # Optimizers
    train_pi = MpiAdamOptimizer(learning_rate=pi_lr).minimize(pi_loss)
    train_v = MpiAdamOptimizer(learning_rate=vf_lr).minimize(v_loss)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Sync params across processes
    sess.run(sync_all_params())

    # Setup model saving
    model_inputs = {'x': x_ph, 't': t_ph, 'a': a_ph, 'r': r_ph}
    model_outputs = {'pi': pi}
    logger.setup_tf_saver(sess, inputs=model_inputs, outputs=model_outputs)

    def update():
        inputs = {k: v for k, v in zip(all_phs, buf.get())}
        #        inputs[a_ph] = np.tril(np.transpose(np.repeat(inputs[a_ph], n).reshape(trials, n, n), [0, 2, 1]))
        #        inputs[rew_ph] = np.tril(np.transpose(np.repeat(inputs[rew_ph], n).reshape(trials, n, n), [0, 2, 1]))
        #        print(inputs[x_ph])
        #        print(inputs[t_ph])
        #        print(inputs[a_ph])
        #        print(inputs[r_ph])
        inputs[x_ph] = inputs[x_ph].reshape(trials, sequence_length)
        inputs[t_ph] = inputs[t_ph].reshape(trials, sequence_length)
        inputs[a_ph] = inputs[a_ph].reshape(trials, sequence_length)
        inputs[r_ph] = inputs[r_ph].reshape(trials, sequence_length)
        #        print('x:', inputs[x_ph])
        #        print('t:', inputs[t_ph])
        #        print('a:', inputs[a_ph])
        #        print('r:', inputs[r_ph])
        #        print('ret:', inputs[ret_ph])
        #        print('adv:', inputs[adv_ph])
        #        print('logp_old:', inputs[logp_old_ph])
        pi_l_old, v_l_old, ent = sess.run([pi_loss, v_loss, approx_ent],
                                          feed_dict=inputs)

        # Training
        for i in range(train_pi_iters):
            _, kl = sess.run([train_pi, approx_kl], feed_dict=inputs)


#            kl = mpi_avg(kl)
#            if kl > 1.5 * target_kl:
#                logger.log('Early stopping at step %d due to reaching max kl.'%i)
#                break
        logger.store(StopIter=i)
        for _ in range(train_v_iters):
            sess.run(train_v, feed_dict=inputs)

        # Log changes from update
        pi_l_new, v_l_new, kl, cf = sess.run(
            [pi_loss, v_loss, approx_kl, clipfrac], feed_dict=inputs)
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

    start_time = time.time()
    save_itr = 0
    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        for trail in range(trials):
            print('trial:', trail)
            #            last_a = np.zeros(n).reshape(1, n)
            #            last_r = np.zeros(n).reshape(1, n)
            o_deque = deque(sequence_length * [0], sequence_length)
            t_deque = deque(sequence_length * [0], sequence_length)
            last_a = deque(sequence_length * [0], sequence_length)
            last_r = deque(sequence_length * [0], sequence_length)
            means = env.sample_tasks(1)[0]
            #            print('task means:', means)
            action_dict = defaultdict(int)
            total_reward = 0
            env.reset_task(means)
            o, r, d, ep_ret, ep_len = env.reset(), np.zeros(1), False, 0, 0

            for episode in range(sequence_length):
                #                print('episode:', episode)
                #                print('o:', o_deque)
                #                print('d:', t_deque)
                #                print('a:', last_a)
                #                print('r:', last_r)
                a, v_t, logp_t = sess.run(
                    get_action_ops,
                    feed_dict={
                        x_ph: np.array(o_deque).reshape(1, sequence_length),
                        t_ph: np.array(t_deque).reshape(1, sequence_length),
                        a_ph: np.array(last_a).reshape(1, sequence_length),
                        r_ph: np.array(last_r).reshape(1, sequence_length)
                    })
                #                print("a shape:", a.shape)
                #                print("v_t shape:", v_t.shape)
                #                print("logp_t shape:", logp_t.shape)
                #                choosen_a = a[episode, 0]
                #                choosen_v_t = v_t[0, episode]
                #                choosen_logp_t = logp_t[episode]
                #                print('a:', a)
                choosen_a = a[-1]
                choosen_v_t = v_t[-1]
                choosen_logp_t = logp_t[-1]
                action_dict[choosen_a] += 1
                o, r, d, _ = env.step(choosen_a)

                ep_ret += r
                ep_len += 1
                t = ep_len == max_ep_len
                total_reward += r

                o_deque.append(o)
                t_deque.append(int(d))
                last_a.append(choosen_a)
                last_r.append(r)

                # save and log
                buf.store(o, int(t), choosen_a, r, choosen_v_t, choosen_logp_t)
                logger.store(VVals=v_t)

                terminal = d or t
                if terminal or (episode == sequence_length - 1):
                    if not (terminal):
                        print(
                            'Warning: trajectory cut off by epoch at %d steps.'
                            % ep_len)
                    # if trajectory didn't reach terminal state, bootstrap value target
                    if d:
                        last_val = r
                    else:
                        last_val = sess.run(
                            v,
                            feed_dict={
                                x_ph:
                                np.array(o_deque).reshape(1, sequence_length),
                                t_ph:
                                np.array(t_deque).reshape(1, sequence_length),
                                a_ph:
                                np.array(last_a).reshape(1, sequence_length),
                                r_ph:
                                np.array(last_r).reshape(1, sequence_length)
                            })
                        last_val = last_val[-1]
                    buf.finish_path(last_val)
                    if terminal:
                        # only save EpRet / EpLen if trajectory finished
                        logger.store(EpRet=ep_ret, EpLen=ep_len)
                    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
                    o_deque[-1] = 0
                    t_deque[-1] = 0
                    last_a[-1] = 0
                    last_r[-1] = 0
            print(action_dict)
            print('average reward:', total_reward / sequence_length)
        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, save_itr)
            save_itr += 1
        # Perform PPO update!
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * batch_size)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()
Ejemplo n.º 2
0
def ppo(env_fn, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0, gru_units=256,
        trials_per_epoch=100, episodes_per_trial=2, n = 100, epochs=100, gamma=0.99, clip_ratio=0.2, pi_lr=3e-4,
        vf_lr=1e-3, train_pi_iters=1000, train_v_iters=80, lam=0.97, max_ep_len=1000,
        target_kl=0.01, logger_kwargs=dict(), save_freq=10):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols 
            for state, ``x_ph``, and action, ``a_ph``, and returns the main 
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Samples actions from policy given 
                                           | states.
            ``logp``     (batch,)          | Gives log probability, according to
                                           | the policy, of taking actions ``a_ph``
                                           | in states ``x_ph``.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``.
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x_ph``. (Critical: make sure 
                                           | to flatten this!)
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
            function you provided to PPO.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.

        gamma (float): Discount factor. (Always between 0 and 1.)

        clip_ratio (float): Hyperparameter for clipping in the policy objective.
            Roughly: how far can the new policy go from the old policy while 
            still profiting (improving the objective function)? The new policy 
            can still go farther than the clip_ratio says, but it doesn't help
            on the objective anymore. (Usually small, 0.1 to 0.3.)

        pi_lr (float): Learning rate for policy optimizer.

        vf_lr (float): Learning rate for value function optimizer.

        train_pi_iters (int): Maximum number of gradient descent steps to take 
            on policy loss per epoch. (Early stopping may cause optimizer
            to take fewer than this.)

        train_v_iters (int): Number of gradient descent steps to take on 
            value function per epoch.

        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        target_kl (float): Roughly what KL divergence we think is appropriate
            between new and old policies after an update. This will get used 
            for early stopping. (Usually small, 0.01 or 0.05.)

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape
    
    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Inputs to computation graph\
    raw_input_ph = tf.placeholder(dtype=tf.float32, shape=obs_dim, name='raw_input_ph')
    rescale_image_op = tf.image.resize_images(raw_input_ph, [30, 40])
    max_seq_len_ph = tf.placeholder(dtype=tf.int32, shape=(), name='max_seq_len_ph')
    seq_len_ph = tf.placeholder(dtype=tf.int32, shape=(None,))

    # Because we pad zeros at the end of every sequence of length less than max length, we need to mask these zeros out
    # when computing loss
    seq_len_mask_ph = tf.placeholder(dtype=tf.int32, shape=(trials_per_epoch, episodes_per_trial * max_ep_len))

    # rescaled_image_ph This is a ph  because we want to be able to pass in value to this node manually
    rescaled_image_in_ph = tf.placeholder(dtype=tf.float32, shape=[None, 30, 40, 3], name='rescaled_image_in_ph')
    a_ph = core.placeholders_from_spaces( env.action_space)[0]
    conv1 = slim.conv2d(activation_fn=tf.nn.relu, inputs=rescaled_image_in_ph, num_outputs=16, kernel_size=[5,5],
                        stride=2)
    image_out = slim.flatten(slim.conv2d(activation_fn=tf.nn.relu, inputs=conv1, num_outputs=16, kernel_size=[5,5],
                        stride=2))

    rew_ph, adv_ph, ret_ph, logp_old_ph = core.placeholders(1, None, None, None)
    rnn_state_ph = tf.placeholder(tf.float32, [None, gru_units], name='pi_rnn_state_ph')
    # Main outputs from computation graph

    action_encoder_matrix = np.load(r'encoder.npy')
    pi, logp, logp_pi, v, rnn_state, logits, seq_len_vec, tmp_vec = actor_critic(
            image_out, a_ph, rew_ph, rnn_state_ph, gru_units,
            max_seq_len_ph, action_encoder_matrix, seq_len=seq_len_ph, action_space=env.action_space)

    # Need all placeholders in *this* order later (to zip with data from buffer)
    all_phs = [rescaled_image_in_ph, a_ph, adv_ph, ret_ph, logp_old_ph, rew_ph]

    # Every step, get: action, value, and logprob
    get_action_ops = [pi, v, logp_pi, rnn_state, logits]

    # Experience buffer
    buffer_size = trials_per_epoch * episodes_per_trial * max_ep_len
    buf = PPOBuffer(rescaled_image_in_ph.get_shape().as_list()[1:], act_dim, buffer_size, trials_per_epoch, gamma, lam)

    # Count variables
    var_counts = tuple(core.count_vars(scope) for scope in ['pi', 'v'])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n'%var_counts)

    # PPO objectives
    ratio = tf.exp(logp - logp_old_ph)          # pi(a|s) / pi_old(a|s)
    min_adv = tf.where(adv_ph>0, (1+clip_ratio)*adv_ph, (1-clip_ratio)*adv_ph)

    # Need to mask out the padded zeros when computing loss
    sequence_mask = tf.sequence_mask(seq_len_ph, episodes_per_trial*max_ep_len)
    # Convert bool tensor to int tensor with 1 and 0
    sequence_mask = tf.where(sequence_mask,
                             np.ones(dtype=np.float32, shape=(trials_per_epoch, episodes_per_trial*max_ep_len)),
                             np.zeros(dtype=np.float32, shape=(trials_per_epoch, episodes_per_trial*max_ep_len)))

    # need to reshape because ratio is a 1-D vector (it is a concatnation of all sequence) for masking and then reshape
    # it back
    pi_loss_vec = tf.multiply(sequence_mask, tf.reshape(tf.minimum(ratio * adv_ph, min_adv), tf.shape(sequence_mask)))
    pi_loss = -tf.reduce_mean(tf.reshape(pi_loss_vec, tf.shape(ratio)))
    aaa = (ret_ph - v)**2

    v_loss_vec = tf.multiply(sequence_mask, tf.reshape((ret_ph - v)**2, tf.shape(sequence_mask)))
    ccc = tf.reshape(v_loss_vec, tf.shape(v))

    v_loss = tf.reduce_mean(tf.reshape(v_loss_vec, tf.shape(v)))


    # Info (useful to watch during learning)
    approx_kl = tf.reduce_mean(logp_old_ph - logp)      # a sample estimate for KL-divergence, easy to compute
    approx_ent = tf.reduce_mean(-logp)                  # a sample estimate for entropy, also easy to compute
    clipped = tf.logical_or(ratio > (1+clip_ratio), ratio < (1-clip_ratio))
    clipfrac = tf.reduce_mean(tf.cast(clipped, tf.float32))

    # Optimizers
    train_pi = MpiAdamOptimizer(learning_rate=pi_lr).minimize(pi_loss)
    train_v = MpiAdamOptimizer(learning_rate=vf_lr).minimize(v_loss)

    train = MpiAdamOptimizer(learning_rate=1e-4).minimize(pi_loss + 0.01 * v_loss - 0.001 * approx_ent)


    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Sync params across processes
    sess.run(sync_all_params())

    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'rescaled_image_in': rescaled_image_in_ph}, outputs={'pi': pi, 'v': v})



    def update():
        print(f'Start updating at {datetime.now()}')
        inputs = {k:v for k,v in zip(all_phs, buf.get())}

        inputs[rnn_state_ph] = np.zeros((trials_per_epoch, gru_units), np.float32)
        inputs[max_seq_len_ph] = int(episodes_per_trial * max_ep_len)
        inputs[seq_len_ph] = buf.seq_len_buf
        pi_l_old, v_l_old, ent = sess.run([pi_loss, v_loss, approx_ent], feed_dict=inputs)

        buf.reset()

        
        # Training
        print(f'sequence length = {sess.run(seq_len_vec, feed_dict=inputs)}')


        for i in range(train_pi_iters):
            _, kl, pi_loss_i, v_loss_i, ent = sess.run([train_pi, approx_kl, pi_loss, v_loss, approx_ent], feed_dict=inputs)
            print(f'i: {i}, pi_loss: {pi_loss_i}, v_loss: {v_loss_i}, entropy: {ent}')


        logger.store(StopIter=i)


        # Log changes from update
        pi_l_new, v_l_new, kl, cf = sess.run(
                [pi_loss, v_loss, approx_kl, clipfrac], feed_dict=inputs)
        logger.store(LossPi=pi_l_old, LossV=v_l_old, 
                     KL=kl, Entropy=ent, ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))
        print(f'Updating finished at {datetime.now()}')


    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), np.zeros(1), False, 0, 0

    def recenter_rgb(image, min=0.0, max=255.0):
        '''

        :param image:
        :param min:
        :param max:
        :return: an image with rgb value re-centered to [-1, 1]
        '''
        mid = (min + max) / 2.0
        return np.apply_along_axis(func1d=lambda x: (x - mid) / mid, axis=2, arr=image)

    o_rescaled = recenter_rgb(sess.run(rescale_image_op, feed_dict={raw_input_ph: o}))
    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        for trial in range(trials_per_epoch):
            # TODO: tweek settings to match the paper

            # TODO: find a way to generate mazes
            last_a = np.array(0)
            last_r = np.array(r)
            last_rnn_state = np.zeros((1, gru_units), np.float32)

            step_counter = 0
            for episode in range(episodes_per_trial):
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
                o_rescaled = recenter_rgb(sess.run(rescale_image_op, feed_dict={raw_input_ph: o}))

                action_dict = defaultdict(int)

                # dirty hard coding to make it print in order
                action_dict[0] = 0
                action_dict[1] = 0
                action_dict[2] = 0

                for step in range(max_ep_len):
                    a, v_t, logp_t, rnn_state_t, logits_t = sess.run(
                            get_action_ops, feed_dict={
                                    rescaled_image_in_ph: np.expand_dims(o_rescaled, 0),
                                    a_ph: last_a.reshape(-1,),
                                    rew_ph: last_r.reshape(-1,1),
                                    rnn_state_ph: last_rnn_state,
                                    # v_rnn_state_ph: last_v_rnn_state,
                                    max_seq_len_ph: 1,
                        seq_len_ph: [1]})
                    action_dict[a[0]] += 1
                    # save and log
                    buf.store(o_rescaled, a, r, v_t, logp_t)
                    logger.store(VVals=v_t)
                    o, r, d, _ = env.step(a[0])
                    step_counter += 1
                    o_rescaled = recenter_rgb(sess.run(rescale_image_op, feed_dict={raw_input_ph: o}))
                    ep_ret += r
                    ep_len += 1

                    last_a = a[0]
                    last_r = np.array(r)
                    last_rnn_state = rnn_state_t

                    terminal = d or (ep_len == max_ep_len)
                    if terminal or (step==n-1):
                        if not(terminal):
                            print('Warning: trajectory cut off by epoch at %d steps.'%ep_len)
                        # if trajectory didn't reach terminal state, bootstrap value target
                        last_val = r if d else sess.run(v, feed_dict={rescaled_image_in_ph: np.expand_dims(o_rescaled, 0),
                                    a_ph: last_a.reshape(-1,),
                                    rew_ph: last_r.reshape(-1,1),
                                    rnn_state_ph: last_rnn_state,
                                    max_seq_len_ph: 1,
                                    seq_len_ph: [1]})
                        buf.finish_path(last_val)
                        logger.store(EpRet=ep_ret, EpLen=ep_len)


                        print(f'episode terminated with {step} steps. epoch:{epoch} trial:{trial} episode:{episode}')
                        break
                print(action_dict)
            if step_counter < episodes_per_trial * max_ep_len:
                buf.pad_zeros(episodes_per_trial * max_ep_len - step_counter)
            buf.seq_len_buf[trial] = step_counter



            # pad zeros to sequence buffer after each trial
        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs-1):
            logger.save_state({'env': env}, None)
        # Perform PPO update!
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch+1)*trials_per_epoch*episodes_per_trial*max_ep_len)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time()-start_time)
        logger.dump_tabular()
Ejemplo n.º 3
0
def ppo(env_fn,
        actor_critic=core.MLPActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=50,
        gamma=0.99,
        clip_ratio=0.2,
        pi_lr=3e-4,
        vf_lr=1e-3,
        train_pi_iters=80,
        train_v_iters=80,
        lam=0.97,
        max_ep_len=None,
        target_kl=0.01,
        logger_kwargs=dict(),
        save_freq=10,
        TensorBoard=True,
        save_nn=True,
        save_every=1000,
        load_latest=False,
        load_custom=False,
        LoadPath=None,
        RTA_type=None):
    """
	Proximal Policy Optimization (by clipping),

	with early stopping based on approximate KL

	Args:
		env_fn : A function which creates a copy of the environment.
			The environment must satisfy the OpenAI Gym API.

		actor_critic: The constructor method for a PyTorch Module with a
			``step`` method, an ``act`` method, a ``pi`` module, and a ``v``
			module. The ``step`` method should accept a batch of observations
			and return:

			===========  ================  ======================================
			Symbol       Shape             Description
			===========  ================  ======================================
			``a``        (batch, act_dim)  | Numpy array of actions for each
										   | observation.
			``v``        (batch,)          | Numpy array of value estimates
										   | for the provided observations.
			``logp_a``   (batch,)          | Numpy array of log probs for the
										   | actions in ``a``.
			===========  ================  ======================================

			The ``act`` method behaves the same as ``step`` but only returns ``a``.

			The ``pi`` module's forward call should accept a batch of
			observations and optionally a batch of actions, and return:

			===========  ================  ======================================
			Symbol       Shape             Description
			===========  ================  ======================================
			``pi``       N/A               | Torch Distribution object, containing
										   | a batch of distributions describing
										   | the policy for the provided observations.
			``logp_a``   (batch,)          | Optional (only returned if batch of
										   | actions is given). Tensor containing
										   | the log probability, according to
										   | the policy, of the provided actions.
										   | If actions not given, will contain
										   | ``None``.
			===========  ================  ======================================

			The ``v`` module's forward call should accept a batch of observations
			and return:

			===========  ================  ======================================
			Symbol       Shape             Description
			===========  ================  ======================================
			``v``        (batch,)          | Tensor containing the value estimates
										   | for the provided observations. (Critical:
										   | make sure to flatten this!)
			===========  ================  ======================================


		ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object
			you provided to PPO.

		seed (int): Seed for random number generators.

		steps_per_epoch (int): Number of steps of interaction (state-action pairs)
			for the agent and the environment in each epoch.

		epochs (int): Number of epochs of interaction (equivalent to
			number of policy updates) to perform.

		gamma (float): Discount factor. (Always between 0 and 1.)

		clip_ratio (float): Hyperparameter for clipping in the policy objective.
			Roughly: how far can the new policy go from the old policy while
			still profiting (improving the objective function)? The new policy
			can still go farther than the clip_ratio says, but it doesn't help
			on the objective anymore. (Usually small, 0.1 to 0.3.) Typically
			denoted by :math:`\epsilon`.

		pi_lr (float): Learning rate for policy optimizer.

		vf_lr (float): Learning rate for value function optimizer.

		train_pi_iters (int): Maximum number of gradient descent steps to take
			on policy loss per epoch. (Early stopping may cause optimizer
			to take fewer than this.)

		train_v_iters (int): Number of gradient descent steps to take on
			value function per epoch.

		lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
			close to 1.)

		max_ep_len (int): Maximum length of trajectory / episode / rollout.

		target_kl (float): Roughly what KL divergence we think is appropriate
			between new and old policies after an update. This will get used
			for early stopping. (Usually small, 0.01 or 0.05.)

		logger_kwargs (dict): Keyword args for EpochLogger.

		save_freq (int): How often (in terms of gap between epochs) to save
			the current policy and value function.

		TensorBoard (bool): True plots to TensorBoard, False does not

		save_nn (bool): True saves neural network data, False does not

		save_every (int): How often to save neural network

		load_latest (bool): Load last saved neural network data before training

		load_custom (bool): Load custom neural network data file before training

		LoadPath (str): Path for custom neural network data file

		RTA_type (str): RTA framework, either 'CBF', 'SVL', 'ASIF', or
			'SBSF'

	"""

    # Special function to avoid certain slowdowns from PyTorch + MPI combo.
    setup_pytorch_for_mpi()

    # Set up logger and save configuration
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    # Instantiate environment
    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Random seed for each cpu
    seed += 1 * proc_id()
    env.seed(seed)

    # Create actor-critic module
    ac = actor_critic(env.observation_space, env.action_space, **ac_kwargs)

    # Load model if True
    if load_latest:
        models = glob.glob(f"{PATH}/models/PPO/*")
        LoadPath = max(models, key=os.path.getctime)
        ac.load_state_dict(torch.load(LoadPath))
    elif load_custom:
        ac.load_state_dict(torch.load(LoadPath))

    # Sync params across processes
    sync_params(ac)

    # Count variables
    var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.v])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # Set up experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)

    # Set up function for computing PPO policy loss
    def compute_loss_pi(data):
        obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data[
            'logp']

        # Policy loss
        pi, logp = ac.pi(obs, act)
        ratio = torch.exp(logp - logp_old)
        clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv
        loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()

        # Useful extra info
        approx_kl = (logp_old - logp).mean().item()
        ent = pi.entropy().mean().item()
        clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
        clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().item()
        pi_info = dict(kl=approx_kl, ent=ent, cf=clipfrac)

        return loss_pi, pi_info

    # Set up function for computing value loss
    def compute_loss_v(data):
        obs, ret = data['obs'], data['ret']
        return ((ac.v(obs) - ret)**2).mean()

    # Set up optimizers for policy and value function
    pi_optimizer = Adam(ac.pi.parameters(), lr=pi_lr)
    vf_optimizer = Adam(ac.v.parameters(), lr=vf_lr)

    # Set up model saving
    logger.setup_pytorch_saver(ac)

    def update():
        data = buf.get()

        pi_l_old, pi_info_old = compute_loss_pi(data)
        pi_l_old = pi_l_old.item()
        v_l_old = compute_loss_v(data).item()

        # Train policy with multiple steps of gradient descent
        for i in range(train_pi_iters):
            pi_optimizer.zero_grad()
            loss_pi, pi_info = compute_loss_pi(data)
            kl = mpi_avg(pi_info['kl'])
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break
            loss_pi.backward()
            mpi_avg_grads(ac.pi)  # average grads across MPI processes
            pi_optimizer.step()

        logger.store(StopIter=i)

        # Value function learning
        for i in range(train_v_iters):
            vf_optimizer.zero_grad()
            loss_v = compute_loss_v(data)
            loss_v.backward()
            mpi_avg_grads(ac.v)  # average grads across MPI processes
            vf_optimizer.step()

        # Log changes from update
        kl, ent, cf = pi_info['kl'], pi_info_old['ent'], pi_info['cf']
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(loss_pi.item() - pi_l_old),
                     DeltaLossV=(loss_v.item() - v_l_old))

    # Import RTA
    if RTA_type == 'CBF':
        from CBF_for_speed_limit import RTA
    elif RTA_type == 'SVL':
        from Simple_velocity_limit import RTA
    elif RTA_type == 'ASIF':
        from IASIF import RTA
    elif RTA_type == 'SBSF':
        from ISimplex import RTA

    # Call RTA, define action conversion
    if RTA_type != 'off':
        env.RTA_reward = RTA_type

        rta = RTA(env)

        def RTA_act(obs, act):
            act = np.clip(act, -env.force_magnitude, env.force_magnitude)
            x0 = [obs[0], obs[1], 0, obs[2], obs[3], 0]
            u_des = np.array([[act[0]], [act[1]], [0]])
            u = rta.main(x0, u_des)
            new_act = [u[0, 0], u[1, 0]]
            if np.sqrt((act[0] - new_act[0])**2 +
                       (act[1] - new_act[1])**2) < 0.0001:
                env.RTA_on = False
            else:
                env.RTA_on = True
            return new_act

    # Prepare for interaction with environment
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0
    total_episodes = 0
    RTA_percent = 0

    # Create TensorBoard file if True
    if TensorBoard and proc_id() == 0:
        if env_name == 'spacecraft-docking-continuous-v0' or env_name == 'spacecraft-docking-v0':
            Name = f"{PATH}/runs/Spacecraft-docking-" + current_time
        elif env_name == 'dubins-aircraft-v0' or env_name == 'dubins-aircraft-continuous-v0':
            Name = f"{PATH}/runs/Dubins-aircraft-" + current_time
        writer = SummaryWriter(Name)

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        batch_ret = []  # Track episode returns
        batch_len = []  # Track episode lengths
        batch_RTA_percent = []  # Track precentage of time RTA is on
        env.success = 0  # Track episode success rate
        env.failure = 0  # Track episode failure rate
        env.crash = 0  # Track episode crash rate
        env.overtime = 0  # Track episode over max time/control rate
        episodes = 0  # Track episodes
        delta_v = []  # Track episode total delta v
        for t in range(local_steps_per_epoch):
            a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))
            if RTA_type != 'off':  # If RTA is on, get RTA action
                RTA_a = RTA_act(o, a)
                if env.RTA_on:
                    RTA_percent += 1
                next_o, r, d, _ = env.step(RTA_a)
            else:  # If RTA is off, pass through desired action
                next_o, r, d, _ = env.step(a)
                if env_name == 'spacecraft-docking-continuous-v0' or env_name == 'spacecraft-docking-v0':
                    over_max_vel, _, _ = env.check_velocity(a[0], a[1])
                    if over_max_vel:
                        RTA_percent += 1
            ep_ret += r
            ep_len += 1

            # save and log
            buf.store(o, a, r, v, logp)
            logger.store(VVals=v)

            # Update obs (critical!)
            o = next_o

            timeout = ep_len == max_ep_len
            terminal = d or timeout
            epoch_ended = t == local_steps_per_epoch - 1

            if terminal or epoch_ended:
                if epoch_ended and not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len,
                          flush=True)
                # if trajectory didn't reach terminal state, bootstrap value target
                if timeout or epoch_ended:
                    _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
                else:
                    v = 0
                buf.finish_path(v)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                    batch_ret.append(ep_ret)
                    batch_len.append(ep_len)
                    episodes += 1
                    if env_name == 'spacecraft-docking-continuous-v0' or env_name == 'spacecraft-docking-v0':
                        delta_v.append(env.control_input / env.mass_deputy)
                batch_RTA_percent.append(RTA_percent / ep_len * 100)
                RTA_percent = 0
                o, ep_ret, ep_len = env.reset(), 0, 0

        total_episodes += episodes
        # Track success, failure, crash, overtime rates
        if episodes != 0:
            success_rate = env.success / episodes
            failure_rate = env.failure / episodes
            crash_rate = env.crash / episodes
            overtime_rate = env.overtime / episodes
        else:
            success_rate = 0
            failure_rate = 0
            crash_rate = 0
            overtime_rate = 0
            raise (
                "No completed episodes logging will break [increase steps per epoch]"
            )

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, None)

        # Perform PPO update!
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()

        # Average data over all cpus
        avg_batch_ret = mpi_avg(np.mean(batch_ret))
        avg_batch_len = mpi_avg(np.mean(batch_len))
        avg_success_rate = mpi_avg(success_rate)
        avg_failure_rate = mpi_avg(failure_rate)
        avg_crash_rate = mpi_avg(crash_rate)
        avg_overtime_rate = mpi_avg(overtime_rate)
        if env_name == 'spacecraft-docking-continuous-v0' or env_name == 'spacecraft-docking-v0':
            avg_delta_v = mpi_avg(np.mean(delta_v))
            avg_RTA_percent = mpi_avg(np.mean(batch_RTA_percent))

        if proc_id() == 0:  # Only on one cpu
            # Plot to TensorBoard if True, only on one cpu
            if TensorBoard:
                writer.add_scalar('Return', avg_batch_ret, epoch)
                writer.add_scalar('Episode-Length', avg_batch_len * env.tau,
                                  epoch)
                writer.add_scalar('Success-Rate', avg_success_rate * 100,
                                  epoch)
                writer.add_scalar('Failure-Rate', avg_failure_rate * 100,
                                  epoch)
                writer.add_scalar('Crash-Rate', avg_crash_rate * 100, epoch)
                writer.add_scalar('Overtime-Rate', avg_overtime_rate * 100,
                                  epoch)
                if env_name == 'spacecraft-docking-continuous-v0' or env_name == 'spacecraft-docking-v0':
                    writer.add_scalar('Delta-V', avg_delta_v, epoch)
                    writer.add_scalar('RTA-on-percent', avg_RTA_percent, epoch)

            # Save neural network if true, can change to desired location
            if save_nn and epoch % save_every == 0 and epoch != 0:
                if not os.path.isdir(f"{PATH}/models"):
                    os.mkdir(f"{PATH}/models")
                if not os.path.isdir(f"{PATH}/models/PPO"):
                    os.mkdir(f"{PATH}/models/PPO")
                if env_name == 'spacecraft-docking-continuous-v0' or env_name == 'spacecraft-docking-v0':
                    Name2 = f"{PATH}/models/PPO/Spacecraft-docking-" + current_time + f"-epoch{epoch}.dat"
                elif env_name == 'dubins-aircraft-v0' or env_name == 'dubins-aircraft-continuous-v0':
                    Name2 = f"{PATH}/models/PPO/Dubins-aircraft-" + current_time + f"-epoch{epoch}.dat"
                torch.save(ac.state_dict(), Name2)

    # Average episodes per hour, episode per epoch
    ep_hr = mpi_avg(total_episodes) * args.cpu / (time.time() -
                                                  start_time) * 3600
    ep_Ep = mpi_avg(total_episodes) * args.cpu / (epoch + 1)

    # Plot on one cpu
    if proc_id() == 0:
        # Save neural network
        if save_nn:
            if not os.path.isdir(f"{PATH}/models"):
                os.mkdir(f"{PATH}/models")
            if not os.path.isdir(f"{PATH}/models/PPO"):
                os.mkdir(f"{PATH}/models/PPO")
            if env_name == 'spacecraft-docking-continuous-v0' or env_name == 'spacecraft-docking-v0':
                Name2 = f"{PATH}/models/PPO/Spacecraft-docking-" + current_time + "-final.dat"
            elif env_name == 'dubins-aircraft-v0' or env_name == 'dubins-aircraft-continuous-v0':
                Name2 = f"{PATH}/models/PPO/Dubins-aircraft-" + current_time + "-final.dat"
            torch.save(ac.state_dict(), Name2)

        # Print statistics on episodes
        print(
            f"Episodes per hour: {ep_hr:.0f}, Episodes per epoch: {ep_Ep:.0f}, Epochs per hour: {(epoch+1)/(time.time()-start_time)*3600:.0f}"
        )
Ejemplo n.º 4
0
Archivo: ddpg.py Proyecto: zhc134/l2s
def ddpg(env_config, ac_type, ac_kwargs, rb_type, rb_kwargs, gamma, lr, polyak,
         batch_size, epochs, start_steps, steps_per_epoch, inc_ep, max_ep_len,
         test_max_ep_len, number_of_tests_per_epoch, act_noise, logger_kwargs,
         seed):
    logger = EpochLogger(**logger_kwargs)
    configs = locals().copy()
    configs.pop("logger")
    logger.save_config(configs)

    tf.set_random_seed(seed)
    np.random.seed(seed)

    env, test_env = make_env(env_config), make_env(env_config)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    act_high = env.action_space.high

    # Inputs to computation graph
    x_ph, a_ph, x2_ph, r_ph, d_ph = core.placeholders(obs_dim, act_dim,
                                                      obs_dim, None, None)

    actor_critic = core.get_ddpg_actor_critic(ac_type)
    # Main outputs from computation graph
    with tf.variable_scope('main'):
        pi, q, q_pi = actor_critic(x_ph, a_ph, **ac_kwargs)

    # Target networks
    with tf.variable_scope('target'):
        pi_targ, _, q_pi_targ = actor_critic(x2_ph, a_ph, **ac_kwargs)

    # Experience buffer
    RB = get_replay_buffer(rb_type)
    replay_buffer = RB(obs_dim, act_dim, **rb_kwargs)

    # Count variables
    var_counts = tuple(
        core.count_vars(scope) for scope in ['main/pi', 'main/q', 'main'])
    print('\nNumber of parameters: \t pi: %d, \t q: %d, \t total: %d\n' %
          var_counts)

    # Bellman backup for Q function
    backup = tf.stop_gradient(r_ph + gamma * (1 - d_ph) * q_pi_targ)

    # DDPG losses
    pi_loss = -tf.reduce_mean(q_pi)
    q_loss = tf.reduce_mean((q - backup)**2)

    # Separate train ops for pi, q
    pi_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    q_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    train_pi_op = pi_optimizer.minimize(pi_loss, var_list=get_vars('main/pi'))
    train_q_op = q_optimizer.minimize(q_loss, var_list=get_vars('main/q'))

    # Polyak averaging for target variables
    target_update = tf.group([
        tf.assign(v_targ, polyak * v_targ + (1 - polyak) * v_main)
        for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
    ])

    # Initializing targets to match main variables
    target_init = tf.group([
        tf.assign(v_targ, v_main)
        for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
    ])

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    sess.run(target_init)

    def get_action(o, noise_scale):
        pi_a = sess.run(pi, feed_dict={x_ph: o.reshape(1, -1)})[0]
        pi_a += noise_scale * np.random.randn(act_dim)
        pi_a = np.clip(pi_a, 0, 1)
        real_a = pi_a * act_high
        return pi_a, real_a

    def test_agent(n=10):
        test_actions = []
        for j in range(n):
            test_actions_ep = []
            o, r, d, ep_ret, ep_len = test_env.reset(), 0, False, 0, 0
            while not (d or (ep_len == test_max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                _, real_a = get_action(o, 0)
                test_actions_ep.append(real_a)
                o, r, d, _ = test_env.step(real_a)
                ep_ret += r
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
            test_actions.append(test_actions_ep)
        return test_actions

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    total_steps = steps_per_epoch * epochs

    actions = []
    epoch_actions = []
    rewards = []
    rets = []
    test_rets = []
    max_ret = None
    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):
        """
        Until start_steps have elapsed, randomly sample actions
        from a uniform distribution for better exploration. Afterwards, 
        use the learned policy (with some noise, via act_noise). 
        """
        if t > start_steps:
            pi_a, real_a = get_action(o, act_noise)
        else:
            pi_a, real_a = env.action_space.sample()

        # Step the env
        o2, r, d, _ = env.step(real_a)
        ep_ret += r
        ep_len += 1
        epoch_actions.append(pi_a)

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.store(o, pi_a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        if d or (ep_len == max_ep_len):
            """
            Perform all DDPG updates at the end of the trajectory,
            in accordance with tuning done by TD3 paper authors.
            """
            for _ in range(ep_len):
                batch = replay_buffer.sample_batch(batch_size)
                feed_dict = {
                    x_ph: batch['obs1'],
                    x2_ph: batch['obs2'],
                    a_ph: batch['acts'],
                    r_ph: batch['rews'],
                    d_ph: batch['done']
                }

                # Q-learning update
                outs = sess.run([q_loss, q, train_q_op], feed_dict)
                logger.store(LossQ=outs[0], QVals=outs[1])

                # Policy update
                outs = sess.run([pi_loss, train_pi_op, target_update],
                                feed_dict)
                logger.store(LossPi=outs[0])

            logger.store(EpRet=ep_ret, EpLen=ep_len)
            actions.append(np.mean(epoch_actions))
            epoch_actions = []
            rewards.append(ep_ret)
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # End of epoch wrap-up
        if (t + 1) % steps_per_epoch == 0:
            epoch = (t + 1) // steps_per_epoch

            # Test the performance of the deterministic version of the agent.
            test_actions = test_agent(number_of_tests_per_epoch)

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            ret = logger.log_tabular('EpRet', average_only=True)
            test_ret = logger.log_tabular('TestEpRet', average_only=True)[0]
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('QVals', average_only=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('Time', time.time() - start_time)
            logger.dump_tabular()

            rets.append(ret)
            test_rets.append(test_ret)

            if max_ret is None or test_ret > max_ret:
                max_ret = test_ret
                best_test_actions = test_actions

            max_ep_len += inc_ep
            util.plot_actions(test_actions, act_high,
                              logger.output_dir + '/actions%s.png' % epoch)

    logger.save_state(
        {
            "actions": actions,
            "rewards": rewards,
            "best_test_actions": best_test_actions,
            "rets": rets,
            "test_rets": test_rets,
            "max_ret": max_ret
        }, None)

    util.plot_actions(best_test_actions, act_high,
                      logger.output_dir + '/best_test_actions.png')
    logger.log("max ret: %f" % max_ret)
Ejemplo n.º 5
0
def sac(args, steps_per_epoch=1500, replay_size=int(1e6), gamma=0.99,
        polyak=0.995, lr=1e-3, alpha=3e-4, batch_size=128, start_steps=1000,
        update_after=1000, update_every=1, num_test_episodes=10, max_ep_len=150,
        logger_kwargs=dict(), save_freq=1):

    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

    torch.set_num_threads(torch.get_num_threads())

    actor_critic = core.MLPActorCritic
    ac_kwargs = dict(hidden_sizes=[args.hid] * args.l)
    gamma = args.gamma
    seed = args.seed
    epochs = args.epochs
    logger_tensor = Logger(logdir=args.logdir, run_name="{}-{}".format(args.model_name, time.ctime()))

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    torch.manual_seed(seed)
    np.random.seed(seed)

    env = ML1.get_train_tasks('reach-v1')  # Create an environment with task `pick_place`
    tasks = env.sample_tasks(1)  # Sample a task (in this case, a goal variation)
    env.set_task(tasks[0])  # Set task

    test_env = ML1.get_train_tasks('reach-v1')  # Create an environment with task `pick_place`
    tasks = env.sample_tasks(1)  # Sample a task (in this case, a goal variation)
    test_env.set_task(tasks[0])  # Set task

    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    act_limit = env.action_space.high[0]

    # Create actor-critic module and target networks
    ac = actor_critic(env.observation_space, env.action_space, **ac_kwargs)
    ac_targ = deepcopy(ac)

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False

    # List of parameters for both Q-networks (save this for convenience)
    q_params = itertools.chain(ac.q1.parameters(), ac.q2.parameters())

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.q1, ac.q2])
    logger.log('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' % var_counts)

    # Set up function for computing SAC Q-losses
    def compute_loss_q(data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']

        q1 = ac.q1(o, a)
        q2 = ac.q2(o, a)

        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a2, logp_a2 = ac.pi(o2)

            # Target Q-values
            q1_pi_targ = ac_targ.q1(o2, a2)
            q2_pi_targ = ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup) ** 2).mean()
        loss_q2 = ((q2 - backup) ** 2).mean()
        loss_q = loss_q1 + loss_q2

        # Useful info for logging
        q_info = dict(Q1Vals=q1.detach().numpy(),
                      Q2Vals=q2.detach().numpy())

        return loss_q, q_info

    # Set up function for computing SAC pi loss
    def compute_loss_pi(data):
        o = data['obs']
        pi, logp_pi = ac.pi(o)
        q1_pi = ac.q1(o, pi)
        q2_pi = ac.q2(o, pi)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = (alpha * logp_pi - q_pi).mean()

        # Useful info for logging
        pi_info = dict(LogPi=logp_pi.detach().numpy())

        return loss_pi, pi_info

    # Set up optimizers for policy and q-function
    pi_optimizer = Adam(ac.pi.parameters(), lr=3e-4)
    q_optimizer = Adam(q_params, lr=3e-4)

    # Set up model saving
    logger.setup_pytorch_saver(ac)

    def update(data, logger_tensor, t):
        # First run one gradient descent step for Q1 and Q2
        q_optimizer.zero_grad()
        loss_q, q_info = compute_loss_q(data)
        loss_q.backward()
        q_optimizer.step()

        # Record things
        logger.store(LossQ=loss_q.item(), **q_info)
        logger_tensor.log_value(t, loss_q.item(), "loss q")

        # Freeze Q-networks so you don't waste computational effort
        # computing gradients for them during the policy learning step.
        for p in q_params:
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        pi_optimizer.zero_grad()
        loss_pi, pi_info = compute_loss_pi(data)
        loss_pi.backward()
        pi_optimizer.step()

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in q_params:
            p.requires_grad = True

        # Record things
        logger.store(LossPi=loss_pi.item(), **pi_info)
        logger_tensor.log_value(t, loss_pi.item(), "loss pi")

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(polyak)
                p_targ.data.add_((1 - polyak) * p.data)

    def get_action(o, deterministic=False):
        return ac.act(torch.as_tensor(o, dtype=torch.float32),
                      deterministic)

    def test_agent():
        for j in range(num_test_episodes):
            o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, _ = test_env.step(get_action(o, True))
                ep_ret += r
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
            logger_tensor.log_value(t, ep_ret, "test ep reward")
            logger_tensor.log_value(t, ep_len, "test ep length")

    # Prepare for interaction with environment
    total_steps = steps_per_epoch * epochs
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0
    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):

        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy.
        if t > start_steps:
            a = get_action(o)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, r, d, _ = env.step(a)
        ep_ret += r
        ep_len += 1
        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d
        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len):
            logger_tensor.log_value(t, ep_ret, "reward")
            logging.info("> total_steps={} | reward={}".format(t, ep_ret))
            logger.store(EpRet=ep_ret, EpLen=ep_len)
            o, ep_ret, ep_len = env.reset(), 0, 0


        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = replay_buffer.sample_batch(batch_size)
                update(data=batch, logger_tensor = logger_tensor, t = t)

        # End of epoch handling
        if (t + 1) % steps_per_epoch == 0:
            epoch = (t + 1) // steps_per_epoch

            # Save model
            if (epoch % save_freq == 0) or (epoch == epochs):
                logger.save_state({'env': env}, None)

            # Test the performance of the deterministic version of the agent.
            test_agent()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('LogPi', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('Time', time.time() - start_time)

            logger_tensor.log_value(t, epoch, "epoch")
            logger.dump_tabular(logger_tensor=logger_tensor,epoch = epoch)
            ac.save(args.save_model_dir, args.model_name)
Ejemplo n.º 6
0
def ppo(env_fn,
        actor_critic=a2c,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=50,
        gamma=.99,
        clip_ratio=.2,
        pi_lr=3e-4,
        vf_lr=1e-3,
        train_pi_iters=80,
        train_v_iters=80,
        lam=.97,
        max_ep_len=1000,
        target_kl=.01,
        logger_kwargs=dict(),
        save_freq=10):

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    seed += 10000 * proc_id()
    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    # Share action space structure with the actor_critic
    ac_kwargs['action_space'] = env.action_space

    x_ph, a_ph = tf.placeholder( name="x_ph", shape=[None, obs_dim], dtype=tf.float32), \
        tf.placeholder( name="a_ph", shape=[None, act_dim], dtype=tf.float32)
    adv_ph, ret_ph, logp_old_ph = tf.placeholder( name="adv_ph", shape=[None], dtype=tf.float32), \
        tf.placeholder( name="ret_ph", shape=[None], dtype=tf.float32), \
        tf.placeholder( name="logp_old_ph", shape=[None], dtype=tf.float32)

    # Main outputs from computation graph
    # print( actor_critic( x_ph, a_ph, **ac_kwargs))
    pi, logp, logp_pi, v = actor_critic(x_ph, a_ph, **ac_kwargs)

    all_phs = [x_ph, a_ph, adv_ph, ret_ph, logp_old_ph]

    get_action_ops = [pi, v, logp_pi]

    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)

    # helpers for var count
    def get_vars(scope=''):
        return [x for x in tf.trainable_variables() if scope in x.name]

    def count_vars(scope=''):
        v = get_vars(scope)
        return sum([np.prod(var.shape.as_list()) for var in v])

    var_counts = tuple(count_vars(scope) for scope in ['pi', 'v'])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # PPO Objectives
    ratio = tf.exp(logp - logp_old_ph)
    min_adv = tf.where(adv_ph > 0, (1 + clip_ratio) * adv_ph,
                       (1 - clip_ratio) * adv_ph)
    pi_loss = -tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
    v_loss = tf.reduce_mean((ret_ph - v)**2)

    # Stats to watch
    approx_kl = tf.reduce_mean(
        logp_old_ph -
        logp)  # a sample estimate for KL-divergence, easy to compute
    approx_ent = tf.reduce_mean(-logp)

    clipped = tf.logical_or(ratio > (1 + clip_ratio), ratio < (1 - clip_ratio))
    clipfrac = tf.reduce_mean(tf.cast(clipped, tf.float32))

    train_pi = MpiAdamOptimizer(learning_rate=pi_lr).minimize(pi_loss)
    train_v = MpiAdamOptimizer(learning_rate=vf_lr).minimize(v_loss)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Sync params across processes
    sess.run(sync_all_params())

    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x': x_ph}, outputs={'pi': pi, 'v': v})

    def update():
        inputs = {k: v for k, v in zip(all_phs, buf.get())}
        pi_l_old, v_l_old, ent = sess.run([pi_loss, v_loss, approx_ent],
                                          feed_dict=inputs)

        for i in range(train_pi_iters):
            _, kl = sess.run([train_pi, approx_kl], feed_dict=inputs)

            def mpi_avg(x):
                """Average a scalar or vector over MPI processes."""
                return mpi_sum(x) / num_procs()

            kl = mpi_avg(kl)

            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break

        logger.store(StopIter=i)
        for _ in range(train_v_iters):
            sess.run(train_v, feed_dict=inputs)

        # Log changes from update
        pi_l_new, v_l_new, kl, cf = sess.run(
            [pi_loss, v_loss, approx_kl, clipfrac], feed_dict=inputs)
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    for epoch in range(epochs):
        for t in range(local_steps_per_epoch):
            a, v_t, logp_t = sess.run(get_action_ops,
                                      feed_dict={x_ph: o.reshape(1, -1)})

            # save and log
            buf.store(o, a, r, v_t, logp_t)
            logger.store(VVals=v_t)

            o, r, d, _ = env.step(a[0])
            ep_ret += r
            ep_len += 1

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == local_steps_per_epoch - 1):
                if not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)
                # if trajectory didn't reach terminal state, bootstrap value target
                last_val = r if d else sess.run(
                    v, feed_dict={x_ph: o.reshape(1, -1)})
                buf.finish_path(last_val)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, None)

        # Perform PPO update!
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()
def ppo(env_fn,
        actor_critic=core.MLPActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=50,
        gamma=0.99,
        clip_ratio=0.2,
        pi_lr=3e-4,
        vf_lr=1e-3,
        train_pi_iters=80,
        train_v_iters=80,
        lam=0.97,
        max_ep_len=1000,
        target_kl=0.01,
        logger_kwargs=dict(),
        save_freq=10):
    """
    Proximal Policy Optimization (by clipping), 

    with early stopping based on approximate KL

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The constructor method for a PyTorch Module with a 
            ``step`` method, an ``act`` method, a ``pi`` module, and a ``v`` 
            module. The ``step`` method should accept a batch of observations 
            and return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``a``        (batch, act_dim)  | Numpy array of actions for each 
                                           | observation.
            ``v``        (batch,)          | Numpy array of value estimates
                                           | for the provided observations.
            ``logp_a``   (batch,)          | Numpy array of log probs for the
                                           | actions in ``a``.
            ===========  ================  ======================================

            The ``act`` method behaves the same as ``step`` but only returns ``a``.

            The ``pi`` module's forward call should accept a batch of 
            observations and optionally a batch of actions, and return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       N/A               | Torch Distribution object, containing
                                           | a batch of distributions describing
                                           | the policy for the provided observations.
            ``logp_a``   (batch,)          | Optional (only returned if batch of
                                           | actions is given). Tensor containing 
                                           | the log probability, according to 
                                           | the policy, of the provided actions.
                                           | If actions not given, will contain
                                           | ``None``.
            ===========  ================  ======================================

            The ``v`` module's forward call should accept a batch of observations
            and return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``v``        (batch,)          | Tensor containing the value estimates
                                           | for the provided observations. (Critical: 
                                           | make sure to flatten this!)
            ===========  ================  ======================================


        ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object 
            you provided to PPO.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.

        gamma (float): Discount factor. (Always between 0 and 1.)

        clip_ratio (float): Hyperparameter for clipping in the policy objective.
            Roughly: how far can the new policy go from the old policy while 
            still profiting (improving the objective function)? The new policy 
            can still go farther than the clip_ratio says, but it doesn't help
            on the objective anymore. (Usually small, 0.1 to 0.3.) Typically
            denoted by :math:`\epsilon`. 

        pi_lr (float): Learning rate for policy optimizer.

        vf_lr (float): Learning rate for value function optimizer.

        train_pi_iters (int): Maximum number of gradient descent steps to take 
            on policy loss per epoch. (Early stopping may cause optimizer
            to take fewer than this.)

        train_v_iters (int): Number of gradient descent steps to take on 
            value function per epoch.

        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        target_kl (float): Roughly what KL divergence we think is appropriate
            between new and old policies after an update. This will get used 
            for early stopping. (Usually small, 0.01 or 0.05.)

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    # GAedit
    # Special function to avoid certain slowdowns from PyTorch + MPI combo.
    # setup_pytorch_for_mpi()

    # Set up logger and save configuration
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    # GAedit
    # Seed
    seed = 333
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Instantiate environment
    env = env_fn()
    #GAedit
    # obs_dim = env.observation_space.shape
    # act_dim = env.action_space.shape
    # get the default brain
    brain_name = env.brain_names[0]
    brain = env.brains[brain_name]
    # reset the environment
    env_info = env.reset(train_mode=True)[brain_name]
    # number of agents
    num_agents = len(env_info.agents)
    # size of each action
    act_dim = brain.vector_action_space_size
    # examine the state space
    obs_dim = env_info.vector_observations.shape[1]

    #GAedit
    # Create actor-critic module
    # ac = actor_critic(env.observation_space, env.action_space, **ac_kwargs)
    ac = actor_critic(obs_dim, act_dim, **ac_kwargs)

    # GAedit - don't think we need to sync
    # Sync params across processes
    # sync_params(ac)

    # Count variables
    var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.v])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # Set up experience buffer
    # GAedit
    # local_steps_per_epoch = int(steps_per_epoch / num_procs())
    local_steps_per_epoch = int(steps_per_epoch / num_agents)
    #GAedit
    buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch * num_agents,
                    gamma, lam)

    # buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)

    # Set up function for computing PPO policy loss
    def compute_loss_pi(data):
        obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data[
            'logp']

        # Policy loss
        pi, logp = ac.pi(obs, act)
        ratio = torch.exp(logp - logp_old)
        clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv
        loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()

        # Useful extra info
        approx_kl = (logp_old - logp).mean().item()
        ent = pi.entropy().mean().item()
        clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
        clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().item()
        pi_info = dict(kl=approx_kl, ent=ent, cf=clipfrac)

        return loss_pi, pi_info

    # Set up function for computing value loss
    def compute_loss_v(data):
        obs, ret = data['obs'], data['ret']
        return ((ac.v(obs) - ret)**2).mean()

    # Set up optimizers for policy and value function
    pi_optimizer = Adam(ac.pi.parameters(), lr=pi_lr)
    vf_optimizer = Adam(ac.v.parameters(), lr=vf_lr)

    # Set up model saving
    logger.setup_pytorch_saver(ac)

    def update():
        data = buf.get()

        pi_l_old, pi_info_old = compute_loss_pi(data)
        pi_l_old = pi_l_old.item()
        v_l_old = compute_loss_v(data).item()

        # Train policy with multiple steps of gradient descent
        for i in range(train_pi_iters):
            pi_optimizer.zero_grad()
            loss_pi, pi_info = compute_loss_pi(data)
            #GAedit
            # kl = mpi_avg(pi_info['kl'])
            kl = pi_info['kl']
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break
            loss_pi.backward()
            #GAedit
            # mpi_avg_grads(ac.pi)    # average grads across MPI processes
            # ac.pi.mean()
            pi_optimizer.step()

        logger.store(StopIter=i)

        # Value function learning
        for i in range(train_v_iters):
            vf_optimizer.zero_grad()
            loss_v = compute_loss_v(data)
            loss_v.backward()
            #GAedit
            # mpi_avg_grads(ac.v)    # average grads across MPI processes
            vf_optimizer.step()

        # Log changes from update
        kl, ent, cf = pi_info['kl'], pi_info_old['ent'], pi_info['cf']
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(loss_pi.item() - pi_l_old),
                     DeltaLossV=(loss_v.item() - v_l_old))

    # Prepare for interaction with environment
    start_time = time.time()
    #GAedit
    # o, ep_ret, ep_len = env.reset(), 0, 0
    ep_ret, ep_len = 0, 0
    env_info = env.reset(train_mode=True)[brain_name]
    o = env_info.vector_observations
    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        for t in range(local_steps_per_epoch):
            a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))
            # GAedit
            # next_o, r, d, _ = env.step(a)
            env_info = env.step(a)[brain_name]
            next_o, r, d = env_info.vector_observations, env_info.rewards, env_info.local_done
            #GAedit
            # ep_ret += r
            ep_ret += np.mean(r)
            ep_len += 1

            # save and log
            #GAedit
            # buf.store(o, a, r, v, logp)
            for i in range(20):
                buf.store(o[i], a[i], r[i], v[i], logp[i])
            logger.store(VVals=v)

            # Update obs (critical!)
            o = next_o

            timeout = ep_len == max_ep_len
            # GAedit
            # terminal = d or timeout
            terminal = any(d) or timeout
            epoch_ended = t == local_steps_per_epoch - 1

            if terminal or epoch_ended:
                if epoch_ended and not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len,
                          flush=True)
                # if trajectory didn't reach terminal state, bootstrap value target
                if timeout or epoch_ended:
                    _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
                else:
                    v = 0
                buf.finish_path(v)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                # GAedit
                # o, ep_ret, ep_len = env.reset(), 0, 0
                ep_ret, ep_len = 0, 0
                env_info = env.reset(train_mode=True)[brain_name]
                o = env_info.vector_observations

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, None)

        # Perform PPO update!
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()
Ejemplo n.º 8
0
def ppo(env_fn,
        actor_critic=core.mlp_actor_critic,
        ac_kwargs=dict(),
        seed=0,
        trials_per_epoch=2500,
        steps_per_trial=100,
        epochs=50,
        gamma=0.99,
        clip_ratio=0.2,
        pi_lr=3e-4,
        vf_lr=1e-3,
        train_pi_iters=1000,
        train_v_iters=80,
        lam=0.97,
        max_ep_len=1000,
        target_kl=0.01,
        logger_kwargs=dict(),
        save_freq=10):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols
            for state, ``x_ph``, and action, ``a_ph``, and returns the main
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Samples actions from policy given
                                           | states.
            ``logp``     (batch,)          | Gives log probability, according to
                                           | the policy, of taking actions ``a_ph``
                                           | in states ``x_ph``.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``.
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x_ph``. (Critical: make sure
                                           | to flatten this!)
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic
            function you provided to PPO.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.

        gamma (float): Discount factor. (Always between 0 and 1.)

        clip_ratio (float): Hyperparameter for clipping in the policy objective.
            Roughly: how far can the new policy go from the old policy while
            still profiting (improving the objective function)? The new policy
            can still go farther than the clip_ratio says, but it doesn't help
            on the objective anymore. (Usually small, 0.1 to 0.3.)

        pi_lr (float): Learning rate for policy optimizer.

        vf_lr (float): Learning rate for value function optimizer.

        train_pi_iters (int): Maximum number of gradient descent steps to take
            on policy loss per epoch. (Early stopping may cause optimizer
            to take fewer than this.)

        train_v_iters (int): Number of gradient descent steps to take on
            value function per epoch.

        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        target_kl (float): Roughly what KL divergence we think is appropriate
            between new and old policies after an update. This will get used
            for early stopping. (Usually small, 0.01 or 0.05.)

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    seed += 10000 * proc_id()
    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Inputs to computation graph
    # x_ph, a_ph = core.placeholders_from_spaces(env.observation_space, env.action_space)
    x_ph = tf.placeholder(dtype=tf.float32, shape=(None, None, 1), name='x_ph')
    a_ph = tf.placeholder(dtype=tf.int32, shape=(None, None), name='a_ph')
    # adv_ph, ret_ph, logp_old_ph, rew_ph = core.placeholders(None, None, None, 1)
    adv_ph = tf.placeholder(dtype=tf.float32,
                            shape=(None, None),
                            name='adv_ph')
    ret_ph = tf.placeholder(dtype=tf.float32,
                            shape=(None, None),
                            name='ret_ph')
    logp_old_ph = tf.placeholder(dtype=tf.float32,
                                 shape=(None, None),
                                 name='logp_old_ph')
    rew_ph = tf.placeholder(dtype=tf.float32,
                            shape=(None, None, 1),
                            name='rew_ph')
    pi_state_ph = tf.placeholder(dtype=tf.float32,
                                 shape=(None, NUM_GRU_UNITS),
                                 name='pi_state_ph')
    v_state_ph = tf.placeholder(dtype=tf.float32,
                                shape=(None, NUM_GRU_UNITS),
                                name='v_state_ph')

    # Initialize rnn states for pi and v

    # Main outputs from computation graph
    pi, logp, logp_pi, v, new_pi_state, new_v_state = actor_critic(
        x_ph,
        a_ph,
        rew_ph,
        pi_state_ph,
        v_state_ph,
        NUM_GRU_UNITS,
        action_space=env.action_space)

    # Need all placeholders in *this* order later (to zip with data from buffer)
    all_phs = [x_ph, a_ph, adv_ph, ret_ph, logp_old_ph, rew_ph]

    # Every step, get: action, value, and logprob and reward
    get_action_ops = [pi, v, logp_pi, new_pi_state, new_v_state]

    # Experience buffer
    steps_per_epoch = trials_per_epoch * steps_per_trial
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)

    # Count variables
    var_counts = tuple(core.count_vars(scope) for scope in ['pi', 'v'])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # PPO objectives
    ratio = tf.exp(logp - logp_old_ph)  # pi(a|s) / pi_old(a|s)
    min_adv = tf.where(adv_ph > 0, (1 + clip_ratio) * adv_ph,
                       (1 - clip_ratio) * adv_ph)
    pi_loss = -tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
    v_loss = tf.reduce_mean((ret_ph - v)**2)

    # Info (useful to watch during learning)
    approx_kl = tf.reduce_mean(
        logp_old_ph -
        logp)  # a sample estimate for KL-divergence, easy to compute
    approx_ent = tf.reduce_mean(
        -logp)  # a sample estimate for entropy, also easy to compute
    clipped = tf.logical_or(ratio > (1 + clip_ratio), ratio < (1 - clip_ratio))
    clipfrac = tf.reduce_mean(tf.cast(clipped, tf.float32))

    # Optimizers
    train_pi = MpiAdamOptimizer(
        learning_rate=pi_lr).minimize(pi_loss - 0.01 * approx_ent)
    train_v = MpiAdamOptimizer(learning_rate=vf_lr).minimize(v_loss)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Sync params across processes
    sess.run(sync_all_params())

    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x': x_ph}, outputs={'pi': pi, 'v': v})

    # tf.reset_default_graph()
    # restore_tf_graph(sess, '..//data//ppo//ppo_s0//simple_save')

    def update():
        inputs = {k: v for k, v in zip(all_phs, buf.get())}
        inputs[pi_state_ph] = np.zeros((trials_per_epoch, NUM_GRU_UNITS))
        inputs[v_state_ph] = np.zeros((trials_per_epoch, NUM_GRU_UNITS))
        pi_l_old, v_l_old, ent = sess.run([pi_loss, v_loss, approx_ent],
                                          feed_dict=inputs)
        print(pi_l_old, v_l_old)
        # Training
        for i in range(train_pi_iters):
            # print(f'pi:{i}')
            _, kl = sess.run([train_pi, approx_kl], feed_dict=inputs)
            # print(sess.run(pi_loss, feed_dict=inputs))
            kl = mpi_avg(kl)
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break
        logger.store(StopIter=i)
        for _ in range(train_v_iters):
            # print(f'v:{_}')
            sess.run(train_v, feed_dict=inputs)

        # Log changes from update
        import datetime
        print(f'finish one batch training at {datetime.datetime.now()}')
        pi_l_new, v_l_new, kl, cf = sess.run(
            [pi_loss, v_loss, approx_kl, clipfrac], feed_dict=inputs)
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    # Main loop: collect experience in env and update/log each epoch

    for epoch in range(epochs):
        for trial in range(trials_per_epoch):
            print(f'trial: {trial}')
            old_a = np.array([0]).reshape(1, 1)
            old_r = np.array([0]).reshape((1, 1, 1))
            means = env.sample_tasks(1)[0]
            action_dict = defaultdict(int)
            for i in range(env.action_space.n):
                action_dict[i] = 0

            env.reset_task_simple(means)
            task_avg = 0.0
            pi_state_t = np.zeros((1, NUM_GRU_UNITS))
            v_state_t = np.zeros((1, NUM_GRU_UNITS))
            for step in range(steps_per_trial):
                a, v_t, logp_t, pi_state_t, v_state_t = sess.run(
                    get_action_ops,
                    feed_dict={
                        x_ph: o.reshape(1, 1, -1),
                        a_ph: old_a,
                        rew_ph: old_r,
                        pi_state_ph: pi_state_t,
                        v_state_ph: v_state_t
                    })
                # save and log
                buf.store(o, a, r, v_t, logp_t)
                logger.store(VVals=v_t)

                try:
                    o, r, d, _ = env.step(a[0][0])
                except:
                    print(a)
                    raise AssertionError

                action_dict[a[0][0]] += 1

                old_a = np.array(a).reshape(1, 1)
                old_r = np.array([r]).reshape(1, 1, 1)
                ep_ret += r
                task_avg += r
                ep_len += 1

                terminal = d or (ep_len == max_ep_len)
                if terminal or (step == local_steps_per_epoch - 1):
                    if not (terminal):
                        print(
                            'Warning: trajectory cut off by epoch at %d steps.'
                            % ep_len)
                    # if trajectory didn't reach terminal state, bootstrap value target
                    last_val = r if d else sess.run(
                        v, feed_dict={x_ph: o.reshape(1, -1)})
                    buf.finish_path(last_val)
                    if terminal:
                        # only save EpRet / EpLen if trajectory finished
                        logger.store(EpRet=ep_ret, EpLen=ep_len)

                    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

            # logger.log_tabular('Epoch', epoch)
            # logger.log_tabular('EpRet', with_min_and_max=True)
            # logger.log_tabular('Means', means)
            # logger.dump_tabular()
            print(f'avg in trial {trial}: {task_avg / steps_per_trial}')
            print(f'Means in trial {trial}: {means}')

            print(action_dict)

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, None)
            # saved_path = saver.save(sess, f"/tmp/model_epoch{epoch}.ckpt")
            # print(f'Model saved in {saved_path}')
        # Perform PPO update!

        update()
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()
Ejemplo n.º 9
0
def ppo(env_fn,
        actor_critic=core.MLPActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=50,
        gamma=0.99,
        clip_ratio=0.2,
        pi_lr=3e-4,
        vf_lr=1e-3,
        train_pi_iters=80,
        train_v_iters=80,
        lam=0.97,
        max_ep_len=2000,
        target_kl=0.01,
        logger_kwargs=dict(),
        save_freq=10):

    global RENDER, BONUS
    """
    Proximal Policy Optimization (by clipping), 

    with early stopping based on approximate KL

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The constructor method for a PyTorch Module with a 
            ``step`` method, an ``act`` method, a ``pi`` module, and a ``v`` 
            module. The ``step`` method should accept a batch of observations 
            and return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``a``        (batch, act_dim)  | Numpy array of actions for each 
                                           | observation.
            ``v``        (batch,)          | Numpy array of value estimates
                                           | for the provided observations.
            ``logp_a``   (batch,)          | Numpy array of log probs for the
                                           | actions in ``a``.
            ===========  ================  ======================================

            The ``act`` method behaves the same as ``step`` but only returns ``a``.

            The ``pi`` module's forward call should accept a batch of 
            observations and optionally a batch of actions, and return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       N/A               | Torch Distribution object, containing
                                           | a batch of distributions describing
                                           | the policy for the provided observations.
            ``logp_a``   (batch,)          | Optional (only returned if batch of
                                           | actions is given). Tensor containing 
                                           | the log probability, according to 
                                           | the policy, of the provided actions.
                                           | If actions not given, will contain
                                           | ``None``.
            ===========  ================  ======================================

            The ``v`` module's forward call should accept a batch of observations
            and return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``v``        (batch,)          | Tensor containing the value estimates
                                           | for the provided observations. (Critical: 
                                           | make sure to flatten this!)
            ===========  ================  ======================================


        ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object 
            you provided to PPO.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.

        gamma (float): Discount factor. (Always between 0 and 1.)

        clip_ratio (float): Hyperparameter for clipping in the policy objective.
            Roughly: how far can the new policy go from the old policy while 
            still profiting (improving the objective function)? The new policy 
            can still go farther than the clip_ratio says, but it doesn't help
            on the objective anymore. (Usually small, 0.1 to 0.3.) Typically
            denoted by :math:`\epsilon`. 

        pi_lr (float): Learning rate for policy optimizer.

        vf_lr (float): Learning rate for value function optimizer.

        train_pi_iters (int): Maximum number of gradient descent steps to take 
            on policy loss per epoch. (Early stopping may cause optimizer
            to take fewer than this.)

        train_v_iters (int): Number of gradient descent steps to take on 
            value function per epoch.

        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        target_kl (float): Roughly what KL divergence we think is appropriate
            between new and old policies after an update. This will get used 
            for early stopping. (Usually small, 0.01 or 0.05.)

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    # Reachability Trainer
    r_network = R_Network().to(device)
    trainer = R_Network_Trainer(r_network=r_network, exp_name="random1")
    episodic_memory = EpisodicMemory(embedding_shape=[EMBEDDING_DIM])

    # Special function to avoid certain slowdowns from PyTorch + MPI combo.
    setup_pytorch_for_mpi()

    # Set up logger and save configuration
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    # Random seed
    seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Instantiate environment
    env = env_fn()
    observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(3, 64, 64))
    action_space = gym.spaces.Discrete(3)
    obs_dim = observation_space.shape
    act_dim = action_space.shape

    # Create actor-critic module
    ac = actor_critic(observation_space, action_space, **ac_kwargs)

    # Sync params across processes
    sync_params(ac)

    # Count variables
    var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.v])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # Set up experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)

    # Set up function for computing PPO policy loss
    def compute_loss_pi(data):
        obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data[
            'logp']

        # Policy loss
        pi, logp = ac.pi(obs, act)
        ratio = torch.exp(logp - logp_old)
        clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv
        loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()

        # Useful extra info
        approx_kl = (logp_old - logp).mean().item()
        ent = pi.entropy().mean().item()
        clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
        clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().item()
        pi_info = dict(kl=approx_kl, ent=ent, cf=clipfrac)

        return loss_pi, pi_info

    # Set up function for computing value loss
    def compute_loss_v(data):
        obs, ret = data['obs'], data['ret']
        return ((ac.v(obs) - ret)**2).mean()

    # Set up optimizers for policy and value function
    pi_optimizer = Adam(ac.pi.parameters(), lr=pi_lr)
    vf_optimizer = Adam(ac.v.parameters(), lr=vf_lr)

    # Set up model saving
    logger.setup_pytorch_saver(ac)

    def update():
        data = buf.get()

        pi_l_old, pi_info_old = compute_loss_pi(data)
        pi_l_old = pi_l_old.item()
        v_l_old = compute_loss_v(data).item()

        # Train policy with multiple steps of gradient descent
        for i in range(train_pi_iters):
            pi_optimizer.zero_grad()
            loss_pi, pi_info = compute_loss_pi(data)
            # Entropy bonus
            loss_pi += pi_info['ent'] * 0.0021
            kl = mpi_avg(pi_info['kl'])
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break
            loss_pi.backward()
            mpi_avg_grads(ac.pi)  # average grads across MPI processes
            pi_optimizer.step()

        logger.store(StopIter=i)

        # Value function learning
        for i in range(train_v_iters):
            vf_optimizer.zero_grad()
            loss_v = compute_loss_v(data)
            loss_v.backward()
            mpi_avg_grads(ac.v)  # average grads across MPI processes
            vf_optimizer.step()

        # Log changes from update
        kl, ent, cf = pi_info['kl'], pi_info_old['ent'], pi_info['cf']
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(loss_pi.item() - pi_l_old),
                     DeltaLossV=(loss_v.item() - v_l_old))

    # Prepare for interaction with environment
    start_time = time.time()
    o, _ = env.reset()
    env.render()
    o = o.astype(np.float32) / 255.
    o = o.transpose(2, 0, 1)
    ep_ret, ep_len = 0, 0
    indices = []

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        for t in range(local_steps_per_epoch):
            state = torch.as_tensor(o[np.newaxis, ...], dtype=torch.float32)
            a, v, logp = ac.step(state)

            next_o, r, d, info = env.step(a)
            next_o = next_o.astype(np.float32) / 255.

            d = ep_len == max_ep_len
            trainer.store_new_state([next_o], [r], [d], [None])

            r_network.eval()
            with torch.no_grad():
                state_embedding = r_network.embed_observation(
                    torch.FloatTensor([o]).to(device)).cpu().numpy()[0]
                aggregated, _, _ = similarity_to_memory(
                    state_embedding, episodic_memory, r_network)
                curiosity_bonus = 0.03 * (0.5 - aggregated)
                if BONUS:
                    print(f'{curiosity_bonus:.3f}')
                if curiosity_bonus > 0 or len(episodic_memory) == 0:
                    idx = episodic_memory.store_new_state(state_embedding)
                    x = int(env.map_scale * info['pose']['x'])
                    y = int(env.map_scale * info['pose']['y'])
                    if idx == len(indices):
                        indices.append((x, y))
                    else:
                        indices[idx] = (x, y)

            r_network.train()

            next_o = next_o.transpose(2, 0, 1)
            ep_ret += r + curiosity_bonus
            ep_len += 1

            # save and log
            buf.store(o, a, r, v, logp)
            logger.store(VVals=v)

            k = cv2.waitKey(1)
            if k == ord('s'):
                RENDER = 1 - RENDER
            elif k == ord('b'):
                BONUS = 1 - BONUS

            if RENDER:
                env.info['map'] = cv2.flip(env.info['map'], 0)
                for index in indices:
                    cv2.circle(env.info['map'], index, 3, (0, 0, 255), -1)
                env.info['map'] = cv2.flip(env.info['map'], 0)
                env.render()

            # Update obs (critical!)
            o = next_o

            timeout = ep_len == max_ep_len
            terminal = d or timeout
            epoch_ended = t == local_steps_per_epoch - 1

            if terminal or epoch_ended:
                if epoch_ended and not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len,
                          flush=True)
                # if trajectory didn't reach terminal state, bootstrap value target
                if timeout or epoch_ended:
                    state = torch.as_tensor(o[np.newaxis, ...],
                                            dtype=torch.float32)
                    _, v, _ = ac.step(state)
                else:
                    v = 0
                buf.finish_path(v)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                print(ep_ret, ep_len, len(episodic_memory))
                ep_ret, ep_len = 0, 0
                o, _ = env.reset()
                o = o.astype(np.float32) / 255.
                o = o.transpose(2, 0, 1)
                episodic_memory.reset()
                indices = []

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, None)

        # Perform PPO update!
        if epoch > 4:
            update()
            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('VVals', with_min_and_max=True)
            logger.log_tabular('TotalEnvInteracts',
                               (epoch + 1) * steps_per_epoch)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossV', average_only=True)
            logger.log_tabular('DeltaLossPi', average_only=True)
            logger.log_tabular('DeltaLossV', average_only=True)
            logger.log_tabular('Entropy', average_only=True)
            logger.log_tabular('KL', average_only=True)
            logger.log_tabular('ClipFrac', average_only=True)
            logger.log_tabular('StopIter', average_only=True)
            logger.log_tabular('Time', time.time() - start_time)
            logger.dump_tabular()

        else:
            buf.get()
Ejemplo n.º 10
0
def trpo(env_fn,
         actor_critic,
         ac_kwargs=dict(),
         seed=0,
         steps_per_epoch=4000,
         epochs=50,
         gamma=.99,
         delta=.01,
         vf_lr=1e-3,
         train_v_iters=80,
         damping_coeff=.1,
         cg_iters=10,
         backtrack_iters=10,
         backtrack_coeff=.8,
         lam=.97,
         max_ep_len=1000,
         logger_kwargs=dict(),
         save_freq=10,
         algo="trpo"):

    # LOgger tools
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    # Seed inits
    seed += 10000 * proc_id()
    tf.set_random_seed(seed)
    np.random.seed(seed)

    # Environment recreation
    env = env_fn()

    # Getting obs dims
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    ac_kwargs['action_space'] = env.action_space

    # Placeholders
    x_ph, a_ph = tf.placeholder( name="x_ph", shape=[None, obs_dim], dtype=tf.float32), \
        tf.placeholder( name="a_ph", shape=[None, act_dim], dtype=tf.float32)
    adv_ph, ret_ph, logp_old_ph = tf.placeholder( name="adv_ph", shape=[None], dtype=tf.float32), \
        tf.placeholder( name="ret_ph", shape=[None], dtype=tf.float32), \
        tf.placeholder( name="logp_old_ph", shape=[None], dtype=tf.float32)

    pi, logp, logp_pi, info, info_phs, d_kl, v = actor_critic(
        x_ph, a_ph, **ac_kwargs)

    def keys_as_sorted_list(dict):
        return sorted(list(dict.keys()))

    def values_as_sorted_list(dict):
        return [dict[k] for k in keys_as_sorted_list(dict)]

    all_phs = [x_ph, a_ph, adv_ph, ret_ph, logp_old_ph
               ] + values_as_sorted_list(info_phs)

    get_action_ops = [pi, v, logp_pi] + values_as_sorted_list(info)

    # Experience buffer init
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    info_shapes = {k: v.shape.as_list()[1:] for k, v in info_phs.items()}
    buf = GAEBuffer(obs_dim, act_dim, local_steps_per_epoch, info_shapes,
                    gamma, lam)

    # Count variables
    def get_vars(scope=''):
        return [x for x in tf.trainable_variables() if scope in x.name]

    def count_vars(scope=''):
        v = get_vars(scope)
        return sum([np.prod(var.shape.as_list()) for var in v])

    var_counts = tuple(count_vars(scope) for scope in ["pi", "v"])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # TRPO Losses
    ratio = tf.exp(logp - logp_old_ph)
    pi_loss = -tf.reduce_mean(ratio * adv_ph)
    v_loss = tf.reduce_mean((ret_ph - v)**2)

    # Optimizer for value function
    train_vf = MpiAdamOptimizer(learning_rate=vf_lr).minimize(v_loss)

    # CG solver requirements
    pi_params = get_vars("pi")

    # Some helpers
    def flat_concat(xs):
        return tf.concat([tf.reshape(x, (-1, )) for x in xs], axis=0)

    def flat_grad(f, params):
        return flat_concat(tf.gradients(xs=params, ys=f))

    def hessian_vector_product(f, params):
        g = flat_grad(f, params)
        x = tf.placeholder(tf.float32, shape=g.shape)

        return x, flat_grad(tf.reduce_sum(g * x), params)

    def assign_params_from_flat(x, params):
        flat_size = lambda p: int(np.prod(p.shape.as_list())
                                  )  # the 'int' is important for scalars
        splits = tf.split(x, [flat_size(p) for p in params])
        new_params = [
            tf.reshape(p_new, p.shape) for p, p_new in zip(params, splits)
        ]

        return tf.group(
            [tf.assign(p, p_new) for p, p_new in zip(params, new_params)])

    gradient = flat_grad(pi_loss, pi_params)
    v_ph, hvp = hessian_vector_product(d_kl, pi_params)
    if damping_coeff > 0:
        hvp += damping_coeff * v_ph

    # Symbols for getting and setting params
    get_pi_params = flat_concat(pi_params)
    set_pi_params = assign_params_from_flat(v_ph, pi_params)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Sync params across processes
    sess.run(sync_all_params())

    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x': x_ph}, outputs={'pi': pi, 'v': v})

    def cg(Ax, b):
        x = np.zeros_like(b)
        r = b.copy()
        p = r.copy()
        r_dot_old = np.dot(r, r)

        for _ in range(cg_iters):
            z = Ax(p)
            alpha = r_dot_old / (np.dot(p, z) + EPS)
            x += alpha * p
            r -= alpha * z
            r_dot_new = np.dot(r, r)
            p = r + (r_dot_new / r_dot_old) * p
            r_dot_old = r_dot_new
        return x

    def update():
        # Prepare hessian func, gradient eval
        # Always so elegant haha
        inputs = {k: v for k, v in zip(all_phs, buf.get())}

        def mpi_avg(x):
            """Average a scalar or vector over MPI processes."""
            return mpi_sum(x) / num_procs()

        Hx = lambda x: mpi_avg(sess.run(hvp, feed_dict={**inputs, v_ph: x}))
        g, pi_l_old, v_l_old = sess.run([gradient, pi_loss, v_loss],
                                        feed_dict=inputs)
        g, pi_l_old = mpi_avg(g), mpi_avg(pi_l_old)

        # Core calculations for TRPO or NPG
        x = cg(Hx, g)
        alpha = np.sqrt(2 * delta / (np.dot(x, Hx(x)) + EPS))  # OK
        old_params = sess.run(get_pi_params)

        def set_and_eval(step):
            sess.run(set_pi_params,
                     feed_dict={v_ph: old_params - alpha * x * step})

            return mpi_avg(sess.run([d_kl, pi_loss], feed_dict=inputs))

        if algo == 'npg':
            # npg has no backtracking or hard kl constraint enforcement
            kl, pi_l_new = set_and_eval(step=1.)
        elif algo == "trpo":
            for j in range(backtrack_iters):
                kl, pi_l_new = set_and_eval(step=backtrack_coeff**j)
                if kl <= delta and pi_l_new <= pi_l_old:
                    logger.log(
                        'Accepting new params at step %d of line search.' % j)
                    logger.store(BacktrackIters=j)
                    break

                if j == backtrack_iters - 1:
                    logger.log('Line search failed! Keeping old params.')
                    logger.store(BacktrackIters=j)
                    kl, pi_l_new = set_and_eval(step=0.)

        # Value function updates
        for _ in range(train_v_iters):
            sess.run(train_vf, feed_dict=inputs)
            v_l_new = sess.run(v_loss, feed_dict=inputs)

        # Log changes from update
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        for t in range(local_steps_per_epoch):
            agent_outs = sess.run(get_action_ops,
                                  feed_dict={x_ph: o.reshape(1, -1)})
            a, v_t, logp_t, info_t = agent_outs[0][0], agent_outs[
                1], agent_outs[2], agent_outs[3:]

            # Save and log
            buf.store(o, a, r, v_t, logp_t, info_t)
            logger.store(VVals=v_t)

            o, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == local_steps_per_epoch - 1):
                if not terminal:
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)

                last_val = r if d else sess.run(
                    v, feed_dict={x_ph: o.reshape(1, -1)})
                buf.finish_path(last_val)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, None)

        # Perform TRPO or NPG update!
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('KL', average_only=True)
        if algo == 'trpo':
            logger.log_tabular('BacktrackIters', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()
Ejemplo n.º 11
0
Archivo: iac.py Proyecto: zhc134/l2s
def iac(env_config, ac_type, ac_kwargs, rb_type, rb_kwargs, gamma, lr, polyak,
        batch_size, epochs, start_steps, steps_per_epoch, inc_ep, max_ep_len,
        test_max_ep_len, number_of_tests_per_epoch, q_pi_sample_size, z_dim,
        z_type, act_noise, test_without_state, logger_kwargs, seed):
    logger = EpochLogger(**logger_kwargs)
    configs = locals().copy()
    configs.pop("logger")
    logger.save_config(configs)

    tf.set_random_seed(seed)
    np.random.seed(seed)

    env, test_env = make_env(env_config), make_env(env_config)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    act_high = env.action_space.high

    # Inputs to computation graph
    x_ph, a_ph, z_ph, x2_ph, r_ph, d_ph = core.placeholders(
        obs_dim, act_dim, z_dim, obs_dim, None, None)

    actor_critic = core.get_iac_actor_critic(ac_type)
    # Main outputs from computation graph
    with tf.variable_scope('main'):
        pi, q1, q2, q1_pi, q2_pi, v = actor_critic(x_ph, a_ph, z_ph,
                                                   **ac_kwargs)

    # Target networks
    with tf.variable_scope('target'):
        _, _, _, _, _, v_targ = actor_critic(x2_ph, a_ph, z_ph, **ac_kwargs)

    # Experience buffer
    RB = get_replay_buffer(rb_type)
    replay_buffer = RB(obs_dim, act_dim, **rb_kwargs)

    # Count variables
    var_counts = tuple(
        core.count_vars(scope)
        for scope in ['main/pi', 'main/q', 'main/v', 'main'])
    print(
        '\nNumber of parameters: \t pi: %d, \t q: %d, \t v: %d, \t total: %d\n'
        % var_counts)

    # Bellman backup for Q and V function
    q_backup = tf.stop_gradient(r_ph + gamma * (1 - d_ph) * v_targ)
    min_q_pi = tf.minimum(q1_pi, q2_pi)
    v_backup = tf.stop_gradient(min_q_pi)

    # TD3 losses
    pi_loss = -tf.reduce_mean(q1_pi)
    q1_loss = 0.5 * tf.reduce_mean((q1 - q_backup)**2)
    q2_loss = 0.5 * tf.reduce_mean((q2 - q_backup)**2)
    v_loss = 0.5 * tf.reduce_mean((v - v_backup)**2)
    value_loss = q1_loss + q2_loss + v_loss

    # Separate train ops for pi, q
    policy_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    value_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    train_policy_op = policy_optimizer.minimize(pi_loss,
                                                var_list=get_vars('main/pi'))
    if ac_kwargs["pi_separate"]:
        train_policy_emb_op = policy_optimizer.minimize(
            pi_loss, var_list=get_vars('main/pi/emb'))
        train_policy_d_op = policy_optimizer.minimize(
            pi_loss, var_list=get_vars('main/pi/d'))
    train_value_op = value_optimizer.minimize(value_loss,
                                              var_list=get_vars('main/q') +
                                              get_vars('main/v'))

    # Polyak averaging for target variables
    target_update = tf.group([
        tf.assign(v_targ, polyak * v_targ + (1 - polyak) * v_main)
        for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
    ])

    # Initializing targets to match main variables
    target_init = tf.group([
        tf.assign(v_targ, v_main)
        for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
    ])

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    sess.run(target_init)

    def sample_z(size):
        if z_type == "uniform":
            return np.random.random_sample(size=size)
        elif z_type == "gaussian":
            return np.random.normal(size=size)
        else:
            raise Exception("z_type error")

    def get_action(o, noise_scale):
        pi_a = sess.run(pi,
                        feed_dict={
                            x_ph: o.reshape(1, -1),
                            z_ph: sample_z((1, z_dim))
                        })[0]
        pi_a += noise_scale * np.random.randn(act_dim)
        pi_a = np.clip(pi_a, 0, 1)
        real_a = pi_a * act_high
        return pi_a, real_a

    def test_agent(n=10):
        test_actions = []
        for j in range(n):
            test_actions_ep = []
            o, r, d, ep_ret, ep_len = test_env.reset(), 0, False, 0, 0
            while not (d or (ep_len == test_max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                if test_without_state:
                    _, real_a = get_action(np.zeros(o.shape), 0)
                else:
                    _, real_a = get_action(o, 0)
                test_actions_ep.append(real_a)
                o, r, d, _ = test_env.step(real_a)
                ep_ret += r
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
            test_actions.append(test_actions_ep)
        return test_actions

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    total_steps = steps_per_epoch * epochs

    rewards = []
    rets = []
    test_rets = []
    max_ret = None
    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):
        """
        Until start_steps have elapsed, randomly sample actions
        from a uniform distribution for better exploration. Afterwards, 
        use the learned policy (with some noise, via act_noise). 
        """
        if t > start_steps:
            pi_a, real_a = get_action(o, act_noise)
        else:
            pi_a, real_a = env.action_space.sample()

        # Step the env
        o2, r, d, _ = env.step(real_a)
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.store(o, pi_a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        if d or (ep_len == max_ep_len):

            for _ in range(ep_len):
                batch = replay_buffer.sample_batch(batch_size)
                feed_dict = {
                    x_ph: batch['obs1'],
                    x2_ph: batch['obs2'],
                    a_ph: batch['acts'],
                    r_ph: batch['rews'],
                    d_ph: batch['done']
                }
                feed_dict[z_ph] = sample_z((batch_size, z_dim))

                # Policy Learning update
                for key in feed_dict:
                    feed_dict[key] = np.repeat(feed_dict[key],
                                               q_pi_sample_size,
                                               axis=0)
                feed_dict[z_ph] = sample_z(
                    (batch_size * q_pi_sample_size, z_dim))
                if ac_kwargs["pi_separate"]:
                    if len(rewards) % 2 == 0:
                        outs = sess.run([pi_loss, train_policy_emb_op],
                                        feed_dict)
                    else:
                        outs = sess.run([pi_loss, train_policy_d_op],
                                        feed_dict)
                else:
                    outs = sess.run([pi_loss, train_policy_op], feed_dict)
                logger.store(LossPi=outs[0])

                # Q-learning update
                outs = sess.run([q1_loss, v_loss, q1, v, train_value_op],
                                feed_dict)
                logger.store(LossQ=outs[0],
                             LossV=outs[1],
                             ValueQ=outs[2],
                             ValueV=outs[3])

            logger.store(EpRet=ep_ret, EpLen=ep_len)
            rewards.append(ep_ret)
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # End of epoch wrap-up
        if (t + 1) % steps_per_epoch == 0:
            epoch = (t + 1) // steps_per_epoch

            # Test the performance of the deterministic version of the agent.
            test_actions = test_agent(number_of_tests_per_epoch)

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            ret = logger.log_tabular('EpRet', average_only=True)[0]
            test_ret = logger.log_tabular('TestEpRet', average_only=True)[0]
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('LossV', average_only=True)
            logger.log_tabular('ValueQ', average_only=True)
            logger.log_tabular('ValueV', average_only=True)
            logger.log_tabular('Time', time.time() - start_time)
            logger.dump_tabular()

            rets.append(ret)
            test_rets.append(test_ret)

            if max_ret is None or test_ret > max_ret:
                max_ret = test_ret
                best_test_actions = test_actions

            max_ep_len += inc_ep
            sess.run(target_update, feed_dict)

    logger.save_state(
        {
            "rewards": rewards,
            "best_test_actions": best_test_actions,
            "rets": rets,
            "test_rets": test_rets,
            "max_ret": max_ret
        }, None)

    util.plot_actions(best_test_actions, act_high,
                      logger.output_dir + '/best_test_actions.png')
    logger.log("max ret: %f" % max_ret)