示例#1
0
        def _train_loop_body(counter, policy_state, training_info_ta):
            exp = tf.nest.map_structure(lambda ta: ta.read(counter),
                                        experience_ta)
            collect_action_distribution_param = exp.action_distribution
            collect_action_distribution = nested_distributions_from_specs(
                self._action_distribution_spec,
                collect_action_distribution_param)
            exp = exp._replace(action_distribution=collect_action_distribution)

            policy_state = common.reset_state_if_necessary(
                policy_state, initial_train_state,
                tf.equal(exp.step_type, StepType.FIRST))

            policy_step = common.algorithm_step(self.train_step, exp,
                                                policy_state)

            action_dist_param = common.get_distribution_params(
                policy_step.action)

            training_info = TrainingInfo(action_distribution=action_dist_param,
                                         info=policy_step.info)

            training_info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), training_info_ta,
                training_info)

            counter += 1

            return [counter, policy_step.state, training_info_ta]
示例#2
0
    def _step(self, time_step, policy_state):
        policy_state = common.reset_state_if_necessary(policy_state,
                                                       self._initial_state,
                                                       time_step.is_first())
        policy_step = common.algorithm_step(
            self._algorithm,
            self._observation_transformer,
            time_step,
            state=policy_state,
            training=self._training,
            greedy_predict=self._greedy_predict)
        action = common.sample_action_distribution(policy_step.action)
        next_time_step = self._env_step(action)
        if self._observers:
            traj = from_transition(time_step,
                                   policy_step._replace(action=action),
                                   next_time_step)
            for observer in self._observers:
                observer(traj)
        if self._exp_observers:
            action_distribution_param = common.get_distribution_params(
                policy_step.action)
            exp = make_experience(
                time_step,
                policy_step._replace(action=action),
                action_distribution=action_distribution_param,
                state=policy_state if self._use_rollout_state else ())
            for observer in self._exp_observers:
                observer(exp)

        return next_time_step, policy_step, action
示例#3
0
    def _step(self, time_step, policy_state):
        policy_state = common.reset_state_if_necessary(policy_state,
                                                       self._initial_state,
                                                       time_step.is_first())
        if self._mode == self.PREDICT:
            step_func = functools.partial(self._algorithm.predict,
                                          epsilon_greedy=self._epsilon_greedy)
        elif self._mode == self.ON_POLICY_TRAINING:
            step_func = functools.partial(self._algorithm.rollout,
                                          mode=RLAlgorithm.ON_POLICY_TRAINING)
        elif self._mode == self.OFF_POLICY_TRAINING:
            step_func = functools.partial(self._algorithm.rollout,
                                          mode=RLAlgorithm.ROLLOUT)
        else:
            raise ValueError()
        transformed_time_step = self._algorithm.transform_timestep(time_step)
        policy_step = step_func(transformed_time_step, policy_state)

        next_time_step = self._env_step(policy_step.action)
        if self._observers:
            traj = from_transition(time_step, policy_step._replace(info=()),
                                   next_time_step)
            for observer in self._observers:
                observer(traj)
        if self._algorithm.exp_observers and self._training:
            policy_step = nest_utils.distributions_to_params(policy_step)
            exp = make_experience(time_step, policy_step, policy_state)
            self._algorithm.observe(exp)

        return next_time_step, policy_step, transformed_time_step
示例#4
0
    def rollout_step(self, time_step: TimeStep, state: SarsaState):
        if self._on_policy:
            return self._train_step(time_step, state)

        if not self._is_rnn:
            critic_states = state.critics
        else:
            _, critic_states = self._critic_networks(
                (state.prev_observation, time_step.prev_action), state.critics)

            not_first_step = time_step.step_type != StepType.FIRST

            critic_states = common.reset_state_if_necessary(
                state.critics, critic_states, not_first_step)

        action_distribution, action, actor_state, noise_state = self._get_action(
            self._rollout_actor_network, time_step, state)

        if not self._is_rnn:
            target_critic_states = state.target_critics
        else:
            _, target_critic_states = self._target_critic_networks(
                (time_step.observation, action), state.target_critics)

        info = SarsaInfo(action_distribution=action_distribution)

        rl_state = SarsaState(noise=noise_state,
                              prev_observation=time_step.observation,
                              prev_step_type=time_step.step_type,
                              actor=actor_state,
                              critics=critic_states,
                              target_critics=target_critic_states)

        return AlgStep(action, rl_state, info)
