示例#1
0
    def make_optimizers(self):
        config = dutil.deep_merge(
            DEFAULT_OPTIM_CONFIG,
            self.config["torch_optimizer"],
            False,
            [],
            ["actor", "critic"],
        )
        assert config["actor"]["type"] in [
            "KFAC",
            "EKFAC",
        ], "ACKTR must use optimizer with Kronecker Factored curvature estimation."

        return {
            "actor": build_optimizer(self.module.actor, config["actor"]),
            "critic": build_optimizer(self.module.critic, config["critic"]),
        }
示例#2
0
    def make_optimizers(self):
        config = self.config["torch_optimizer"]
        components = "models actor critics alpha".split()

        return {
            name: build_optimizer(self.module[name], config[name])
            for name in components
        }
示例#3
0
    def make_optimizers(self):
        config = self.config["torch_optimizer"]
        components = "model actor critics".split()
        if self.config["true_model"]:
            components = components[1:]

        return {
            name: build_optimizer(self.module[name], config[name])
            for name in components
        }
示例#4
0
    def make_optimizers(self):
        config = self.config["torch_optimizer"]
        components = {
            "model": self.module.model,
            "actor": self.module.actor,
            "critic": self.module.critic,
            "alpha": self.module.alpha,
        }

        return {
            name: build_optimizer(module, config[name])
            for name, module in components.items()
        }
示例#5
0
    def make_optimizers(self):
        """PyTorch optimizers to use."""
        config = self.config["torch_optimizer"]
        component_map = {
            "on_policy": self.module.actor,
            "off_policy":
            nn.ModuleList([self.module.model, self.module.critic]),
        }

        return {
            name: build_optimizer(module, config[name])
            for name, module in component_map.items()
        }
示例#6
0
 def make_optimizers(self):
     return {
         "models": build_optimizer(self.module.models, {"type": "Adam"})
     }
示例#7
0
 def make_optimizers(self):
     return {
         "critic":
         build_optimizer(self.module.critic,
                         self.config["critic_optimizer"])
     }
示例#8
0
 def make_optimizers(self):
     return {
         "naf":
         build_optimizer(self.module.critics,
                         self.config["torch_optimizer"])
     }