Example #1
0
 def train_step(self, exp: Experience, state):
     time_step = ActionTimeStep(step_type=exp.step_type,
                                reward=exp.reward,
                                discount=exp.discount,
                                observation=exp.observation,
                                prev_action=exp.prev_action)
     return self.rollout(time_step, state)
Example #2
0
    def rollout(self, time_step: ActionTimeStep, state: AgentState):
        """Rollout for one step."""
        new_state = AgentState()
        info = AgentInfo()
        observation = self._encode(time_step)
        if self._icm is not None:
            icm_step = self._icm.train_step(
                (observation, time_step.prev_action), state=state.icm)
            info = info._replace(icm=icm_step.info)
            new_state = new_state._replace(icm=icm_step.state)

        rl_step = self._rl_algorithm.rollout(
            time_step._replace(observation=observation), state.rl)

        new_state = new_state._replace(rl=rl_step.state)
        info = info._replace(rl=rl_step.info)

        # TODO
        # avoid computing this when rollout (off policy train)
        if self._entropy_target_algorithm:
            et_step = self._entropy_target_algorithm.train_step(
                rl_step.action, step_type=time_step.step_type)
            info = info._replace(entropy_target=et_step.info)

        return PolicyStep(action=rl_step.action, state=new_state, info=info)
Example #3
0
    def _calc_change(self, exp_array):
        """Calculate the distance between old/new action distributions.

        The distance is:
        ||logits_1 - logits_2||^2 for Categorical distribution
        KL(d1||d2) + KL(d2||d1) for others
        """

        def _dist(d1, d2):
            if isinstance(d1, tfp.distributions.Categorical):
                return tf.reduce_sum(tf.square(d1.logits - d2.logits), axis=-1)
            elif isinstance(d1, tfp.distributions.Deterministic):
                return tf.reduce_sum(tf.square(d1.loc - d2.loc), axis=-1)
            else:
                if isinstance(d1, SquashToSpecNormal):
                    # TODO `SquashToSpecNormal.kl_divergence` checks that  two distributions should have
                    #  same action mean and magnitude, but this check fails in graph mode
                    d1 = d1.input_distribution
                    d2 = d2.input_distribution
                return tf.reduce_sum(
                    d1.kl_divergence(d2) + d2.kl_divergence(d1), axis=-1)

        def _update_total_dists(new_action, exp, total_dists):
            old_action = nested_distributions_from_specs(
                common.to_distribution_spec(self.action_distribution_spec),
                exp.action_param)
            dists = nest_map(_dist, old_action, new_action)
            valid_masks = tf.cast(
                tf.not_equal(exp.step_type, StepType.LAST), tf.float32)
            dists = nest_map(lambda kl: tf.reduce_sum(kl * valid_masks), dists)
            return nest_map(lambda x, y: x + y, total_dists, dists)

        num_steps = exp_array.step_type.size()
        # element_shape for `TensorArray` can be (None, ...)
        batch_size = tf.shape(exp_array.step_type.read(0))[0]
        state = tf.nest.map_structure(lambda x: x.read(0), exp_array.state)
        # exp_array.state is no longer needed
        exp_array = exp_array._replace(state=())
        initial_state = common.zero_tensor_from_nested_spec(
            self.predict_state_spec, batch_size)
        total_dists = nest_map(lambda _: tf.zeros(()), self.action_spec)
        for t in tf.range(num_steps):
            exp = tf.nest.map_structure(lambda x: x.read(t), exp_array)
            state = common.reset_state_if_necessary(
                state, initial_state, exp.step_type == StepType.FIRST)
            time_step = ActionTimeStep(
                observation=exp.observation, step_type=exp.step_type)
            policy_step = self._ac_algorithm.predict(
                time_step=time_step, state=state)
            new_action, state = policy_step.action, policy_step.state
            new_action = common.to_distribution(new_action)
            total_dists = _update_total_dists(new_action, exp, total_dists)

        size = tf.cast(num_steps * batch_size, tf.float32)
        total_dists = nest_map(lambda d: d / size, total_dists)
        return total_dists
Example #4
0
    def greedy_predict(self, time_step: ActionTimeStep, state=None, eps=0.1):
        observation = self._encode(time_step)

        new_state = AgentState()

        rl_step = self._rl_algorithm.greedy_predict(
            time_step._replace(observation=observation), state.rl)
        new_state = new_state._replace(rl=rl_step.state)

        return PolicyStep(action=rl_step.action, state=new_state, info=())
Example #5
0
    def predict(self, time_step: ActionTimeStep, state: AgentState):
        """Predict for one step."""
        observation = self._encode(time_step)

        new_state = AgentState()

        rl_step = self._rl_algorithm.predict(
            time_step._replace(observation=observation), state.rl)
        new_state = new_state._replace(rl=rl_step.state)

        return PolicyStep(action=rl_step.action, state=new_state, info=())
Example #6
0
    def rollout(self, time_step: ActionTimeStep, state):
        """Train one step."""
        mbp_step = self._mbp.train_step(inputs=(time_step.observation,
                                                time_step.prev_action),
                                        state=state.mbp_state)
        mba_step = self._mba.rollout(
            time_step=time_step._replace(observation=mbp_step.outputs),
            state=state.mba_state)

        return PolicyStep(action=mba_step.action,
                          state=MerlinState(mbp_state=mbp_step.state,
                                            mba_state=mba_step.state),
                          info=MerlinInfo(mbp_info=mbp_step.info,
                                          mba_info=mba_step.info))