示例#5
0
    def _calc_change(self, exp_array):
        """Calculate the distance between old/new action distributions.

        The distance is:
        ||logits_1 - logits_2||^2 for Categorical distribution
        KL(d1||d2) + KL(d2||d1) for others
        """

        def _dist(d1, d2):
            if isinstance(d1, tfp.distributions.Categorical):
                return tf.reduce_sum(tf.square(d1.logits - d2.logits), axis=-1)
            elif isinstance(d1, tfp.distributions.Deterministic):
                return tf.reduce_sum(tf.square(d1.loc - d2.loc), axis=-1)
            else:
                if isinstance(d1, SquashToSpecNormal):
                    # TODO `SquashToSpecNormal.kl_divergence` checks that  two distributions should have
                    #  same action mean and magnitude, but this check fails in graph mode
                    d1 = d1.input_distribution
                    d2 = d2.input_distribution
                return tf.reduce_sum(
                    d1.kl_divergence(d2) + d2.kl_divergence(d1), axis=-1)

        def _update_total_dists(new_action, exp, total_dists):
            old_action = nested_distributions_from_specs(
                common.to_distribution_spec(self.action_distribution_spec),
                exp.action_param)
            dists = nest_map(_dist, old_action, new_action)
            valid_masks = tf.cast(
                tf.not_equal(exp.step_type, StepType.LAST), tf.float32)
            dists = nest_map(lambda kl: tf.reduce_sum(kl * valid_masks), dists)
            return nest_map(lambda x, y: x + y, total_dists, dists)

        num_steps = exp_array.step_type.size()
        # element_shape for `TensorArray` can be (None, ...)
        batch_size = tf.shape(exp_array.step_type.read(0))[0]
        state = tf.nest.map_structure(lambda x: x.read(0), exp_array.state)
        # exp_array.state is no longer needed
        exp_array = exp_array._replace(state=())
        initial_state = common.zero_tensor_from_nested_spec(
            self.predict_state_spec, batch_size)
        total_dists = nest_map(lambda _: tf.zeros(()), self.action_spec)
        for t in tf.range(num_steps):
            exp = tf.nest.map_structure(lambda x: x.read(t), exp_array)
            state = common.reset_state_if_necessary(
                state, initial_state, exp.step_type == StepType.FIRST)
            time_step = ActionTimeStep(
                observation=exp.observation, step_type=exp.step_type)
            policy_step = self._ac_algorithm.predict(
                time_step=time_step, state=state)
            new_action, state = policy_step.action, policy_step.state
            new_action = common.to_distribution(new_action)
            total_dists = _update_total_dists(new_action, exp, total_dists)

        size = tf.cast(num_steps * batch_size, tf.float32)
        total_dists = nest_map(lambda d: d / size, total_dists)
        return total_dists
示例#6
0
 def _step(self, time_step, policy_state):
     policy_state = common.reset_state_if_necessary(
         policy_state, self._initial_policy_state, time_step.is_first())
     self._actor_q.enqueue([time_step, policy_state, self._id])
     policy_step, act_dist_param = self._action_return_q.dequeue()
     action = policy_step.action
     next_time_step = make_action_time_step(self._env.step(action), action)
     # temporarily store the transition into a local queue
     self._unroll_queue.enqueue(
         [time_step, policy_step, act_dist_param, next_time_step])
     return next_time_step, policy_step.state
示例#7
0
def _step(algorithm, env, time_step, policy_state, trans_state, epsilon_greedy,
          metrics):
    policy_state = common.reset_state_if_necessary(
        policy_state, algorithm.get_initial_predict_state(env.batch_size),
        time_step.is_first())
    transformed_time_step, trans_state = algorithm.transform_timestep(
        time_step, trans_state)
    policy_step = algorithm.predict_step(transformed_time_step, policy_state,
                                         epsilon_greedy)
    next_time_step = env.step(policy_step.output)
    for metric in metrics:
        metric(time_step.cpu())
    return next_time_step, policy_step, trans_state
