예제 #1
0
    def train_step(self, exp: Experience, state, trainable=True):
        """This function trains the discriminator or generates intrinsic rewards.

        If ``trainable=True``, then it only generates and returns the pred loss;
        otherwise it only generates rewards with no grad.
        """
        # Discriminator training from its own replay buffer  or
        # Discriminator computing intrinsic rewards for training lower_rl
        untrans_observation, prev_skill, switch_skill, steps = exp.observation

        observation = self._observation_transformer(untrans_observation)
        loss = self._predict_skill_loss(observation, exp.prev_action,
                                        prev_skill, steps, state)

        first_observation = self._update_state_if_necessary(
            switch_skill, observation, state.first_observation)
        subtrajectory = self._clear_subtrajectory_if_necessary(
            state.subtrajectory, switch_skill)
        new_state = DiscriminatorState(first_observation=first_observation,
                                       untrans_observation=untrans_observation,
                                       subtrajectory=subtrajectory)

        valid_masks = (exp.step_type != StepType.FIRST)
        if self._sparse_reward:
            # Only give intrinsic rewards at the last step of the skill
            valid_masks &= switch_skill
        loss *= valid_masks.to(torch.float32)

        if trainable:
            info = LossInfo(loss=loss, extra=dict(discriminator_loss=loss))
            return AlgStep(state=new_state, info=info)
        else:
            intrinsic_reward = -loss.detach() / self._skill_dim
            return AlgStep(state=common.detach(new_state),
                           info=intrinsic_reward)
예제 #2
0
    def predict_step(self,
                     inputs=None,
                     noise=None,
                     batch_size=None,
                     training=False,
                     state=None):
        """Generate outputs given inputs.

        Args:
            inputs (nested Tensor): if None, the outputs is generated only from
                noise.
            noise (Tensor): input to the generator.
            batch_size (int): batch_size. Must be provided if inputs is None.
                Its is ignored if inputs is not None
            training (bool): whether train the generator.
            state: not used

        Returns:
            AlgorithmStep: outputs with shape (batch_size, output_dim)
        """
        outputs, _ = self._predict(inputs=inputs,
                                   noise=noise,
                                   batch_size=batch_size,
                                   training=training)
        return AlgStep(output=outputs, state=(), info=())
예제 #3
0
    def predict_step(self, time_step: TimeStep, state, epsilon_greedy):
        switch_action = self._should_switch_action(time_step, state)

        @torch.no_grad()
        def _generate_new_action(time_step, state):
            repr_state = ()
            if self._repr_learner is not None:
                repr_step = self._repr_learner.predict_step(
                    time_step, state.repr)
                time_step = time_step._replace(observation=repr_step.output)
                repr_state = repr_step.state

            rl_step = self._rl.predict_step(time_step, state.rl,
                                            epsilon_greedy)
            steps, action = rl_step.output
            return ActionRepeatState(
                action=action,
                steps=steps + 1,  # [0, K-1] -> [1, K]
                rl=rl_step.state,
                repr=repr_state)

        new_state = conditional_update(
            target=state,
            cond=switch_action,
            func=_generate_new_action,
            time_step=time_step,
            state=state)
        new_state = new_state._replace(steps=new_state.steps - 1)

        return AlgStep(
            output=new_state.action,
            state=new_state,
            # plot steps and action when rendering video
            info=dict(action=(new_state.action, new_state.steps)))
예제 #4
0
    def rollout_step(self, time_step: TimeStep, state: AgentState):
        """Rollout for one step."""
        new_state = AgentState()
        info = AgentInfo()

        time_step = transform_nest(time_step, "observation",
                                   self._observation_transformer)

        subtrajectory = self._skill_generator.update_disc_subtrajectory(
            time_step, state.skill_generator)

        skill_step = self._skill_generator.rollout_step(
            time_step, state.skill_generator)
        new_state = new_state._replace(skill_generator=skill_step.state)
        info = info._replace(skill_generator=skill_step.info)

        observation = self._make_low_level_observation(
            subtrajectory, skill_step.output, skill_step.info.switch_skill,
            skill_step.state.steps,
            skill_step.state.discriminator.first_observation)

        rl_step = self._rl_algorithm.rollout_step(
            time_step._replace(observation=observation), state.rl)
        new_state = new_state._replace(rl=rl_step.state)
        info = info._replace(rl=rl_step.info)

        skill_discount = ((
            (skill_step.state.steps == 1)
            & (time_step.step_type != StepType.LAST)).to(torch.float32) *
                          (1 - self._skill_boundary_discount))
        info = info._replace(skill_discount=1 - skill_discount)

        return AlgStep(output=rl_step.output, state=new_state, info=info)
