Exemplo n.º 1
0
    def __init__(self, obs_space, action_space, config):
        config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
        self.config = config
        _, self.logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
                                                  self.config["model"])
        loss = PGLoss(self.model)

        TorchPolicyGraph.__init__(self,
                                  obs_space,
                                  action_space,
                                  self.model,
                                  loss,
                                  loss_inputs=["obs", "actions", "advantages"])
Exemplo n.º 2
0
 def __init__(self, obs_space, action_space, config):
     config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
     self.config = config
     _, self.logit_dim = ModelCatalog.get_action_dist(
         action_space, self.config["model"])
     self.model = ModelCatalog.get_torch_model(
         obs_space.shape, self.logit_dim, self.config["model"])
     loss = A3CLoss(self.model, self.config["vf_loss_coeff"],
                    self.config["entropy_coeff"])
     TorchPolicyGraph.__init__(
         self,
         obs_space,
         action_space,
         self.model,
         loss,
         loss_inputs=["obs", "actions", "advantages", "value_targets"])
Exemplo n.º 3
0
 def __init__(self, obs_space, action_space, config):
     config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
     self.config = config
     _, self.logit_dim = ModelCatalog.get_action_dist(
         action_space, self.config["model"])
     self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
                                               self.config["model"])
     loss = A3CLoss(self.model, self.config["vf_loss_coeff"],
                    self.config["entropy_coeff"])
     TorchPolicyGraph.__init__(
         self,
         obs_space,
         action_space,
         self.model,
         loss,
         loss_inputs=["obs", "actions", "advantages", "value_targets"])
Exemplo n.º 4
0
    def __init__(self, obs_space, action_space, config):
        config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
        self.config = config
        _, self.logit_dim = ModelCatalog.get_action_dist(
            action_space, self.config["model"])
        self.model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
                                                  self.config["model"])
        loss = PGLoss(self.model)

        TorchPolicyGraph.__init__(
            self,
            obs_space,
            action_space,
            self.model,
            loss,
            loss_inputs=[
                SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
                Postprocessing.ADVANTAGES
            ])
Exemplo n.º 5
0
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if before_init:
                before_init(self, obs_space, action_space, config)

            if make_model_and_action_dist:
                self.model, self.dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], torch=True)
                self.model = ModelCatalog.get_torch_model(
                    obs_space, logit_dim, self.config["model"])

            TorchPolicyGraph.__init__(self, obs_space, action_space,
                                      self.model, loss_fn, self.dist_class)

            if after_init:
                after_init(self, obs_space, action_space, config)
Exemplo n.º 6
0
 def extra_grad_info(self, batch_tensors):
     if stats_fn:
         return stats_fn(self, batch_tensors)
     else:
         return TorchPolicyGraph.extra_grad_info(self, batch_tensors)
Exemplo n.º 7
0
 def optimizer(self):
     if optimizer_fn:
         return optimizer_fn(self, self.config)
     else:
         return TorchPolicyGraph.optimizer(self)
Exemplo n.º 8
0
 def extra_action_out(self, model_out):
     if extra_action_out_fn:
         return extra_action_out_fn(self, model_out)
     else:
         return TorchPolicyGraph.extra_action_out(self, model_out)
Exemplo n.º 9
0
 def extra_grad_process(self):
     if extra_grad_process_fn:
         return extra_grad_process_fn(self)
     else:
         return TorchPolicyGraph.extra_grad_process(self)