def _predict_with_planning(self, time_step: ActionTimeStep, state): action = self._planner_module.generate_plan(time_step, state) dynamics_state = self._dynamics_module.update_state( time_step, state.dynamics) return PolicyStep( action=action, state=MbrlState(dynamics=dynamics_state, reward=(), planner=()), info=MbrlInfo())
def predict(self, time_step: ActionTimeStep, state, epsilon_greedy): _, action, actor_state = self._get_action(self._actor_network, time_step, state.actor, epsilon_greedy) return PolicyStep(action=action, state=SarsaState( actor=actor_state, prev_observation=time_step.observation, prev_step_type=time_step.step_type), info=())
def experience_spec(self): """Spec for experience.""" policy_step_spec = PolicyStep(action=self.action_spec, state=self.train_state_spec, info=self.rollout_info_spec) exp_spec = make_experience(self.time_step_spec, policy_step_spec, policy_step_spec.state) if not self._use_rollout_state: exp_spec = exp_spec._replace(state=()) return exp_spec
def rollout(self, time_step: ActionTimeStep, state: SarsaState, mode): not_first_step = tf.not_equal(time_step.step_type, StepType.FIRST) prev_critic, critic_state = self._critic_network( inputs=(state.prev_observation, time_step.prev_action), step_type=state.prev_step_type, network_state=state.critic) critic_state = tf.nest.map_structure( lambda new_s, s: tf.where(not_first_step, new_s, s), critic_state, state.critic) action_distribution, action, actor_state = self._get_action( self._actor_network, time_step, state.actor) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(action) critic, _ = self._critic_network((time_step.observation, action), step_type=time_step.step_type, network_state=critic_state) dqda = tape.gradient(critic, action) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( tf.stop_gradient(dqda + action), action) loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape)))) return loss actor_loss = tf.nest.map_structure(actor_loss_fn, dqda, action) actor_loss = tf.add_n(tf.nest.flatten(actor_loss)) _, target_action, target_actor_state = self._get_action( self._target_actor_network, time_step, state.target_actor) target_critic, target_critic_state = self._target_critic_network( (time_step.observation, target_action), step_type=time_step.step_type, network_state=state.target_critic) prev_return = tf.stop_gradient(time_step.reward + time_step.discount * target_critic) info = SarsaInfo(action_distribution=action_distribution, actor_loss=actor_loss, critic=prev_critic, returns=prev_return) rl_state = SarsaState(prev_observation=time_step.observation, prev_step_type=time_step.step_type, actor=actor_state, target_actor=target_actor_state, critic=critic_state, target_critic=target_critic_state) return PolicyStep(action, rl_state, info)
def predict(self, time_step: ActionTimeStep, state: ActorCriticState, epsilon_greedy): """Predict for one step.""" action_dist, actor_state = self._actor_network( time_step.observation, step_type=time_step.step_type, network_state=state.actor) action = common.epsilon_greedy_sample(action_dist, epsilon_greedy) return PolicyStep( action=action, state=ActorCriticState(actor=actor_state), info=ActorCriticInfo(action_distribution=action_dist))
def train_step(self, exp: Experience, state: MbrlState): action = exp.action dynamics_step = self._dynamics_module.train_step(exp, state.dynamics) reward_step = self._reward_module.train_step(exp, state.reward) plan_step = self._planner_module.train_step(exp, state.planner) state = MbrlState( dynamics=dynamics_step.state, reward=reward_step.state, planner=plan_step.state) info = MbrlInfo( dynamics=dynamics_step.info, reward=reward_step.info, planner=plan_step.info) return PolicyStep(action, state, info)
def _predict(self, time_step: ActionTimeStep, state=None, epsilon_greedy=1.): action_dist, state = self._actor_network( time_step.observation, step_type=time_step.step_type, network_state=state.share.actor) empty_state = tf.nest.map_structure(lambda x: (), self.train_state_spec) state = empty_state._replace(share=SacShareState(actor=state)) action = common.epsilon_greedy_sample(action_dist, epsilon_greedy) return PolicyStep(action=action, state=state, info=SacInfo(action_distribution=action_dist))
def rollout(self, time_step: ActionTimeStep, state: ActorCriticState, mode): """Rollout for one step.""" value, value_state = self._value_network( time_step.observation, step_type=time_step.step_type, network_state=state.value) action_distribution, actor_state = self._actor_network( time_step.observation, step_type=time_step.step_type, network_state=state.actor) action = common.sample_action_distribution(action_distribution) return PolicyStep( action=action, state=ActorCriticState(actor=actor_state, value=value_state), info=ActorCriticInfo( value=value, action_distribution=action_distribution))
def train_step(self, exp: Experience, state: SacState): action_distribution, share_actor_state = self._actor_network( exp.observation, step_type=exp.step_type, network_state=state.share.actor) action = tf.nest.map_structure(lambda d: d.sample(), action_distribution) log_pi = tfa_common.log_probability(action_distribution, action, self._action_spec) actor_state, actor_info = self._actor_train_step( exp, state.actor, action_distribution, action, log_pi) critic_state, critic_info = self._critic_train_step( exp, state.critic, action, log_pi) alpha_info = self._alpha_train_step(log_pi) state = SacState(share=SacShareState(actor=share_actor_state), actor=actor_state, critic=critic_state) info = SacInfo(action_distribution=action_distribution, actor=actor_info, critic=critic_info, alpha=alpha_info) return PolicyStep(action, state, info)
def rollout(self, time_step: ActionTimeStep, state, mode): """Perform one step of predicting and training computation. Note that as RandomCategoricalGoalGenerator is a non-trainable module, this function just passes the goal from state as outputs and the input state as output state. Args: time_step (ActionTimeStep): input time_step data state (nested Tensor): consistent with train_state_spec mode (int): See alf.algorithms.rl_algorithm.RLAlgorithm.rollout Returns: TrainStep: outputs: goal vector; currently just output the one from state state: state info (GoalInfo): """ observation = time_step.observation step_type = time_step.step_type new_goal = self._update_goal(observation, state, step_type) return PolicyStep(action=new_goal, state=GoalState(goal=new_goal), info=GoalInfo(goal=new_goal))
def __init__(self, envs, algorithm: OffPolicyAlgorithm, num_actor_queues=1, unroll_length=8, learn_queue_cap=1, actor_queue_cap=1, observers=[], metrics=[], exp_replayer="one_time"): """ Args: envs (list[TFEnvironment]): list of TFEnvironment algorithm (OffPolicyAlgorithm): num_actor_queues (int): number of actor queues. Each queue is exclusively owned by just one actor thread. unroll_length (int): number of time steps each environment proceeds before sending the steps to the learner queue learn_queue_cap (int): the learner queue capacity determines how many environments contribute to the training data for each training iteration actor_queue_cap (int): the actor queue capacity determines how many environments contribute to the data for each prediction forward in an `ActorThread`. To prevent deadlock, it's required that `actor_queue_cap` * `num_actor_queues` <= `num_envs`. observers (list[Callable]): An optional list of observers that are updated after every step in the environment. Each observer is a callable(time_step.Trajectory). metrics (list[TFStepMetric]): An optional list of metrics. exp_replayer (str): a string that indicates which ExperienceReplayer to use. """ super(AsyncOffPolicyDriver, self).__init__( env=envs[0], num_envs=len(envs), algorithm=algorithm, exp_replayer=exp_replayer, observers=observers, metrics=metrics) # create threads self._coord = tf.train.Coordinator() num_envs = len(envs) policy_step_spec = PolicyStep( action=algorithm.action_spec, state=algorithm.train_state_spec, info=algorithm.rollout_info_spec) self._tfq = TFQueues( num_envs, self._env.batch_size, learn_queue_cap, actor_queue_cap, time_step_spec=algorithm.time_step_spec, policy_step_spec=policy_step_spec, unroll_length=unroll_length, num_actor_queues=num_actor_queues) actor_threads = [ ActorThread( name="actor{}".format(i), coord=self._coord, algorithm=self._algorithm, tf_queues=self._tfq, id=i) for i in range(num_actor_queues) ] env_threads = [ EnvThread( name="env{}".format(i), coord=self._coord, env=envs[i], tf_queues=self._tfq, unroll_length=unroll_length, id=i, actor_id=i % num_actor_queues) for i in range(num_envs) ] self._log_thread = LogThread( name="logging", num_envs=num_envs, env_batch_size=self._env.batch_size, observers=observers, metrics=metrics, coord=self._coord, queue=self._tfq.log_queue) self._threads = actor_threads + env_threads + [self._log_thread] algorithm.set_metrics(self.get_metrics())
def train_step(self, exp: Experience, state): return PolicyStep(action=state.goal, state=state, info=GoalInfo(goal=state.goal))