예제 #5
0
    def train_step(self, exp: Experience, state):
        def _hook(grad, name):
            alf.summary.scalar("MCTS_state_grad_norm/" + name, grad.norm())

        model_output = self._model.initial_inference(exp.observation)
        if alf.summary.should_record_summaries():
            model_output.state.register_hook(partial(_hook, name="s0"))
        model_output_spec = dist_utils.extract_spec(model_output)
        model_outputs = [dist_utils.distributions_to_params(model_output)]
        info = exp.rollout_info

        for i in range(self._num_unroll_steps):
            model_output = self._model.recurrent_inference(
                model_output.state, info.action[:, i, ...])
            if alf.summary.should_record_summaries():
                model_output.state.register_hook(
                    partial(_hook, name="s" + str(i + 1)))
            model_output = model_output._replace(state=scale_gradient(
                model_output.state, self._recurrent_gradient_scaling_factor))
            model_outputs.append(
                dist_utils.distributions_to_params(model_output))

        model_outputs = alf.nest.utils.stack_nests(model_outputs, dim=1)
        model_outputs = dist_utils.params_to_distributions(
            model_outputs, model_output_spec)
        return AlgStep(info=self._model.calc_loss(model_outputs, info.target))
예제 #6
0
    def train_step(self, exp: Experience, state):
        new_state = AgentState()
        info = AgentInfo()

        skill_generator_info = exp.rollout_info.skill_generator

        subtrajectory = self._skill_generator.update_disc_subtrajectory(
            exp, state.skill_generator)
        skill_step = self._skill_generator.train_step(
            exp._replace(rollout_info=skill_generator_info),
            state.skill_generator)
        info = info._replace(skill_generator=skill_step.info)
        new_state = new_state._replace(skill_generator=skill_step.state)

        exp = transform_nest(exp, "observation", self._observation_transformer)

        observation = self._make_low_level_observation(
            subtrajectory, skill_step.output,
            skill_generator_info.switch_skill, skill_generator_info.steps,
            skill_step.state.discriminator.first_observation)

        rl_step = self._rl_algorithm.train_step(
            exp._replace(observation=observation,
                         rollout_info=exp.rollout_info.rl), state.rl)

        new_state = new_state._replace(rl=rl_step.state)
        info = info._replace(rl=rl_step.info)

        return AlgStep(output=rl_step.output, state=new_state, info=info)
예제 #7
0
    def predict_step(self, time_step: TimeStep, state: AgentState,
                     epsilon_greedy):
        """Predict for one step."""
        new_state = AgentState()

        time_step = transform_nest(time_step, "observation",
                                   self._observation_transformer)

        subtrajectory = self._skill_generator.update_disc_subtrajectory(
            time_step, state.skill_generator)

        skill_step = self._skill_generator.predict_step(
            time_step, state.skill_generator, epsilon_greedy)
        new_state = new_state._replace(skill_generator=skill_step.state)

        observation = self._make_low_level_observation(
            subtrajectory, skill_step.output, skill_step.info.switch_skill,
            skill_step.state.steps,
            skill_step.state.discriminator.first_observation)

        rl_step = self._rl_algorithm.predict_step(
            time_step._replace(observation=observation), state.rl,
            epsilon_greedy)
        new_state = new_state._replace(rl=rl_step.state)

        return AlgStep(output=rl_step.output, state=new_state)
