Ejemplo n.º 1
0
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if before_init:
                before_init(self, obs_space, action_space, config)

            if make_model_and_action_dist:
                self.model, self.dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], torch=True)
                self.model = ModelCatalog.get_model_v2(obs_space,
                                                       action_space,
                                                       logit_dim,
                                                       self.config["model"],
                                                       framework="torch")

            TorchPolicy.__init__(self, obs_space, action_space, config,
                                 self.model, loss_fn, self.dist_class)

            if after_init:
                after_init(self, obs_space, action_space, config)
Ejemplo n.º 2
0
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if before_init:
                before_init(self, obs_space, action_space, config)

            if make_model_and_action_dist:
                self.model, self.dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
                # Make sure, we passed in a correct Model factory.
                assert isinstance(self.model, TorchModelV2), \
                    "ERROR: TorchPolicy::make_model_and_action_dist must " \
                    "return a TorchModelV2 object!"
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(obs_space,
                                                       action_space,
                                                       logit_dim,
                                                       self.config["model"],
                                                       framework="torch")

            TorchPolicy.__init__(self, obs_space, action_space, config,
                                 self.model, loss_fn, self.dist_class)

            if after_init:
                after_init(self, obs_space, action_space, config)
Ejemplo n.º 3
0
    def __init__(self, observation_space, action_space, config):
        config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
        setup_config(self, observation_space, action_space, config)

        TorchPolicy.__init__(
            self,
            observation_space,
            action_space,
            config,
            max_seq_len=config["model"]["max_seq_len"],
        )

        ValueNetworkMixin.__init__(self, config)
        EntropyCoeffSchedule.__init__(self, config["entropy_coeff"],
                                      config["entropy_coeff_schedule"])
        LearningRateSchedule.__init__(self, config["lr"],
                                      config["lr_schedule"])

        # The current KL value (as python float).
        self.kl_coeff = self.config["kl_coeff"]
        # Constant target value.
        self.kl_target = self.config["kl_target"]

        # TODO: Don't require users to call this manually.
        self._initialize_loss_from_dummy_batch()
