Ejemplo n.º 1
0
    def _step(self, time_step, policy_state):
        policy_state = common.reset_state_if_necessary(policy_state,
                                                       self._initial_state,
                                                       time_step.is_first())
        policy_step = common.algorithm_step(
            self._algorithm,
            self._observation_transformer,
            time_step,
            state=policy_state,
            training=self._training,
            greedy_predict=self._greedy_predict)
        action = common.sample_action_distribution(policy_step.action)
        next_time_step = self._env_step(action)
        if self._observers:
            traj = from_transition(time_step,
                                   policy_step._replace(action=action),
                                   next_time_step)
            for observer in self._observers:
                observer(traj)
        if self._exp_observers:
            action_distribution_param = common.get_distribution_params(
                policy_step.action)
            exp = make_experience(
                time_step,
                policy_step._replace(action=action),
                action_distribution=action_distribution_param,
                state=policy_state if self._use_rollout_state else ())
            for observer in self._exp_observers:
                observer(exp)

        return next_time_step, policy_step, action
Ejemplo n.º 2
0
    def rollout(self, time_step: ActionTimeStep, state: ActorCriticState,
                mode):
        """Rollout for one step."""
        value, value_state = self._value_network(
            time_step.observation,
            step_type=time_step.step_type,
            network_state=state.value)

        action_distribution, actor_state = self._actor_network(
            time_step.observation,
            step_type=time_step.step_type,
            network_state=state.actor)

        action = common.sample_action_distribution(action_distribution)
        return PolicyStep(
            action=action,
            state=ActorCriticState(actor=actor_state, value=value_state),
            info=ActorCriticInfo(
                value=value, action_distribution=action_distribution))
Ejemplo n.º 3
0
    def _iter(self, time_step, policy_state):
        """One training iteration."""
        counter = tf.zeros((), tf.int32)
        batch_size = self._env.batch_size

        def create_ta(s):
            return tf.TensorArray(dtype=s.dtype,
                                  size=self._train_interval + 1,
                                  element_shape=tf.TensorShape(
                                      [batch_size]).concatenate(s.shape))

        training_info_ta = tf.nest.map_structure(create_ta,
                                                 self._training_info_spec)

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch(self._trainable_variables)
            [counter, time_step, policy_state, training_info_ta
             ] = tf.while_loop(cond=lambda *_: True,
                               body=self._train_loop_body,
                               loop_vars=[
                                   counter, time_step, policy_state,
                                   training_info_ta
                               ],
                               back_prop=True,
                               parallel_iterations=1,
                               maximum_iterations=self._train_interval,
                               name='iter_loop')

        if self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP:
            next_time_step, policy_step, action = self._step(
                time_step, policy_state)
            next_state = policy_step.state
        else:
            policy_step = common.algorithm_step(self._algorithm.rollout,
                                                self._observation_transformer,
                                                time_step, policy_state)
            action = common.sample_action_distribution(policy_step.action)
            next_time_step = time_step
            next_state = policy_state

        action_distribution_param = common.get_distribution_params(
            policy_step.action)

        final_training_info = make_training_info(
            action_distribution=action_distribution_param,
            action=action,
            reward=time_step.reward,
            discount=time_step.discount,
            step_type=time_step.step_type,
            info=policy_step.info)

        with tape:
            training_info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), training_info_ta,
                final_training_info)
            training_info = tf.nest.map_structure(lambda ta: ta.stack(),
                                                  training_info_ta)

            action_distribution = nested_distributions_from_specs(
                self._algorithm.action_distribution_spec,
                training_info.action_distribution)

            training_info = training_info._replace(
                action_distribution=action_distribution)

        loss_info, grads_and_vars = self._algorithm.train_complete(
            tape, training_info)

        del tape

        self._training_summary(training_info, loss_info, grads_and_vars)

        self._train_step_counter.assign_add(1)

        return next_time_step, next_state
Ejemplo n.º 4
0
 def _action(self, time_step, policy_state=(), seed=None):
     policy_step = self._action_fn1(time_step, policy_state)
     action = common.sample_action_distribution(policy_step.action, seed)
     policy_step = policy_step._replace(action=action)
     return policy_step