示例#8
0
    def _train_step(self, time_step: TimeStep, state: SarsaState):
        not_first_step = time_step.step_type != StepType.FIRST
        prev_critics, critic_states = self._critic_networks(
            (state.prev_observation, time_step.prev_action), state.critics)

        critic_states = common.reset_state_if_necessary(
            state.critics, critic_states, not_first_step)

        action_distribution, action, actor_state, noise_state = self._get_action(
            self._actor_network, time_step, state)

        critics, _ = self._critic_networks((time_step.observation, action),
                                           critic_states)
        critic = critics.min(dim=1)[0]
        dqda = nest_utils.grad(action, critic.sum())

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = dqda.clamp(-self._dqda_clipping, self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                (dqda + action).detach(), action)
            loss = loss.sum(list(range(1, loss.ndim)))
            return loss

        actor_loss = nest_map(actor_loss_fn, dqda, action)
        actor_loss = math_ops.add_n(alf.nest.flatten(actor_loss))

        neg_entropy = ()
        if self._log_alpha is not None:
            neg_entropy = dist_utils.compute_log_probability(
                action_distribution, action)

        target_critics, target_critic_states = self._target_critic_networks(
            (time_step.observation, action), state.target_critics)

        info = SarsaInfo(action_distribution=action_distribution,
                         actor_loss=actor_loss,
                         critics=prev_critics,
                         neg_entropy=neg_entropy,
                         target_critics=target_critics.min(dim=1)[0])

        rl_state = SarsaState(noise=noise_state,
                              prev_observation=time_step.observation,
                              prev_step_type=time_step.step_type,
                              actor=actor_state,
                              critics=critic_states,
                              target_critics=target_critic_states)

        return AlgStep(action, rl_state, info)
示例#9
0
文件: threads.py 项目: ruizhaogit/alf
 def _step(self, time_step, policy_state):
     policy_state = common.reset_state_if_necessary(
         policy_state, self._initial_policy_state, time_step.is_first())
     self._actor_q.enqueue([time_step, policy_state, self._id])
     policy_step, act_dist_param = self._action_return_q.dequeue()
     action = policy_step.action
     next_time_step = make_action_time_step(self._env.step(action), action)
     # temporarily store the transition into a local queue
     self._unroll_queue.enqueue(
         LearningBatch(time_step=time_step,
                       state=policy_state if self._tfq._store_state else (),
                       policy_step=policy_step,
                       act_dist_param=act_dist_param,
                       next_time_step=next_time_step,
                       env_id=()))
     return [next_time_step, policy_step.state]
示例#10
0
def unroll(env, algorithm, steps, epsilon_greedy=0.1):
    """Run `steps` environment steps using algoirthm.predict_step()."""
    time_step = common.get_initial_time_step(env)
    policy_state = algorithm.get_initial_predict_state(env.batch_size)
    trans_state = algorithm.get_initial_transform_state(env.batch_size)
    for _ in range(steps):
        policy_state = common.reset_state_if_necessary(
            policy_state, algorithm.get_initial_predict_state(env.batch_size),
            time_step.is_first())
        transformed_time_step, trans_state = algorithm.transform_timestep(
            time_step, trans_state)
        policy_step = algorithm.predict_step(transformed_time_step,
                                             policy_state,
                                             epsilon_greedy=epsilon_greedy)
        time_step = env.step(policy_step.output)
        policy_state = policy_step.state
    return time_step
示例#11
0
        def _train_loop_body(counter, policy_state, info_ta):
            exp = tf.nest.map_structure(lambda ta: ta.read(counter),
                                        experience_ta)
            exp = nest_utils.params_to_distributions(
                exp, self.processed_experience_spec)
            policy_state = common.reset_state_if_necessary(
                policy_state, initial_train_state,
                tf.equal(exp.step_type, StepType.FIRST))

            with tf.name_scope(scope):
                policy_step = self.train_step(exp, policy_state)

            info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), info_ta,
                nest_utils.distributions_to_params(policy_step.info))

            counter += 1

            return [counter, policy_step.state, info_ta]