예제 #8
0
    def predict_step(self, time_step: TimeStep, state, epsilon_greedy):
        """This function does one thing, i.e., every ``self._num_steps_per_skill``
        it calls ``self._rl`` to generate new skills.
        """
        switch_skill = self._should_switch_skills(time_step, state)
        discriminator_step = self._discriminator_predict_step(
            time_step, state, switch_skill)

        def _generate_new_skills(time_step, state):
            rl_step = self._rl.predict_step(time_step, state.rl,
                                            epsilon_greedy)
            return SkillGeneratorState(skill=rl_step.output,
                                       steps=torch.zeros_like(state.steps),
                                       rl=rl_step.state)

        new_state = conditional_update(target=state,
                                       cond=switch_skill,
                                       func=_generate_new_skills,
                                       time_step=time_step,
                                       state=state)

        new_state = new_state._replace(steps=new_state.steps + 1,
                                       discriminator=discriminator_step.state)

        return AlgStep(output=new_state.skill,
                       state=new_state,
                       info=SkillGeneratorInfo(switch_skill=switch_skill))
예제 #9
0
    def train_step(self, time_step: TimeStep, state: DynamicsState):
        """
        Args:
            time_step (TimeStep): input data for dynamics learning
            state (DynamicsState): state for dynamics learning (previous observation)
        Returns:
            AlgStep:
                output: empty tuple ()
                state (DynamicsState): state for training
                info (DynamicsInfo):
        """
        feature = time_step.observation
        dynamics_step = self.predict_step(time_step, state)
        forward_pred = dynamics_step.output
        forward_loss = (feature - forward_pred)**2
        forward_loss = 0.5 * forward_loss.mean(
            list(range(1, forward_loss.ndim)))

        # we mask out FIRST as its state is invalid
        valid_masks = (time_step.step_type != StepType.FIRST).to(torch.float32)
        forward_loss = forward_loss * valid_masks

        info = DynamicsInfo(loss=LossInfo(
            loss=forward_loss, extra=dict(forward_loss=forward_loss)))

        state = state._replace(feature=feature)

        return AlgStep(output=(), state=state, info=info)
예제 #10
0
    def predict_step(self, time_step: TimeStep, state, epsilon_greedy=1.):
        action, state = self._actor_network(time_step.observation,
                                            state=state.actor.actor)
        empty_state = nest.map_structure(lambda x: (), self.train_state_spec)

        def _sample(a, ou):
            if epsilon_greedy == 0:
                return a
            elif epsilon_greedy >= 1.0:
                return a + ou()
            else:
                ind_explore = torch.where(
                    torch.rand(a.shape[:1]) < epsilon_greedy)
                noisy_a = a + ou()
                a[ind_explore[0], :] = noisy_a[ind_explore[0], :]
                return a

        noisy_action = nest.map_structure(_sample, action, self._ou_process)
        noisy_action = nest.map_structure(spec_utils.clip_to_spec,
                                          noisy_action, self._action_spec)
        state = empty_state._replace(
            actor=DdpgActorState(actor=state, critics=()))
        return AlgStep(output=noisy_action,
                       state=state,
                       info=DdpgInfo(action_distribution=action))
    def train_step(self, exp: TimeStep, state):
        # [B, num_unroll_steps + 1]
        info = exp.rollout_info
        targets = common.as_list(info.target)
        batch_size = exp.step_type.shape[0]
        latent, state = self._encoding_net(exp.observation, state)

        sim_latent = self._multi_step_latent_rollout(latent,
                                                     self._num_unroll_steps,
                                                     info.action, state)

        loss = 0
        for i, decoder in enumerate(self._decoders):
            # [num_unroll_steps + 1)*B, ...]
            train_info = decoder.train_step(sim_latent).info
            train_info_spec = dist_utils.extract_spec(train_info)
            train_info = dist_utils.distributions_to_params(train_info)
            train_info = alf.nest.map_structure(
                lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x.
                                    shape[1:]), train_info)
            # [num_unroll_steps + 1, B, ...]
            train_info = dist_utils.params_to_distributions(
                train_info, train_info_spec)
            target = alf.nest.map_structure(lambda x: x.transpose(0, 1),
                                            targets[i])
            loss_info = decoder.calc_loss(target, train_info, info.mask.t())
            loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0),
                                               loss_info)
            loss += loss_info.loss

        loss_info = LossInfo(loss=loss, extra=loss)

        return AlgStep(output=latent, state=state, info=loss_info)
