예제 #1
0
    def build_trainer(self, use_gpu: bool) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data)

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
        q2_network = (critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
                      if self.use_2_q_functions else None)

        value_network = None
        if self.value_net_builder:
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            value_net_builder = self.value_net_builder.value
            value_network = value_net_builder.build_value_network(
                self.state_normalization_data)

        trainer = SACTrainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            value_network=value_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer
예제 #2
0
    def build_trainer(self) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data
        )

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data
        )
        q2_network = (
            critic_net_builder.build_q_network(
                self.state_normalization_data, self.action_normalization_data
            )
            if self.use_2_q_functions
            else None
        )

        value_network = None
        if self.value_net_builder:
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            value_net_builder = self.value_net_builder.value
            value_network = value_net_builder.build_value_network(
                self.state_normalization_data
            )

        if self.use_gpu:
            self._q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            if value_network:
                value_network.cuda()
            self._actor_network.cuda()

        # pyre-fixme[29]: `Type[reagent.training.sac_trainer.SACTrainer]` is not a
        #  function.
        # pyre-fixme[29]: `Type[reagent.training.sac_trainer.SACTrainer]` is not a
        #  function.
        trainer = SACTrainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            value_network=value_network,
            q2_network=q2_network,
            use_gpu=self.use_gpu,
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer
예제 #3
0
파일: sac.py 프로젝트: lwzbuaa/ReAgent
    def build_trainer(self) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data
        )

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data
        )
        q2_network = (
            critic_net_builder.build_q_network(
                self.state_normalization_data, self.action_normalization_data
            )
            if self.use_2_q_functions
            else None
        )

        value_network = None
        if self.value_net_builder:
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            value_net_builder = self.value_net_builder.value
            value_network = value_net_builder.build_value_network(
                self.state_normalization_data
            )

        if self.use_gpu:
            self._q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            if value_network:
                value_network.cuda()
            self._actor_network.cuda()

        trainer = SACTrainer(
            self._q1_network,
            self._actor_network,
            self.trainer_param,
            value_network=value_network,
            q2_network=q2_network,
            use_gpu=self.use_gpu,
        )
        return trainer
예제 #4
0
    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        actor_network = actor_net_builder.build_actor(
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )

        critic_net_builder = self.critic_net_builder.value
        q1_network = critic_net_builder.build_q_network(
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )
        q2_network = (
            critic_net_builder.build_q_network(
                normalization_data_map[NormalizationKey.STATE],
                normalization_data_map[NormalizationKey.ACTION],
            )
            if self.use_2_q_functions
            else None
        )

        value_network = None
        value_net_builder = self.value_net_builder
        if value_net_builder:
            value_net_builder = value_net_builder.value
            value_network = value_net_builder.build_value_network(
                normalization_data_map[NormalizationKey.STATE]
            )

        trainer = SACTrainer(
            actor_network=actor_network,
            q1_network=q1_network,
            value_network=value_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer
예제 #5
0
    def build_trainer(self) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.get_normalization_data(NormalizationKey.STATE),
            self.get_normalization_data(NormalizationKey.ACTION),
        )

        critic_net_builder = self.critic_net_builder.value
        q1_network = critic_net_builder.build_q_network(
            self.state_normalization_parameters,
            self.action_normalization_parameters)
        q2_network = (critic_net_builder.build_q_network(
            self.state_normalization_parameters,
            self.action_normalization_parameters,
        ) if self.use_2_q_functions else None)

        value_network = None
        if self.value_net_builder:
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            value_net_builder = self.value_net_builder.value
            value_network = value_net_builder.build_value_network(
                self.get_normalization_data(NormalizationKey.STATE))

        if self.use_gpu:
            q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            if value_network:
                value_network.cuda()
            self._actor_network.cuda()

        trainer = SACTrainer(
            q1_network,
            self._actor_network,
            self.trainer_param,
            value_network=value_network,
            q2_network=q2_network,
            use_gpu=self.use_gpu,
        )
        return trainer