Example #1
0
def main():
    with tf.Session() as sess:

        actor = ActorNetwork(sess, STATE_DIM, ACTION_DIM, ACTION_BOUND,
                             ACTOR_LEARNING_RATE, TAU, MINIBATCH_SIZE)
        critic = CriticNetwork(sess, STATE_DIM, ACTION_DIM,
                               CRITIC_LEARNING_RATE, TAU,
                               actor.get_num_trainable_vars())

        #actor_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(ACTION_DIM))

        #TODO: Ornstein-Uhlenbeck noise.

        sess.run(tf.global_variables_initializer())

        # initialize target net
        actor.update_target_network()
        critic.update_target_network()

        # initialize replay memory
        replay_buffer = ReplayBuffer(BUFFER_SIZE)

        # main loop.
        for ep in range(MAX_EPISODES):

            episode_reward = 0
            ep_batch_avg_q = 0

            s = ENV.reset()

            for step in range(MAX_EP_STEPS):

                a = actor.predict(np.reshape(s,
                                             (1, STATE_DIM)))  #+ actor_noise()
                s2, r, terminal, info = ENV.step(a[0])
                #print(s2)

                replay_buffer.add(np.reshape(s, (STATE_DIM,)), \
                                np.reshape(a, (ACTION_DIM,)), \
                                r, \
                                terminal, \
                                np.reshape(s2, (STATE_DIM,)))

                # Batch sampling.
                if replay_buffer.size() > MINIBATCH_SIZE and \
                    step % TRAIN_INTERVAL == 0:
                    s_batch, a_batch, r_batch, t_batch, s2_batch = \
                        replay_buffer.sample_batch(MINIBATCH_SIZE)

                    # target Q値を計算.
                    target_action = actor.predict_target(s2_batch)
                    target_q = critic.predict_target(s2_batch, target_action)

                    # critic の target V値を計算.
                    targets = []
                    for i in range(MINIBATCH_SIZE):
                        if t_batch[i]:
                            # terminal
                            targets.append(r_batch[i])
                        else:
                            targets.append(r_batch[i] + GAMMA * target_q[i])

                    # Critic を train.
                    #TODO: predQはepisodeではなくrandom batchなのでepisode_avg_maxという統計は不適切.
                    pred_q, _ = critic.train(
                        s_batch, a_batch,
                        np.reshape(targets, (MINIBATCH_SIZE, 1)))

                    # Actor を train.
                    a_outs = actor.predict(s_batch)
                    grads = critic.action_gradients(s_batch, a_outs)
                    #print(grads[0].shape)
                    #exit(1)
                    actor.train(s_batch, grads[0])

                    # Update target networks.
                    # 数batchに一度にするべき?
                    actor.update_target_network()
                    critic.update_target_network()

                    ep_batch_avg_q += np.mean(pred_q)

                s = s2
                episode_reward += r

                if terminal:
                    print('Episode:', ep, 'Reward:', episode_reward)
                    reward_log.append(episode_reward)
                    q_log.append(ep_batch_avg_q / step)

                    break
Example #2
0
        for step in range(MAX_EP_STEPS):

            a = actor.predict(np.reshape(s, (1, STATE_DIM))) + actor_noise()
            s2, r, terminal, info = ENV.step(a[0])

            replay_buffer.add(np.reshape(s, (STATE_DIM,)), \
                              np.reshape(a, (ACTION_DIM,)), \
                              r, \
                              terminal, \
                              np.reshape(s2, (STATE_DIM,)))

            # Batch sampling.
            if replay_buffer.size() > MINIBATCH_SIZE:
                s_batch, a_batch, r_batch, t_batch, s2_batch = \
                    replay_buffer.sample_batch(MINIBATCH_SIZE)

                # target Q値を計算.
                target_action = actor.predict_target(s2_batch)
                target_q = critic.predict_target(s2_batch, target_action)

                # critic の target V値を計算.
                targets = []
                for i in range(MINIBATCH_SIZE):
                    if t_batch[i]:
                        # terminal
                        targets.append(r_batch[i])
                    else:
                        targets.append(r_batch[i] + GAMMA * target_q[i])

                # Critic を train.