예제 #12
0
    def rollout_step(self, time_step: TimeStep, state: SarsaState):
        if self._on_policy:
            return self._train_step(time_step, state)

        if not self._is_rnn:
            critic_states = state.critics
        else:
            _, critic_states = self._critic_networks(
                (state.prev_observation, time_step.prev_action), state.critics)

            not_first_step = time_step.step_type != StepType.FIRST

            critic_states = common.reset_state_if_necessary(
                state.critics, critic_states, not_first_step)

        action_distribution, action, actor_state, noise_state = self._get_action(
            self._rollout_actor_network, time_step, state)

        if not self._is_rnn:
            target_critic_states = state.target_critics
        else:
            _, target_critic_states = self._target_critic_networks(
                (time_step.observation, action), state.target_critics)

        info = SarsaInfo(action_distribution=action_distribution)

        rl_state = SarsaState(noise=noise_state,
                              prev_observation=time_step.observation,
                              prev_step_type=time_step.step_type,
                              actor=actor_state,
                              critics=critic_states,
                              target_critics=target_critic_states)

        return AlgStep(action, rl_state, info)
예제 #13
0
    def predict_step(self, time_step: TimeStep, state: AgentState,
                     epsilon_greedy):
        """Predict for one step."""
        new_state = AgentState()
        observation = time_step.observation
        info = AgentInfo()

        if self._representation_learner is not None:
            repr_step = self._representation_learner.predict_step(
                time_step, state.repr)
            new_state = new_state._replace(repr=repr_step.state)
            info = info._replace(repr=repr_step.info)
            observation = repr_step.output

        if self._goal_generator is not None:
            goal_step = self._goal_generator.predict_step(
                time_step._replace(observation=observation),
                state.goal_generator, epsilon_greedy)
            new_state = new_state._replace(goal_generator=goal_step.state)
            info = info._replace(goal_generator=goal_step.info)
            observation = [observation, goal_step.output]

        rl_step = self._rl_algorithm.predict_step(
            time_step._replace(observation=observation), state.rl,
            epsilon_greedy)
        new_state = new_state._replace(rl=rl_step.state)
        info = info._replace(rl=rl_step.info)

        return AlgStep(output=rl_step.output, state=new_state, info=info)
예제 #14
0
 def testCreateWithDefaultInfo(self):
     action = torch.tensor(1)
     state = torch.tensor(2)
     info = ()
     step = AlgStep(output=action, state=state)
     self.assertEqual(step.output, action)
     self.assertEqual(step.state, state)
     self.assertEqual(step.info, info)
예제 #15
0
 def testCreate(self):
     action = torch.tensor(1)
     state = torch.tensor(2)
     info = torch.tensor(3)
     step = AlgStep(output=action, state=state, info=info)
     self.assertEqual(step.output, action)
     self.assertEqual(step.state, state)
     self.assertEqual(step.info, info)
예제 #16
0
 def predict_step(self, time_step, state):
     observation, switch_skill = time_step.observation
     first_observation = self._update_state_if_necessary(
         switch_skill, observation, state.first_observation)
     subtrajectory = self._clear_subtrajectory_if_necessary(
         state.subtrajectory, switch_skill)
     return AlgStep(state=DiscriminatorState(
         first_observation=first_observation, subtrajectory=subtrajectory))