示例#12
0
    def _calc_change(self, exp_array):
        """Calculate the distance between old/new action distributions.

        The distance is:

        - :math:`||logits_1 - logits_2||^2` for Categorical distribution
        - :math:`||loc_1 - loc_2||^2` for Deterministic distribution
        - :math:`KL(d1||d2) + KL(d2||d1)` for others
        """
        def _get_base_dist(dist: td.Distribution):
            """Get the base distribution of `dist`."""
            if isinstance(dist, (td.Independent, td.TransformedDistribution)):
                return _get_base_dist(dist.base_dist)
            return dist

        def _dist(d1, d2):
            d1_base = _get_base_dist(d1)
            d2_base = _get_base_dist(d2)
            if isinstance(d1_base, td.Categorical):
                dist = (d1_base.logits - d2_base.logits)**2
            elif isinstance(d1, torch.Tensor):
                dist = (d1 - d2)**2
            else:
                dist = td.kl.kl_divergence(d1, d2) + td.kl.kl_divergence(
                    d2, d1)
            return math_ops.sum_to_leftmost(dist, 1)

        def _update_total_dists(new_action, exp, total_dists):
            old_action = dist_utils.params_to_distributions(
                exp.action_param, self._action_distribution_spec)
            valid_masks = (exp.step_type != StepType.LAST).to(torch.float32)
            return nest_map(
                lambda d1, d2, total_dist:
                (_dist(d1, d2) * valid_masks).sum() + total_dist, old_action,
                new_action, total_dists)

        num_steps, batch_size = exp_array.step_type.shape
        state = nest_map(lambda x: x[0], exp_array.state)
        # exp_array.state is no longer needed
        exp_array = exp_array._replace(state=())
        initial_state = self.get_initial_predict_state(batch_size)
        total_dists = nest_map(lambda _: torch.tensor(0.), self.action_spec)
        for t in range(num_steps):
            exp = nest_map(lambda x: x[t], exp_array)
            state = common.reset_state_if_necessary(
                state, initial_state, exp.step_type == StepType.FIRST)
            time_step = TimeStep(observation=exp.observation,
                                 step_type=exp.step_type,
                                 prev_action=exp.prev_action)
            policy_step = self._ac_algorithm.predict_step(time_step=time_step,
                                                          state=state,
                                                          epsilon_greedy=1.0)
            assert (
                alf.nest.is_namedtuple(policy_step.info)
                and "action_distribution" in policy_step.info._fields
            ), ("AlgStep.info from ac_algorithm.predict_step() should be "
                "a namedtuple containing `action_distribution` in order to "
                "use TracAlgorithm.")
            new_action = policy_step.info.action_distribution
            state = policy_step.state
            total_dists = _update_total_dists(new_action, exp, total_dists)

        size = num_steps * batch_size
        total_dists = nest_map(lambda d: torch.sqrt(d / size), total_dists)
        return total_dists
示例#13
0
    def unroll(self, unroll_length):
        r"""Unroll ``unroll_length`` steps using the current policy.

        Because the ``self._env`` is a batched environment. The total number of
        environment steps is ``self._env.batch_size * unroll_length``.

        Args:
            unroll_length (int): number of steps to unroll.
        Returns:
            Experience: The stacked experience with shape :math:`[T, B, \ldots]`
            for each of its members.
        """
        if self._current_time_step is None:
            self._current_time_step = common.get_initial_time_step(self._env)
        if self._current_policy_state is None:
            self._current_policy_state = self.get_initial_rollout_state(
                self._env.batch_size)
        if self._current_transform_state is None:
            self._current_transform_state = self.get_initial_transform_state(
                self._env.batch_size)

        time_step = self._current_time_step
        policy_state = self._current_policy_state
        trans_state = self._current_transform_state

        experience_list = []
        initial_state = self.get_initial_rollout_state(self._env.batch_size)

        env_step_time = 0.
        store_exp_time = 0.
        for _ in range(unroll_length):
            policy_state = common.reset_state_if_necessary(
                policy_state, initial_state, time_step.is_first())
            transformed_time_step, trans_state = self.transform_timestep(
                time_step, trans_state)
            # save the untransformed time step in case that sub-algorithms need
            # to store it in replay buffers
            transformed_time_step = transformed_time_step._replace(
                untransformed=time_step)
            policy_step = self.rollout_step(transformed_time_step,
                                            policy_state)
            # release the reference to ``time_step``
            transformed_time_step = transformed_time_step._replace(
                untransformed=())

            action = common.detach(policy_step.output)

            t0 = time.time()
            next_time_step = self._env.step(action)
            env_step_time += time.time() - t0

            self.observe_for_metrics(time_step.cpu())

            if self._exp_replayer_type == "one_time":
                exp = make_experience(transformed_time_step, policy_step,
                                      policy_state)
            else:
                exp = make_experience(time_step.cpu(), policy_step,
                                      policy_state)

            t0 = time.time()
            self.observe_for_replay(exp)
            store_exp_time += time.time() - t0

            exp_for_training = Experience(
                action=action,
                reward=transformed_time_step.reward,
                discount=transformed_time_step.discount,
                step_type=transformed_time_step.step_type,
                state=policy_state,
                prev_action=transformed_time_step.prev_action,
                observation=transformed_time_step.observation,
                rollout_info=dist_utils.distributions_to_params(
                    policy_step.info),
                env_id=transformed_time_step.env_id)

            experience_list.append(exp_for_training)
            time_step = next_time_step
            policy_state = policy_step.state

        alf.summary.scalar("time/unroll_env_step", env_step_time)
        alf.summary.scalar("time/unroll_store_exp", store_exp_time)
        experience = alf.nest.utils.stack_nests(experience_list)
        experience = experience._replace(
            rollout_info=dist_utils.params_to_distributions(
                experience.rollout_info, self._rollout_info_spec))

        self._current_time_step = time_step
        # Need to detach so that the graph from this unroll is disconnected from
        # the next unroll. Otherwise backward() will report error for on-policy
        # training after the next unroll.
        self._current_policy_state = common.detach(policy_state)
        self._current_transform_state = common.detach(trans_state)

        return experience
