コード例 #1
0
ファイル: mbrl_algorithm.py プロジェクト: emailweixu/alf
    def _predict_with_planning(self, time_step: ActionTimeStep, state):
        action = self._planner_module.generate_plan(time_step, state)
        dynamics_state = self._dynamics_module.update_state(
            time_step, state.dynamics)

        return PolicyStep(
            action=action,
            state=MbrlState(dynamics=dynamics_state, reward=(), planner=()),
            info=MbrlInfo())
コード例 #2
0
ファイル: sarsa_algorithm.py プロジェクト: emailweixu/alf
 def predict(self, time_step: ActionTimeStep, state, epsilon_greedy):
     _, action, actor_state = self._get_action(self._actor_network,
                                               time_step, state.actor,
                                               epsilon_greedy)
     return PolicyStep(action=action,
                       state=SarsaState(
                           actor=actor_state,
                           prev_observation=time_step.observation,
                           prev_step_type=time_step.step_type),
                       info=())
コード例 #3
0
 def experience_spec(self):
     """Spec for experience."""
     policy_step_spec = PolicyStep(action=self.action_spec,
                                   state=self.train_state_spec,
                                   info=self.rollout_info_spec)
     exp_spec = make_experience(self.time_step_spec, policy_step_spec,
                                policy_step_spec.state)
     if not self._use_rollout_state:
         exp_spec = exp_spec._replace(state=())
     return exp_spec
コード例 #4
0
ファイル: sarsa_algorithm.py プロジェクト: emailweixu/alf
    def rollout(self, time_step: ActionTimeStep, state: SarsaState, mode):
        not_first_step = tf.not_equal(time_step.step_type, StepType.FIRST)
        prev_critic, critic_state = self._critic_network(
            inputs=(state.prev_observation, time_step.prev_action),
            step_type=state.prev_step_type,
            network_state=state.critic)
        critic_state = tf.nest.map_structure(
            lambda new_s, s: tf.where(not_first_step, new_s, s), critic_state,
            state.critic)

        action_distribution, action, actor_state = self._get_action(
            self._actor_network, time_step, state.actor)

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(action)
            critic, _ = self._critic_network((time_step.observation, action),
                                             step_type=time_step.step_type,
                                             network_state=critic_state)
        dqda = tape.gradient(critic, action)

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = tf.clip_by_value(dqda, -self._dqda_clipping,
                                        self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                tf.stop_gradient(dqda + action), action)
            loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape))))
            return loss

        actor_loss = tf.nest.map_structure(actor_loss_fn, dqda, action)
        actor_loss = tf.add_n(tf.nest.flatten(actor_loss))

        _, target_action, target_actor_state = self._get_action(
            self._target_actor_network, time_step, state.target_actor)

        target_critic, target_critic_state = self._target_critic_network(
            (time_step.observation, target_action),
            step_type=time_step.step_type,
            network_state=state.target_critic)

        prev_return = tf.stop_gradient(time_step.reward +
                                       time_step.discount * target_critic)
        info = SarsaInfo(action_distribution=action_distribution,
                         actor_loss=actor_loss,
                         critic=prev_critic,
                         returns=prev_return)

        rl_state = SarsaState(prev_observation=time_step.observation,
                              prev_step_type=time_step.step_type,
                              actor=actor_state,
                              target_actor=target_actor_state,
                              critic=critic_state,
                              target_critic=target_critic_state)

        return PolicyStep(action, rl_state, info)
コード例 #5
0
    def predict(self, time_step: ActionTimeStep, state: ActorCriticState,
                epsilon_greedy):
        """Predict for one step."""
        action_dist, actor_state = self._actor_network(
            time_step.observation,
            step_type=time_step.step_type,
            network_state=state.actor)

        action = common.epsilon_greedy_sample(action_dist, epsilon_greedy)
        return PolicyStep(
            action=action,
            state=ActorCriticState(actor=actor_state),
            info=ActorCriticInfo(action_distribution=action_dist))
コード例 #6
0
ファイル: mbrl_algorithm.py プロジェクト: emailweixu/alf
 def train_step(self, exp: Experience, state: MbrlState):
     action = exp.action
     dynamics_step = self._dynamics_module.train_step(exp, state.dynamics)
     reward_step = self._reward_module.train_step(exp, state.reward)
     plan_step = self._planner_module.train_step(exp, state.planner)
     state = MbrlState(
         dynamics=dynamics_step.state,
         reward=reward_step.state,
         planner=plan_step.state)
     info = MbrlInfo(
         dynamics=dynamics_step.info,
         reward=reward_step.info,
         planner=plan_step.info)
     return PolicyStep(action, state, info)
コード例 #7
0
 def _predict(self,
              time_step: ActionTimeStep,
              state=None,
              epsilon_greedy=1.):
     action_dist, state = self._actor_network(
         time_step.observation,
         step_type=time_step.step_type,
         network_state=state.share.actor)
     empty_state = tf.nest.map_structure(lambda x: (),
                                         self.train_state_spec)
     state = empty_state._replace(share=SacShareState(actor=state))
     action = common.epsilon_greedy_sample(action_dist, epsilon_greedy)
     return PolicyStep(action=action,
                       state=state,
                       info=SacInfo(action_distribution=action_dist))