예제 #17
0
    def _step(self, time_step: TimeStep, state, calc_rewards=True):
        """This step is for both `rollout_step` and `train_step`.

        Args:
            time_step (TimeStep): input time_step data for ICM
            state (Tensor): state for ICM (previous observation)
            calc_rewards (bool): whether calculate rewards

        Returns:
            AlgStep:
                output: empty tuple ()
                state: observation
                info (ICMInfo):
        """
        feature = time_step.observation
        prev_action = time_step.prev_action.detach()

        # normalize observation for easier prediction
        if self._observation_normalizer is not None:
            feature = self._observation_normalizer.normalize(feature)

        if self._encoding_net is not None:
            feature, _ = self._encoding_net(feature)
        prev_feature = state

        forward_pred, _ = self._forward_net(
            inputs=[prev_feature.detach(),
                    self._encode_action(prev_action)])
        # nn.MSELoss doesn't support reducing along a dim
        forward_loss = 0.5 * torch.mean(
            math_ops.square(forward_pred - feature.detach()), dim=-1)

        action_pred, _ = self._inverse_net([prev_feature, feature])

        if self._action_spec.is_discrete:
            inverse_loss = torch.nn.CrossEntropyLoss(reduction='none')(
                input=action_pred, target=prev_action.to(torch.int64))
        else:
            # nn.MSELoss doesn't support reducing along a dim
            inverse_loss = 0.5 * torch.mean(
                math_ops.square(action_pred - prev_action), dim=-1)

        intrinsic_reward = ()
        if calc_rewards:
            intrinsic_reward = forward_loss.detach()
            intrinsic_reward = self._reward_normalizer.normalize(
                intrinsic_reward)

        return AlgStep(
            output=(),
            state=feature,
            info=ICMInfo(
                reward=intrinsic_reward,
                loss=LossInfo(
                    loss=forward_loss + inverse_loss,
                    extra=dict(
                        forward_loss=forward_loss,
                        inverse_loss=inverse_loss))))
예제 #18
0
 def predict_step(self, time_step: TimeStep, state, epsilon_greedy):
     action_distribution, action, actor_state, noise_state = self._get_action(
         self._rollout_actor_network, time_step, state, epsilon_greedy)
     return AlgStep(output=action,
                    state=SarsaState(noise=noise_state,
                                     actor=actor_state,
                                     prev_observation=time_step.observation,
                                     prev_step_type=time_step.step_type),
                    info=SarsaInfo(action_distribution=action_distribution))
예제 #19
0
 def train_step(self, time_step: TimeStep, state=()):
     """
     Args:
         time_step (TimeStep): input data for dynamics learning
         state: state for reward learning
     Returns:
         AlgStep
     """
     return AlgStep(output=(), state=state, info=())
예제 #20
0
    def train_step(self, time_step: TimeStep, state: DynamicsState):
        """
        Args:
            time_step (TimeStep): time step structure. The ``prev_action`` from
                time_step will be used for predicting feature of the next step.
                It should be a Tensor of the shape [B, ...] or [B, n, ...] when
                n > 1, where n denotes the number of dynamics network replicas.
                When the input tensor has the shape of [B, ...] and n > 1, it
                will be first expanded to [B, n, ...] to match the number of
                dynamics network replicas.
            state (DynamicsState): state for dynamics learning with the
                following fields:
                - feature (Tensor): features of the previous observation of the
                    shape [B, ...] or [B, n, ...] when n > 1. When
                    ``state.feature`` has the shape of [B, ...] and n > 1, it
                    will be first expanded to [B, n, ...] to match the number
                    of dynamics network replicas.
                    It is used for predicting the feature of the next step
                    together with ``time_step.prev_action``.
                - network: the input state of the dynamics network
        Returns:
            AlgStep:
                outputs: empty tuple ()
                state (DynamicsState): with the following fields
                    - feature (Tensor): [B, ...] (or [B, n, ...] when n > 1)
                        shape tensor representing the predicted feature of
                        the next step
                    - network: the updated state of the dynamics network
                info (DynamicsInfo): with the following fields being updated:
                    - loss (LossInfo):
                    - dist (td.Distribution): the predictive distribution which
                        can be used for further calculation or summarization.
        """

        feature = time_step.observation
        feature = self._expand_to_replica(feature, self._feature_spec)
        dynamics_step = self.predict_step(time_step, state)

        dist = dynamics_step.info.dist
        forward_loss = -dist.log_prob(feature - state.feature)

        if forward_loss.ndim > 2:
            # [B, n, ...] -> [B, ...]
            forward_loss = forward_loss.sum(1)
        if forward_loss.ndim > 1:
            forward_loss = forward_loss.mean(list(range(1, forward_loss.ndim)))

        valid_masks = (time_step.step_type != StepType.FIRST).to(torch.float32)

        forward_loss = forward_loss * valid_masks

        info = DynamicsInfo(loss=LossInfo(
            loss=forward_loss, extra=dict(forward_loss=forward_loss)),
                            dist=dist)
        state = state._replace(feature=feature)

        return AlgStep(output=(), state=state, info=info)