示例#14
0
    def _calc_change(self, exp_array):
        """Calculate the distance between old/new action distributions.

        The squared distance is:
        ||logits_1 - logits_2||^2 for Categorical distribution
        ||loc_1 - loc_2||^2 for Deterministic distribution
        KL(d1||d2) + KL(d2||d1) for others
        """
        def _dist(d1, d2):
            if isinstance(d1, tfp.distributions.Categorical):
                dist = tf.square(d1.logits - d2.logits)
            elif isinstance(d1, tf.Tensor):
                dist = tf.square(d1 - d2)
            else:
                if isinstance(d1, SquashToSpecNormal):
                    # TODO `SquashToSpecNormal.kl_divergence` checks that two
                    # distributions should have same action mean and magnitude,
                    # but this check fails in graph mode
                    d1 = d1.input_distribution
                    d2 = d2.input_distribution
                dist = d1.kl_divergence(d2) + d2.kl_divergence(d1)

            if len(dist.shape) > 1:
                # reduce to shape [B]
                reduce_dims = list(range(1, len(dist.shape)))
                dist = tf.reduce_sum(dist, axis=reduce_dims)
            return dist

        def _update_total_dists(new_action, exp, total_dists):
            old_action = nest_utils.params_to_distributions(
                exp.action_param, self._action_distribution_spec)
            dists = nest_map(_dist, old_action, new_action)
            valid_masks = tf.cast(tf.not_equal(exp.step_type, StepType.LAST),
                                  tf.float32)
            dists = nest_map(lambda kl: tf.reduce_sum(kl * valid_masks), dists)
            return nest_map(lambda x, y: x + y, total_dists, dists)

        num_steps = exp_array.step_type.size()
        # element_shape for `TensorArray` can be (None, ...)
        batch_size = tf.shape(exp_array.step_type.read(0))[0]
        state = tf.nest.map_structure(lambda x: x.read(0), exp_array.state)
        # exp_array.state is no longer needed
        exp_array = exp_array._replace(state=())
        initial_state = common.zero_tensor_from_nested_spec(
            self.predict_state_spec, batch_size)
        total_dists = nest_map(lambda _: tf.zeros(()), self.action_spec)
        for t in tf.range(num_steps):
            exp = tf.nest.map_structure(lambda x: x.read(t), exp_array)
            state = common.reset_state_if_necessary(
                state, initial_state, exp.step_type == StepType.FIRST)
            time_step = ActionTimeStep(observation=exp.observation,
                                       step_type=exp.step_type)
            policy_step = self._ac_algorithm.predict(time_step=time_step,
                                                     state=state,
                                                     epsilon_greedy=1.0)
            assert (
                common.is_namedtuple(policy_step.info)
                and "action_distribution" in policy_step.info._fields
            ), ("PolicyStep.info from ac_algorithm.predict() should be "
                "a namedtuple containing `action_distribution` in order to "
                "use TracAlgorithm.")
            new_action = policy_step.info.action_distribution
            state = policy_step.state
            total_dists = _update_total_dists(new_action, exp, total_dists)

        size = tf.cast(num_steps * batch_size, tf.float32)
        total_dists = nest_map(lambda d: tf.sqrt(d / size), total_dists)
        return total_dists