def train_step(self, exp: Experience, state): time_step = ActionTimeStep(step_type=exp.step_type, reward=exp.reward, discount=exp.discount, observation=exp.observation, prev_action=exp.prev_action) return self.rollout(time_step, state)
def rollout(self, time_step: ActionTimeStep, state: AgentState): """Rollout for one step.""" new_state = AgentState() info = AgentInfo() observation = self._encode(time_step) if self._icm is not None: icm_step = self._icm.train_step( (observation, time_step.prev_action), state=state.icm) info = info._replace(icm=icm_step.info) new_state = new_state._replace(icm=icm_step.state) rl_step = self._rl_algorithm.rollout( time_step._replace(observation=observation), state.rl) new_state = new_state._replace(rl=rl_step.state) info = info._replace(rl=rl_step.info) # TODO # avoid computing this when rollout (off policy train) if self._entropy_target_algorithm: et_step = self._entropy_target_algorithm.train_step( rl_step.action, step_type=time_step.step_type) info = info._replace(entropy_target=et_step.info) return PolicyStep(action=rl_step.action, state=new_state, info=info)
def _calc_change(self, exp_array): """Calculate the distance between old/new action distributions. The distance is: ||logits_1 - logits_2||^2 for Categorical distribution KL(d1||d2) + KL(d2||d1) for others """ def _dist(d1, d2): if isinstance(d1, tfp.distributions.Categorical): return tf.reduce_sum(tf.square(d1.logits - d2.logits), axis=-1) elif isinstance(d1, tfp.distributions.Deterministic): return tf.reduce_sum(tf.square(d1.loc - d2.loc), axis=-1) else: if isinstance(d1, SquashToSpecNormal): # TODO `SquashToSpecNormal.kl_divergence` checks that two distributions should have # same action mean and magnitude, but this check fails in graph mode d1 = d1.input_distribution d2 = d2.input_distribution return tf.reduce_sum( d1.kl_divergence(d2) + d2.kl_divergence(d1), axis=-1) def _update_total_dists(new_action, exp, total_dists): old_action = nested_distributions_from_specs( common.to_distribution_spec(self.action_distribution_spec), exp.action_param) dists = nest_map(_dist, old_action, new_action) valid_masks = tf.cast( tf.not_equal(exp.step_type, StepType.LAST), tf.float32) dists = nest_map(lambda kl: tf.reduce_sum(kl * valid_masks), dists) return nest_map(lambda x, y: x + y, total_dists, dists) num_steps = exp_array.step_type.size() # element_shape for `TensorArray` can be (None, ...) batch_size = tf.shape(exp_array.step_type.read(0))[0] state = tf.nest.map_structure(lambda x: x.read(0), exp_array.state) # exp_array.state is no longer needed exp_array = exp_array._replace(state=()) initial_state = common.zero_tensor_from_nested_spec( self.predict_state_spec, batch_size) total_dists = nest_map(lambda _: tf.zeros(()), self.action_spec) for t in tf.range(num_steps): exp = tf.nest.map_structure(lambda x: x.read(t), exp_array) state = common.reset_state_if_necessary( state, initial_state, exp.step_type == StepType.FIRST) time_step = ActionTimeStep( observation=exp.observation, step_type=exp.step_type) policy_step = self._ac_algorithm.predict( time_step=time_step, state=state) new_action, state = policy_step.action, policy_step.state new_action = common.to_distribution(new_action) total_dists = _update_total_dists(new_action, exp, total_dists) size = tf.cast(num_steps * batch_size, tf.float32) total_dists = nest_map(lambda d: d / size, total_dists) return total_dists
def greedy_predict(self, time_step: ActionTimeStep, state=None, eps=0.1): observation = self._encode(time_step) new_state = AgentState() rl_step = self._rl_algorithm.greedy_predict( time_step._replace(observation=observation), state.rl) new_state = new_state._replace(rl=rl_step.state) return PolicyStep(action=rl_step.action, state=new_state, info=())
def predict(self, time_step: ActionTimeStep, state: AgentState): """Predict for one step.""" observation = self._encode(time_step) new_state = AgentState() rl_step = self._rl_algorithm.predict( time_step._replace(observation=observation), state.rl) new_state = new_state._replace(rl=rl_step.state) return PolicyStep(action=rl_step.action, state=new_state, info=())
def rollout(self, time_step: ActionTimeStep, state): """Train one step.""" mbp_step = self._mbp.train_step(inputs=(time_step.observation, time_step.prev_action), state=state.mbp_state) mba_step = self._mba.rollout( time_step=time_step._replace(observation=mbp_step.outputs), state=state.mba_state) return PolicyStep(action=mba_step.action, state=MerlinState(mbp_state=mbp_step.state, mba_state=mba_step.state), info=MerlinInfo(mbp_info=mbp_step.info, mba_info=mba_step.info))