コード例 #8
0
    def rollout(self, time_step: ActionTimeStep, state: ActorCriticState,
                mode):
        """Rollout for one step."""
        value, value_state = self._value_network(
            time_step.observation,
            step_type=time_step.step_type,
            network_state=state.value)

        action_distribution, actor_state = self._actor_network(
            time_step.observation,
            step_type=time_step.step_type,
            network_state=state.actor)

        action = common.sample_action_distribution(action_distribution)
        return PolicyStep(
            action=action,
            state=ActorCriticState(actor=actor_state, value=value_state),
            info=ActorCriticInfo(
                value=value, action_distribution=action_distribution))
コード例 #9
0
    def train_step(self, exp: Experience, state: SacState):
        action_distribution, share_actor_state = self._actor_network(
            exp.observation,
            step_type=exp.step_type,
            network_state=state.share.actor)
        action = tf.nest.map_structure(lambda d: d.sample(),
                                       action_distribution)
        log_pi = tfa_common.log_probability(action_distribution, action,
                                            self._action_spec)

        actor_state, actor_info = self._actor_train_step(
            exp, state.actor, action_distribution, action, log_pi)
        critic_state, critic_info = self._critic_train_step(
            exp, state.critic, action, log_pi)
        alpha_info = self._alpha_train_step(log_pi)
        state = SacState(share=SacShareState(actor=share_actor_state),
                         actor=actor_state,
                         critic=critic_state)
        info = SacInfo(action_distribution=action_distribution,
                       actor=actor_info,
                       critic=critic_info,
                       alpha=alpha_info)
        return PolicyStep(action, state, info)
コード例 #10
0
ファイル: goal_generator.py プロジェクト: runjerry/alf
    def rollout(self, time_step: ActionTimeStep, state, mode):
        """Perform one step of predicting and training computation.

        Note that as RandomCategoricalGoalGenerator is a non-trainable module,
        this function just passes the goal from state as outputs and
        the input state as output state.

        Args:
            time_step (ActionTimeStep): input time_step data
            state (nested Tensor): consistent with train_state_spec
            mode (int): See alf.algorithms.rl_algorithm.RLAlgorithm.rollout
        Returns:
            TrainStep:
                outputs: goal vector; currently just output the one from state
                state: state
                info (GoalInfo):
        """
        observation = time_step.observation
        step_type = time_step.step_type
        new_goal = self._update_goal(observation, state, step_type)
        return PolicyStep(action=new_goal,
                          state=GoalState(goal=new_goal),
                          info=GoalInfo(goal=new_goal))
コード例 #11
0
    def __init__(self,
                 envs,
                 algorithm: OffPolicyAlgorithm,
                 num_actor_queues=1,
                 unroll_length=8,
                 learn_queue_cap=1,
                 actor_queue_cap=1,
                 observers=[],
                 metrics=[],
                 exp_replayer="one_time"):
        """
        Args:
            envs (list[TFEnvironment]):  list of TFEnvironment
            algorithm (OffPolicyAlgorithm):
            num_actor_queues (int): number of actor queues. Each queue is
                exclusively owned by just one actor thread.
            unroll_length (int): number of time steps each environment proceeds
                before sending the steps to the learner queue
            learn_queue_cap (int): the learner queue capacity determines how many
                environments contribute to the training data for each training
                iteration
            actor_queue_cap (int): the actor queue capacity determines how many
                environments contribute to the data for each prediction forward
                in an `ActorThread`. To prevent deadlock, it's required that
                `actor_queue_cap` * `num_actor_queues` <= `num_envs`.
            observers (list[Callable]): An optional list of observers that are
                updated after every step in the environment. Each observer is a
                callable(time_step.Trajectory).
            metrics (list[TFStepMetric]): An optional list of metrics.
            exp_replayer (str): a string that indicates which ExperienceReplayer
                to use.
        """
        super(AsyncOffPolicyDriver, self).__init__(
            env=envs[0],
            num_envs=len(envs),
            algorithm=algorithm,
            exp_replayer=exp_replayer,
            observers=observers,
            metrics=metrics)

        # create threads
        self._coord = tf.train.Coordinator()
        num_envs = len(envs)
        policy_step_spec = PolicyStep(
            action=algorithm.action_spec,
            state=algorithm.train_state_spec,
            info=algorithm.rollout_info_spec)
        self._tfq = TFQueues(
            num_envs,
            self._env.batch_size,
            learn_queue_cap,
            actor_queue_cap,
            time_step_spec=algorithm.time_step_spec,
            policy_step_spec=policy_step_spec,
            unroll_length=unroll_length,
            num_actor_queues=num_actor_queues)
        actor_threads = [
            ActorThread(
                name="actor{}".format(i),
                coord=self._coord,
                algorithm=self._algorithm,
                tf_queues=self._tfq,
                id=i) for i in range(num_actor_queues)
        ]
        env_threads = [
            EnvThread(
                name="env{}".format(i),
                coord=self._coord,
                env=envs[i],
                tf_queues=self._tfq,
                unroll_length=unroll_length,
                id=i,
                actor_id=i % num_actor_queues) for i in range(num_envs)
        ]
        self._log_thread = LogThread(
            name="logging",
            num_envs=num_envs,
            env_batch_size=self._env.batch_size,
            observers=observers,
            metrics=metrics,
            coord=self._coord,
            queue=self._tfq.log_queue)
        self._threads = actor_threads + env_threads + [self._log_thread]
        algorithm.set_metrics(self.get_metrics())
コード例 #12
0
ファイル: goal_generator.py プロジェクト: runjerry/alf
 def train_step(self, exp: Experience, state):
     return PolicyStep(action=state.goal,
                       state=state,
                       info=GoalInfo(goal=state.goal))