예제 #21
0
    def predict_step(self, time_step: TimeStep, state: ActorCriticState,
                     epsilon_greedy):
        """Predict for one step."""
        action_dist, actor_state = self._actor_network(time_step.observation,
                                                       state=state.actor)

        action = dist_utils.epsilon_greedy_sample(action_dist, epsilon_greedy)
        return AlgStep(output=action,
                       state=ActorCriticState(actor=actor_state),
                       info=ActorCriticInfo(action_distribution=action_dist))
예제 #22
0
파일: prior_actor.py 프로젝트: zhuboli/alf
 def predict_step(self, time_step: TimeStep, state):
     flat_prev_action = alf.nest.flatten(time_step.prev_action)
     dists = [
         self._make_dist(time_step.step_type, prev_action,
                         spec) for prev_action, spec in zip(
                             flat_prev_action, self._prepared_specs)
     ]
     return AlgStep(
         output=alf.nest.pack_sequence_as(self._action_spec, dists),
         state=(),
         info=())
예제 #23
0
    def _predict_with_planning(self, time_step: TimeStep, state,
                               epsilon_greedy):
        # full state in
        action = self._planner_module.generate_plan(time_step, state,
                                                    epsilon_greedy)
        dynamics_state = self._dynamics_module.update_state(
            time_step, state.dynamics)

        return AlgStep(output=action,
                       state=state._replace(dynamics=dynamics_state),
                       info=MbrlInfo())
예제 #24
0
 def rollout_step(self, time_step, state):
     """This function updates the discriminator state."""
     (observation, _, switch_skill, _) = time_step.observation
     first_observation = self._update_state_if_necessary(
         switch_skill, observation, state.first_observation)
     subtrajectory = self._clear_subtrajectory_if_necessary(
         state.subtrajectory, switch_skill)
     return AlgStep(state=DiscriminatorState(
         first_observation=first_observation,
         untrans_observation=time_step.untransformed.observation,
         subtrajectory=subtrajectory))
예제 #25
0
 def train_step(self, exp: Experience, state: MbrlState):
     action = exp.action
     dynamics_step = self._dynamics_module.train_step(exp, state.dynamics)
     reward_step = self._reward_module.train_step(exp, state.reward)
     plan_step = self._planner_module.train_step(exp, state.planner)
     state = MbrlState(dynamics=dynamics_step.state,
                       reward=reward_step.state,
                       planner=plan_step.state)
     info = MbrlInfo(dynamics=dynamics_step.info,
                     reward=reward_step.info,
                     planner=plan_step.info)
     return AlgStep(action, state, info)
예제 #26
0
 def predict_step(self,
                  time_step: TimeStep,
                  state: SacState,
                  epsilon_greedy=1.0):
     action_dist, action, _, action_state = self._predict_action(
         time_step.observation,
         state=state.action,
         epsilon_greedy=epsilon_greedy,
         eps_greedy_sampling=True)
     return AlgStep(output=action,
                    state=SacState(action=action_state),
                    info=SacInfo(action_distribution=action_dist))
예제 #27
0
 def train_step(self, time_step: TimeStep, state):
     """
     Args:
         time_step (TimeStep): input data for planning
         state: state for planning (previous observation)
     Returns:
         AlgStep:
             output: empty tuple ()
             state (DynamicsState): state for training
             info (DynamicsInfo):
     """
     return AlgStep(output=(), state=state, info=())
예제 #28
0
 def predict_step(self, time_step: TimeStep, state, epsilon_greedy):
     mbp_step = self._mbp.predict_step(inputs=(time_step.observation,
                                               time_step.prev_action),
                                       state=state.mbp_state)
     mba_step = self._mba.predict_step(
         time_step=time_step._replace(observation=mbp_step.output),
         state=state.mba_state,
         epsilon_greedy=epsilon_greedy)
     return AlgStep(output=mba_step.output,
                    state=MerlinState(mbp_state=mbp_step.state,
                                      mba_state=mba_step.state),
                    info=())
