def _train_loop_body(counter, policy_state, training_info_ta): exp = tf.nest.map_structure(lambda ta: ta.read(counter), experience_ta) collect_action_distribution_param = exp.action_distribution collect_action_distribution = nested_distributions_from_specs( self._action_distribution_spec, collect_action_distribution_param) exp = exp._replace(action_distribution=collect_action_distribution) policy_state = common.reset_state_if_necessary( policy_state, initial_train_state, tf.equal(exp.step_type, StepType.FIRST)) policy_step = common.algorithm_step(self.train_step, exp, policy_state) action_dist_param = common.get_distribution_params( policy_step.action) training_info = TrainingInfo(action_distribution=action_dist_param, info=policy_step.info) training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, training_info) counter += 1 return [counter, policy_step.state, training_info_ta]
def _step(self, time_step, policy_state): policy_state = common.reset_state_if_necessary(policy_state, self._initial_state, time_step.is_first()) policy_step = common.algorithm_step( self._algorithm, self._observation_transformer, time_step, state=policy_state, training=self._training, greedy_predict=self._greedy_predict) action = common.sample_action_distribution(policy_step.action) next_time_step = self._env_step(action) if self._observers: traj = from_transition(time_step, policy_step._replace(action=action), next_time_step) for observer in self._observers: observer(traj) if self._exp_observers: action_distribution_param = common.get_distribution_params( policy_step.action) exp = make_experience( time_step, policy_step._replace(action=action), action_distribution=action_distribution_param, state=policy_state if self._use_rollout_state else ()) for observer in self._exp_observers: observer(exp) return next_time_step, policy_step, action
def _step(self, time_step, policy_state): policy_state = common.reset_state_if_necessary(policy_state, self._initial_state, time_step.is_first()) if self._mode == self.PREDICT: step_func = functools.partial(self._algorithm.predict, epsilon_greedy=self._epsilon_greedy) elif self._mode == self.ON_POLICY_TRAINING: step_func = functools.partial(self._algorithm.rollout, mode=RLAlgorithm.ON_POLICY_TRAINING) elif self._mode == self.OFF_POLICY_TRAINING: step_func = functools.partial(self._algorithm.rollout, mode=RLAlgorithm.ROLLOUT) else: raise ValueError() transformed_time_step = self._algorithm.transform_timestep(time_step) policy_step = step_func(transformed_time_step, policy_state) next_time_step = self._env_step(policy_step.action) if self._observers: traj = from_transition(time_step, policy_step._replace(info=()), next_time_step) for observer in self._observers: observer(traj) if self._algorithm.exp_observers and self._training: policy_step = nest_utils.distributions_to_params(policy_step) exp = make_experience(time_step, policy_step, policy_state) self._algorithm.observe(exp) return next_time_step, policy_step, transformed_time_step
def rollout_step(self, time_step: TimeStep, state: SarsaState): if self._on_policy: return self._train_step(time_step, state) if not self._is_rnn: critic_states = state.critics else: _, critic_states = self._critic_networks( (state.prev_observation, time_step.prev_action), state.critics) not_first_step = time_step.step_type != StepType.FIRST critic_states = common.reset_state_if_necessary( state.critics, critic_states, not_first_step) action_distribution, action, actor_state, noise_state = self._get_action( self._rollout_actor_network, time_step, state) if not self._is_rnn: target_critic_states = state.target_critics else: _, target_critic_states = self._target_critic_networks( (time_step.observation, action), state.target_critics) info = SarsaInfo(action_distribution=action_distribution) rl_state = SarsaState(noise=noise_state, prev_observation=time_step.observation, prev_step_type=time_step.step_type, actor=actor_state, critics=critic_states, target_critics=target_critic_states) return AlgStep(action, rl_state, info)
def _calc_change(self, exp_array): """Calculate the distance between old/new action distributions. The distance is: ||logits_1 - logits_2||^2 for Categorical distribution KL(d1||d2) + KL(d2||d1) for others """ def _dist(d1, d2): if isinstance(d1, tfp.distributions.Categorical): return tf.reduce_sum(tf.square(d1.logits - d2.logits), axis=-1) elif isinstance(d1, tfp.distributions.Deterministic): return tf.reduce_sum(tf.square(d1.loc - d2.loc), axis=-1) else: if isinstance(d1, SquashToSpecNormal): # TODO `SquashToSpecNormal.kl_divergence` checks that two distributions should have # same action mean and magnitude, but this check fails in graph mode d1 = d1.input_distribution d2 = d2.input_distribution return tf.reduce_sum( d1.kl_divergence(d2) + d2.kl_divergence(d1), axis=-1) def _update_total_dists(new_action, exp, total_dists): old_action = nested_distributions_from_specs( common.to_distribution_spec(self.action_distribution_spec), exp.action_param) dists = nest_map(_dist, old_action, new_action) valid_masks = tf.cast( tf.not_equal(exp.step_type, StepType.LAST), tf.float32) dists = nest_map(lambda kl: tf.reduce_sum(kl * valid_masks), dists) return nest_map(lambda x, y: x + y, total_dists, dists) num_steps = exp_array.step_type.size() # element_shape for `TensorArray` can be (None, ...) batch_size = tf.shape(exp_array.step_type.read(0))[0] state = tf.nest.map_structure(lambda x: x.read(0), exp_array.state) # exp_array.state is no longer needed exp_array = exp_array._replace(state=()) initial_state = common.zero_tensor_from_nested_spec( self.predict_state_spec, batch_size) total_dists = nest_map(lambda _: tf.zeros(()), self.action_spec) for t in tf.range(num_steps): exp = tf.nest.map_structure(lambda x: x.read(t), exp_array) state = common.reset_state_if_necessary( state, initial_state, exp.step_type == StepType.FIRST) time_step = ActionTimeStep( observation=exp.observation, step_type=exp.step_type) policy_step = self._ac_algorithm.predict( time_step=time_step, state=state) new_action, state = policy_step.action, policy_step.state new_action = common.to_distribution(new_action) total_dists = _update_total_dists(new_action, exp, total_dists) size = tf.cast(num_steps * batch_size, tf.float32) total_dists = nest_map(lambda d: d / size, total_dists) return total_dists
def _step(self, time_step, policy_state): policy_state = common.reset_state_if_necessary( policy_state, self._initial_policy_state, time_step.is_first()) self._actor_q.enqueue([time_step, policy_state, self._id]) policy_step, act_dist_param = self._action_return_q.dequeue() action = policy_step.action next_time_step = make_action_time_step(self._env.step(action), action) # temporarily store the transition into a local queue self._unroll_queue.enqueue( [time_step, policy_step, act_dist_param, next_time_step]) return next_time_step, policy_step.state
def _step(algorithm, env, time_step, policy_state, trans_state, epsilon_greedy, metrics): policy_state = common.reset_state_if_necessary( policy_state, algorithm.get_initial_predict_state(env.batch_size), time_step.is_first()) transformed_time_step, trans_state = algorithm.transform_timestep( time_step, trans_state) policy_step = algorithm.predict_step(transformed_time_step, policy_state, epsilon_greedy) next_time_step = env.step(policy_step.output) for metric in metrics: metric(time_step.cpu()) return next_time_step, policy_step, trans_state
def _train_step(self, time_step: TimeStep, state: SarsaState): not_first_step = time_step.step_type != StepType.FIRST prev_critics, critic_states = self._critic_networks( (state.prev_observation, time_step.prev_action), state.critics) critic_states = common.reset_state_if_necessary( state.critics, critic_states, not_first_step) action_distribution, action, actor_state, noise_state = self._get_action( self._actor_network, time_step, state) critics, _ = self._critic_networks((time_step.observation, action), critic_states) critic = critics.min(dim=1)[0] dqda = nest_utils.grad(action, critic.sum()) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = dqda.clamp(-self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) loss = loss.sum(list(range(1, loss.ndim))) return loss actor_loss = nest_map(actor_loss_fn, dqda, action) actor_loss = math_ops.add_n(alf.nest.flatten(actor_loss)) neg_entropy = () if self._log_alpha is not None: neg_entropy = dist_utils.compute_log_probability( action_distribution, action) target_critics, target_critic_states = self._target_critic_networks( (time_step.observation, action), state.target_critics) info = SarsaInfo(action_distribution=action_distribution, actor_loss=actor_loss, critics=prev_critics, neg_entropy=neg_entropy, target_critics=target_critics.min(dim=1)[0]) rl_state = SarsaState(noise=noise_state, prev_observation=time_step.observation, prev_step_type=time_step.step_type, actor=actor_state, critics=critic_states, target_critics=target_critic_states) return AlgStep(action, rl_state, info)
def _step(self, time_step, policy_state): policy_state = common.reset_state_if_necessary( policy_state, self._initial_policy_state, time_step.is_first()) self._actor_q.enqueue([time_step, policy_state, self._id]) policy_step, act_dist_param = self._action_return_q.dequeue() action = policy_step.action next_time_step = make_action_time_step(self._env.step(action), action) # temporarily store the transition into a local queue self._unroll_queue.enqueue( LearningBatch(time_step=time_step, state=policy_state if self._tfq._store_state else (), policy_step=policy_step, act_dist_param=act_dist_param, next_time_step=next_time_step, env_id=())) return [next_time_step, policy_step.state]
def unroll(env, algorithm, steps, epsilon_greedy=0.1): """Run `steps` environment steps using algoirthm.predict_step().""" time_step = common.get_initial_time_step(env) policy_state = algorithm.get_initial_predict_state(env.batch_size) trans_state = algorithm.get_initial_transform_state(env.batch_size) for _ in range(steps): policy_state = common.reset_state_if_necessary( policy_state, algorithm.get_initial_predict_state(env.batch_size), time_step.is_first()) transformed_time_step, trans_state = algorithm.transform_timestep( time_step, trans_state) policy_step = algorithm.predict_step(transformed_time_step, policy_state, epsilon_greedy=epsilon_greedy) time_step = env.step(policy_step.output) policy_state = policy_step.state return time_step
def _train_loop_body(counter, policy_state, info_ta): exp = tf.nest.map_structure(lambda ta: ta.read(counter), experience_ta) exp = nest_utils.params_to_distributions( exp, self.processed_experience_spec) policy_state = common.reset_state_if_necessary( policy_state, initial_train_state, tf.equal(exp.step_type, StepType.FIRST)) with tf.name_scope(scope): policy_step = self.train_step(exp, policy_state) info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), info_ta, nest_utils.distributions_to_params(policy_step.info)) counter += 1 return [counter, policy_step.state, info_ta]
def _calc_change(self, exp_array): """Calculate the distance between old/new action distributions. The distance is: - :math:`||logits_1 - logits_2||^2` for Categorical distribution - :math:`||loc_1 - loc_2||^2` for Deterministic distribution - :math:`KL(d1||d2) + KL(d2||d1)` for others """ def _get_base_dist(dist: td.Distribution): """Get the base distribution of `dist`.""" if isinstance(dist, (td.Independent, td.TransformedDistribution)): return _get_base_dist(dist.base_dist) return dist def _dist(d1, d2): d1_base = _get_base_dist(d1) d2_base = _get_base_dist(d2) if isinstance(d1_base, td.Categorical): dist = (d1_base.logits - d2_base.logits)**2 elif isinstance(d1, torch.Tensor): dist = (d1 - d2)**2 else: dist = td.kl.kl_divergence(d1, d2) + td.kl.kl_divergence( d2, d1) return math_ops.sum_to_leftmost(dist, 1) def _update_total_dists(new_action, exp, total_dists): old_action = dist_utils.params_to_distributions( exp.action_param, self._action_distribution_spec) valid_masks = (exp.step_type != StepType.LAST).to(torch.float32) return nest_map( lambda d1, d2, total_dist: (_dist(d1, d2) * valid_masks).sum() + total_dist, old_action, new_action, total_dists) num_steps, batch_size = exp_array.step_type.shape state = nest_map(lambda x: x[0], exp_array.state) # exp_array.state is no longer needed exp_array = exp_array._replace(state=()) initial_state = self.get_initial_predict_state(batch_size) total_dists = nest_map(lambda _: torch.tensor(0.), self.action_spec) for t in range(num_steps): exp = nest_map(lambda x: x[t], exp_array) state = common.reset_state_if_necessary( state, initial_state, exp.step_type == StepType.FIRST) time_step = TimeStep(observation=exp.observation, step_type=exp.step_type, prev_action=exp.prev_action) policy_step = self._ac_algorithm.predict_step(time_step=time_step, state=state, epsilon_greedy=1.0) assert ( alf.nest.is_namedtuple(policy_step.info) and "action_distribution" in policy_step.info._fields ), ("AlgStep.info from ac_algorithm.predict_step() should be " "a namedtuple containing `action_distribution` in order to " "use TracAlgorithm.") new_action = policy_step.info.action_distribution state = policy_step.state total_dists = _update_total_dists(new_action, exp, total_dists) size = num_steps * batch_size total_dists = nest_map(lambda d: torch.sqrt(d / size), total_dists) return total_dists
def unroll(self, unroll_length): r"""Unroll ``unroll_length`` steps using the current policy. Because the ``self._env`` is a batched environment. The total number of environment steps is ``self._env.batch_size * unroll_length``. Args: unroll_length (int): number of steps to unroll. Returns: Experience: The stacked experience with shape :math:`[T, B, \ldots]` for each of its members. """ if self._current_time_step is None: self._current_time_step = common.get_initial_time_step(self._env) if self._current_policy_state is None: self._current_policy_state = self.get_initial_rollout_state( self._env.batch_size) if self._current_transform_state is None: self._current_transform_state = self.get_initial_transform_state( self._env.batch_size) time_step = self._current_time_step policy_state = self._current_policy_state trans_state = self._current_transform_state experience_list = [] initial_state = self.get_initial_rollout_state(self._env.batch_size) env_step_time = 0. store_exp_time = 0. for _ in range(unroll_length): policy_state = common.reset_state_if_necessary( policy_state, initial_state, time_step.is_first()) transformed_time_step, trans_state = self.transform_timestep( time_step, trans_state) # save the untransformed time step in case that sub-algorithms need # to store it in replay buffers transformed_time_step = transformed_time_step._replace( untransformed=time_step) policy_step = self.rollout_step(transformed_time_step, policy_state) # release the reference to ``time_step`` transformed_time_step = transformed_time_step._replace( untransformed=()) action = common.detach(policy_step.output) t0 = time.time() next_time_step = self._env.step(action) env_step_time += time.time() - t0 self.observe_for_metrics(time_step.cpu()) if self._exp_replayer_type == "one_time": exp = make_experience(transformed_time_step, policy_step, policy_state) else: exp = make_experience(time_step.cpu(), policy_step, policy_state) t0 = time.time() self.observe_for_replay(exp) store_exp_time += time.time() - t0 exp_for_training = Experience( action=action, reward=transformed_time_step.reward, discount=transformed_time_step.discount, step_type=transformed_time_step.step_type, state=policy_state, prev_action=transformed_time_step.prev_action, observation=transformed_time_step.observation, rollout_info=dist_utils.distributions_to_params( policy_step.info), env_id=transformed_time_step.env_id) experience_list.append(exp_for_training) time_step = next_time_step policy_state = policy_step.state alf.summary.scalar("time/unroll_env_step", env_step_time) alf.summary.scalar("time/unroll_store_exp", store_exp_time) experience = alf.nest.utils.stack_nests(experience_list) experience = experience._replace( rollout_info=dist_utils.params_to_distributions( experience.rollout_info, self._rollout_info_spec)) self._current_time_step = time_step # Need to detach so that the graph from this unroll is disconnected from # the next unroll. Otherwise backward() will report error for on-policy # training after the next unroll. self._current_policy_state = common.detach(policy_state) self._current_transform_state = common.detach(trans_state) return experience
def _calc_change(self, exp_array): """Calculate the distance between old/new action distributions. The squared distance is: ||logits_1 - logits_2||^2 for Categorical distribution ||loc_1 - loc_2||^2 for Deterministic distribution KL(d1||d2) + KL(d2||d1) for others """ def _dist(d1, d2): if isinstance(d1, tfp.distributions.Categorical): dist = tf.square(d1.logits - d2.logits) elif isinstance(d1, tf.Tensor): dist = tf.square(d1 - d2) else: if isinstance(d1, SquashToSpecNormal): # TODO `SquashToSpecNormal.kl_divergence` checks that two # distributions should have same action mean and magnitude, # but this check fails in graph mode d1 = d1.input_distribution d2 = d2.input_distribution dist = d1.kl_divergence(d2) + d2.kl_divergence(d1) if len(dist.shape) > 1: # reduce to shape [B] reduce_dims = list(range(1, len(dist.shape))) dist = tf.reduce_sum(dist, axis=reduce_dims) return dist def _update_total_dists(new_action, exp, total_dists): old_action = nest_utils.params_to_distributions( exp.action_param, self._action_distribution_spec) dists = nest_map(_dist, old_action, new_action) valid_masks = tf.cast(tf.not_equal(exp.step_type, StepType.LAST), tf.float32) dists = nest_map(lambda kl: tf.reduce_sum(kl * valid_masks), dists) return nest_map(lambda x, y: x + y, total_dists, dists) num_steps = exp_array.step_type.size() # element_shape for `TensorArray` can be (None, ...) batch_size = tf.shape(exp_array.step_type.read(0))[0] state = tf.nest.map_structure(lambda x: x.read(0), exp_array.state) # exp_array.state is no longer needed exp_array = exp_array._replace(state=()) initial_state = common.zero_tensor_from_nested_spec( self.predict_state_spec, batch_size) total_dists = nest_map(lambda _: tf.zeros(()), self.action_spec) for t in tf.range(num_steps): exp = tf.nest.map_structure(lambda x: x.read(t), exp_array) state = common.reset_state_if_necessary( state, initial_state, exp.step_type == StepType.FIRST) time_step = ActionTimeStep(observation=exp.observation, step_type=exp.step_type) policy_step = self._ac_algorithm.predict(time_step=time_step, state=state, epsilon_greedy=1.0) assert ( common.is_namedtuple(policy_step.info) and "action_distribution" in policy_step.info._fields ), ("PolicyStep.info from ac_algorithm.predict() should be " "a namedtuple containing `action_distribution` in order to " "use TracAlgorithm.") new_action = policy_step.info.action_distribution state = policy_step.state total_dists = _update_total_dists(new_action, exp, total_dists) size = tf.cast(num_steps * batch_size, tf.float32) total_dists = nest_map(lambda d: tf.sqrt(d / size), total_dists) return total_dists