示例#1
0
class Learner(tf.Module):
    def __init__(self, logger, replay_buffer):
        super(Learner, self).__init__(name="Learner")
        self.device = Params.DEVICE

        with tf.device(self.device), self.name_scope:
            self.dtype = Params.DTYPE
            self.logger = logger
            self.batch_size = Params.MINIBATCH_SIZE
            self.gamma = Params.GAMMA
            self.tau = Params.TAU
            self.replay_buffer = replay_buffer
            self.priority_beta = tf.Variable(Params.BUFFER_PRIORITY_BETA_START)
            self.running = tf.Variable(True)
            self.n_steps = tf.Variable(0)

            # Init Networks
            self.actor = ActorNetwork(with_target_net=True)
            self.critic = CriticNetwork()

            # Save shared variables
            self.policy_variables = self.actor.tvariables + self.actor.nvariables

    def save_model(self, path):
        self.actor.actor_network.save(path,
                                      overwrite=True,
                                      include_optimizer=False,
                                      save_format="tf")

    @tf.function()
    def run(self):
        print("retracing learner run")

        with tf.device(self.device), self.name_scope:

            # Wait for replay buffer to fill
            tf.cond(
                Params.BUFFER_FROM_REVERB,
                lambda: [],
                lambda: tf.while_loop(
                    lambda: tf.logical_and(
                        tf.less(self.replay_buffer.size(), Params.
                                MINIBATCH_SIZE), self.running),
                    lambda: [],
                    loop_vars=[],
                    parallel_iterations=1,
                ),
            )

            # Do training
            n_steps = tf.while_loop(lambda n_step: tf.logical_and(
                tf.less_equal(n_step, Params.MAX_STEPS_TRAIN), self.running),
                                    self.train_step,
                                    loop_vars=[tf.constant(1)])

            # Save number of performed steps
            self.n_steps.assign(tf.squeeze(n_steps))

    def train_step(self, n_step):
        print("retracing train_step")

        with tf.device(self.device), self.name_scope:

            print("Eager Execution:  ", tf.executing_eagerly())
            print("Eager Keras Model:", self.actor.actor_network.run_eagerly)

            if Params.BUFFER_FROM_REVERB:

                # Get batch from replay memory
                (s_batch, a_batch, _, _, s2_batch, target_z_atoms_batch), weights_batch, idxes_batch = \
                    self.replay_buffer.sample_batch(self.batch_size, self.priority_beta)

            else:

                # Get batch from replay memory
                (s_batch, a_batch, r_batch, t_batch, s2_batch, g_batch), weights_batch, idxes_batch = \
                    self.replay_buffer.sample_batch(self.batch_size, self.priority_beta)

                # Compute targets (bellman update)
                target_z_atoms_batch = tf.where(t_batch, 0., Params.Z_ATOMS)
                target_z_atoms_batch = r_batch + (target_z_atoms_batch *
                                                  g_batch)

            # Predict target Q value by target critic network
            target_action_batch = self.actor.target_actor_network(
                s2_batch, training=False)
            target_q_probs = tf.cast(
                self.critic.target_critic_network(
                    [s2_batch, target_action_batch], training=False),
                self.dtype)

            # Train the critic on given targets
            td_error_batch = self.critic.train(
                x=[s_batch, a_batch],
                target_z_atoms=target_z_atoms_batch,
                target_q_probs=target_q_probs,
                is_weights=weights_batch)

            # Compute actions for state batch
            actions = self.actor.actor_network(s_batch, training=False)

            # Compute and negate action values (to enable gradient ascent)
            values = self.critic.critic_network([s_batch, actions],
                                                training=False)

            # Compute (dq / da * da / dtheta = dq / dtheta) grads (action values grads wrt. actor network weights)
            action_gradients = tf.gradients(values, actions, Params.Z_ATOMS)[0]
            actor_gradients = tf.gradients(actions, self.actor.tvariables,
                                           -action_gradients)

            # Normalize grads element-wise
            actor_gradients = [
                tf.divide(gradient, tf.cast(self.batch_size, self.dtype))
                for gradient in actor_gradients
            ]

            # Apply gradients to actor net
            self.actor.actor_network.optimizer.apply_gradients(
                zip(actor_gradients, self.actor.tvariables))

            # Update target networks
            update_weights(
                self.actor.target_tvariables + self.actor.target_nvariables,
                self.actor.tvariables + self.actor.nvariables, self.tau)
            update_weights(
                self.critic.target_tvariables + self.critic.target_nvariables,
                self.critic.tvariables + self.critic.nvariables, self.tau)

            # Use critic TD error to update priorities
            self.replay_buffer.update_priorities(idxes_batch, td_error_batch)

            # Increment beta value
            self.priority_beta.assign_add(
                Params.BUFFER_PRIORITY_BETA_INCREMENT)

            # Log status
            tf.cond(
                tf.equal(
                    tf.math.mod(n_step, tf.constant(Params.LEARNER_LOG_STEPS)),
                    tf.constant(0)), lambda: self.logger.log_step_learner(
                        n_step,
                        tf.cast(tf.reduce_mean(td_error_batch), Params.DTYPE),
                        self.priority_beta), lambda: None)

            return tf.add(n_step, 1)