예제 #29
0
    def predict_step(self, time_step: TimeStep, state: DynamicsState):
        """Predict the next observation given the current time_step.
                The next step is predicted using the ``prev_action`` from
                time_step and the ``feature`` from state.
        Args:
            time_step (TimeStep): time step structure. The ``prev_action`` from
                time_step will be used for predicting feature of the next step.
                It should be a Tensor of the shape [B, ...], or [B, n, ...] when
                n > 1, where n denotes the number of dynamics network replicas.
                When the input tensor has the shape of [B, ...] and n > 1,
                it will be first expanded to [B, n, ...] to match the number of
                dynamics network replicas.
            state (DynamicsState): state for dynamics learning with the
                following fields:
                - feature (Tensor): features of the previous observation of the
                    shape [B, ...], or [B, n, ...] when n > 1. When
                    ``state.feature`` has the shape of [B, ...] and n > 1,
                    it will be first expanded to [B, n, ...] to match the
                    number of dynamics network replicas.
                    It is used for predicting the feature of the next step
                    together with ``time_step.prev_action``.
                - network: the input state of the dynamics network
        Returns:
            AlgStep:
                outputs (Tensor): predicted feature of the next step, of the
                    shape [B, ...], or [B, n, ...] when n > 1.
                state (DynamicsState): with the following fields
                    - feature (Tensor): [B, n, ...] (or [B, n, ...] when n > 1)
                        shape tensor representing
                        the predicted feature of the next step
                    - network: the updated state of the dynamics network
                info (DynamicsInfo): with the following fields being updated:
                    - dist (td.Distribution): the predictive distribution which
                        can be used for further calculation or summarization.
        """
        action = self._encode_action(time_step.prev_action)
        obs = state.feature

        # perform preprocessing
        observations = self._expand_to_replica(obs, self._feature_spec)
        actions = self._expand_to_replica(action, self._action_spec)

        dist, network_states = self._dynamics_network((observations, actions),
                                                      state=state.network)

        forward_deltas = dist.sample()

        forward_preds = observations + forward_deltas
        state = state._replace(feature=forward_preds, network=network_states)
        return AlgStep(output=forward_preds,
                       state=state,
                       info=DynamicsInfo(dist=dist))
예제 #30
0
    def _step(self, time_step: TimeStep, state, calc_rewards=True):
        """
        Args:
            time_step (TimeStep): input time step data, where the
                observation is skill-augmened observation. The skill should be
                a one-hot vector.
            state (Tensor): state for DIAYN (previous skill) which should be
                a one-hot vector.
            calc_rewards (bool): if False, only return the losses.

        Returns:
            AlgStep:
                output: empty tuple ()
                state: skill
                info (DIAYNInfo):
        """
        observations_aug = time_step.observation
        step_type = time_step.step_type
        observation, skill = observations_aug
        prev_skill = state.detach()

        # normalize observation for easier prediction
        if self._observation_normalizer is not None:
            observation = self._observation_normalizer.normalize(observation)

        if self._encoding_net is not None:
            feature, _ = self._encoding_net(observation)

        skill_pred, _ = self._discriminator_net(feature)

        if self._skill_spec.is_discrete:
            loss = torch.nn.CrossEntropyLoss(reduction='none')(
                input=skill_pred, target=torch.argmax(prev_skill, dim=-1))
        else:
            # nn.MSELoss doesn't support reducing along a dim
            loss = torch.sum(math_ops.square(skill_pred - prev_skill), dim=-1)

        valid_masks = (step_type != to_tensor(StepType.FIRST)).to(
            torch.float32)
        loss *= valid_masks

        intrinsic_reward = ()
        if calc_rewards:
            intrinsic_reward = -loss.detach()
            intrinsic_reward = self._reward_normalizer.normalize(
                intrinsic_reward)

        return AlgStep(
            output=(),
            state=skill,
            info=DIAYNInfo(reward=intrinsic_reward, loss=loss))