Ejemplo n.º 4
0
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if validate_spaces:
                validate_spaces(self, obs_space, action_space, self.config)

            if before_init:
                before_init(self, obs_space, action_space, self.config)

            # Model is customized (use default action dist class).
            if make_model:
                assert make_model_and_action_dist is None, \
                    "Either `make_model` or `make_model_and_action_dist`" \
                    " must be None!"
                self.model = make_model(self, obs_space, action_space, config)
                dist_class, _ = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
            # Model and action dist class are customized.
            elif make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            # Use default model and default action dist.
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework="torch")

            # Make sure, we passed in a correct Model factory.
            assert isinstance(self.model, TorchModelV2), \
                "ERROR: Generated Model must be a TorchModelV2 object!"

            TorchPolicy.__init__(
                self,
                observation_space=obs_space,
                action_space=action_space,
                config=config,
                model=self.model,
                loss=loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            if callable(training_view_requirements_fn):
                self.training_view_requirements.update(
                    training_view_requirements_fn(self))

            if after_init:
                after_init(self, obs_space, action_space, config)
Ejemplo n.º 5
0
 def extra_action_out(self, input_dict, state_batches, model_out):
     if extra_action_out_fn:
         return extra_action_out_fn(self, input_dict, state_batches,
                                    model_out)
     else:
         return TorchPolicy.extra_action_out(self, input_dict,
                                             state_batches, model_out)
Ejemplo n.º 6
0
 def extra_grad_info(self, train_batch):
     with torch.no_grad():
         if stats_fn:
             stats_dict = stats_fn(self, train_batch)
         else:
             stats_dict = TorchPolicy.extra_grad_info(self, train_batch)
         return convert_to_non_torch_type(stats_dict)
Ejemplo n.º 7
0
 def extra_compute_grad_fetches(self):
     if extra_learn_fetches_fn:
         fetches = convert_to_non_torch_type(
             extra_learn_fetches_fn(self))
         # Auto-add empty learner stats dict if needed.
         return dict({LEARNER_STATS_KEY: {}}, **fetches)
     else:
         return TorchPolicy.extra_compute_grad_fetches(self)
Ejemplo n.º 8
0
 def optimizer(self):
     if optimizer_fn:
         optimizers = optimizer_fn(self, self.config)
     else:
         optimizers = TorchPolicy.optimizer(self)
     optimizers = force_list(optimizers)
     if getattr(self, "exploration", None):
         optimizers = self.exploration.get_exploration_optimizer(
             optimizers)
     return optimizers
Ejemplo n.º 9
0
 def extra_action_out(self, input_dict, state_batches, model,
                      action_dist):
     with torch.no_grad():
         if extra_action_out_fn:
             stats_dict = extra_action_out_fn(
                 self, input_dict, state_batches, model, action_dist)
         else:
             stats_dict = TorchPolicy.extra_action_out(
                 self, input_dict, state_batches, model, action_dist)
         return convert_to_non_torch_type(stats_dict)
Ejemplo n.º 10
0
        def extra_grad_process(self, optimizer, loss):
            """Called after optimizer.zero_grad() and loss.backward() calls.

            Allows for gradient processing before optimizer.step() is called.
            E.g. for gradient clipping.
            """
            if extra_grad_process_fn:
                return extra_grad_process_fn(self, optimizer, loss)
            else:
                return TorchPolicy.extra_grad_process(self, optimizer, loss)
Ejemplo n.º 11
0
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if before_init:
                before_init(self, obs_space, action_space, config)

            if make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
                # Make sure, we passed in a correct Model factory.
                assert isinstance(self.model, TorchModelV2), \
                    "ERROR: TorchPolicy::make_model_and_action_dist must " \
                    "return a TorchModelV2 object!"
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework="torch",
                    **self.config["model"].get("custom_options", {}))

            TorchPolicy.__init__(
                self,
                obs_space,
                action_space,
                config,
                model=self.model,
                loss=loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            if after_init:
                after_init(self, obs_space, action_space, config)
Ejemplo n.º 12
0
 def optimizer(self):
     if optimizer_fn:
         optimizers = optimizer_fn(self, self.config)
     else:
         optimizers = TorchPolicy.optimizer(self)
     optimizers = force_list(optimizers)
     if hasattr(self, "exploration"):
         exploration_optimizers = force_list(
             self.exploration.get_exploration_optimizer(self.config))
         optimizers.extend(exploration_optimizers)
     return optimizers
Ejemplo n.º 13
0
 def set_weights(self, weights):
     # Makes sure that whenever we restore weights for this policy's
     # model, we sync the target network (from the main model)
     # at the same time.
     TorchPolicy.set_weights(self, weights)
     self.update_target()
Ejemplo n.º 14
0
 def extra_grad_info(self, batch_tensors):
     if stats_fn:
         return stats_fn(self, batch_tensors)
     else:
         return TorchPolicy.extra_grad_info(self, batch_tensors)
Ejemplo n.º 15
0
 def optimizer(self):
     if optimizer_fn:
         return optimizer_fn(self, self.config)
     else:
         return TorchPolicy.optimizer(self)
Ejemplo n.º 16
0
 def extra_grad_process(self):
     if extra_grad_process_fn:
         return extra_grad_process_fn(self)
     else:
         return TorchPolicy.extra_grad_process(self)
Ejemplo n.º 17
0
 def extra_grad_info(self, train_batch):
     if stats_fn:
         return stats_fn(self, train_batch)
     else:
         return TorchPolicy.extra_grad_info(self, train_batch)
Ejemplo n.º 18
0
 def apply_gradients(self, gradients):
     if apply_gradients_fn:
         apply_gradients_fn(self, gradients)
     else:
         TorchPolicy.apply_gradients(self, gradients)
Ejemplo n.º 19
0
 def extra_action_out(self, model_out):
     if extra_action_out_fn:
         return extra_action_out_fn(self, model_out)
     else:
         return TorchPolicy.extra_action_out(self, model_out)
Ejemplo n.º 20
0
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if validate_spaces:
                validate_spaces(self, obs_space, action_space, self.config)

            if before_init:
                before_init(self, obs_space, action_space, self.config)

            # Model is customized (use default action dist class).
            if make_model:
                assert make_model_and_action_dist is None, \
                    "Either `make_model` or `make_model_and_action_dist`" \
                    " must be None!"
                self.model = make_model(self, obs_space, action_space, config)
                dist_class, _ = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
            # Model and action dist class are customized.
            elif make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            # Use default model and default action dist.
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework="torch")

            # Make sure, we passed in a correct Model factory.
            assert isinstance(self.model, TorchModelV2), \
                "ERROR: Generated Model must be a TorchModelV2 object!"

            TorchPolicy.__init__(
                self,
                observation_space=obs_space,
                action_space=action_space,
                config=config,
                model=self.model,
                loss=loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            # Update this Policy's ViewRequirements (if function given).
            if callable(view_requirements_fn):
                self.view_requirements.update(view_requirements_fn(self))
            # Merge Model's view requirements into Policy's.
            self.view_requirements.update(
                self.model.inference_view_requirements)

            _before_loss_init = before_loss_init or after_init
            if _before_loss_init:
                _before_loss_init(self, self.observation_space,
                                  self.action_space, config)

            # Perform test runs through postprocessing- and loss functions.
            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=stats_fn,
            )

            if _after_loss_init:
                _after_loss_init(self, obs_space, action_space, config)

            # Got to reset global_timestep again after this fake run-through.
            self.global_timestep = 0