Пример #1
0
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if before_init:
                before_init(self, obs_space, action_space, config)

            if make_model_and_action_dist:
                self.model, self.dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
                # Make sure, we passed in a correct Model factory.
                assert isinstance(self.model, TorchModelV2), \
                    "ERROR: TorchPolicy::make_model_and_action_dist must " \
                    "return a TorchModelV2 object!"
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(obs_space,
                                                       action_space,
                                                       logit_dim,
                                                       self.config["model"],
                                                       framework="torch")

            TorchPolicy.__init__(self, obs_space, action_space, config,
                                 self.model, loss_fn, self.dist_class)

            if after_init:
                after_init(self, obs_space, action_space, config)
Пример #2
0
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if before_init:
                before_init(self, obs_space, action_space, config)

            if make_model_and_action_dist:
                self.model, self.dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], torch=True)
                self.model = ModelCatalog.get_model_v2(obs_space,
                                                       action_space,
                                                       logit_dim,
                                                       self.config["model"],
                                                       framework="torch")

            TorchPolicy.__init__(self, obs_space, action_space, config,
                                 self.model, loss_fn, self.dist_class)

            if after_init:
                after_init(self, obs_space, action_space, config)
Пример #3
0
    def __init__(self, observation_space, action_space, config):
        config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
        setup_config(self, observation_space, action_space, config)

        TorchPolicy.__init__(
            self,
            observation_space,
            action_space,
            config,
            max_seq_len=config["model"]["max_seq_len"],
        )

        ValueNetworkMixin.__init__(self, config)
        EntropyCoeffSchedule.__init__(self, config["entropy_coeff"],
                                      config["entropy_coeff_schedule"])
        LearningRateSchedule.__init__(self, config["lr"],
                                      config["lr_schedule"])

        # The current KL value (as python float).
        self.kl_coeff = self.config["kl_coeff"]
        # Constant target value.
        self.kl_target = self.config["kl_target"]

        # TODO: Don't require users to call this manually.
        self._initialize_loss_from_dummy_batch()
Пример #4
0
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if validate_spaces:
                validate_spaces(self, obs_space, action_space, self.config)

            if before_init:
                before_init(self, obs_space, action_space, self.config)

            # Model is customized (use default action dist class).
            if make_model:
                assert make_model_and_action_dist is None, \
                    "Either `make_model` or `make_model_and_action_dist`" \
                    " must be None!"
                self.model = make_model(self, obs_space, action_space, config)
                dist_class, _ = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
            # Model and action dist class are customized.
            elif make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            # Use default model and default action dist.
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework="torch")

            # Make sure, we passed in a correct Model factory.
            assert isinstance(self.model, TorchModelV2), \
                "ERROR: Generated Model must be a TorchModelV2 object!"

            TorchPolicy.__init__(
                self,
                observation_space=obs_space,
                action_space=action_space,
                config=config,
                model=self.model,
                loss=loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            if callable(training_view_requirements_fn):
                self.training_view_requirements.update(
                    training_view_requirements_fn(self))

            if after_init:
                after_init(self, obs_space, action_space, config)
Пример #5
0
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if before_init:
                before_init(self, obs_space, action_space, config)

            if make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
                # Make sure, we passed in a correct Model factory.
                assert isinstance(self.model, TorchModelV2), \
                    "ERROR: TorchPolicy::make_model_and_action_dist must " \
                    "return a TorchModelV2 object!"
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework="torch",
                    **self.config["model"].get("custom_options", {}))

            TorchPolicy.__init__(
                self,
                obs_space,
                action_space,
                config,
                model=self.model,
                loss=loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            if after_init:
                after_init(self, obs_space, action_space, config)
Пример #6
0
        def __init__(self, obs_space, action_space, config):
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            if validate_spaces:
                validate_spaces(self, obs_space, action_space, self.config)

            if before_init:
                before_init(self, obs_space, action_space, self.config)

            # Model is customized (use default action dist class).
            if make_model:
                assert make_model_and_action_dist is None, \
                    "Either `make_model` or `make_model_and_action_dist`" \
                    " must be None!"
                self.model = make_model(self, obs_space, action_space, config)
                dist_class, _ = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
            # Model and action dist class are customized.
            elif make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            # Use default model and default action dist.
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework="torch")
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework="torch")

            # Make sure, we passed in a correct Model factory.
            assert isinstance(self.model, TorchModelV2), \
                "ERROR: Generated Model must be a TorchModelV2 object!"

            TorchPolicy.__init__(
                self,
                observation_space=obs_space,
                action_space=action_space,
                config=config,
                model=self.model,
                loss=loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            # Update this Policy's ViewRequirements (if function given).
            if callable(view_requirements_fn):
                self.view_requirements.update(view_requirements_fn(self))
            # Merge Model's view requirements into Policy's.
            self.view_requirements.update(
                self.model.inference_view_requirements)

            _before_loss_init = before_loss_init or after_init
            if _before_loss_init:
                _before_loss_init(self, self.observation_space,
                                  self.action_space, config)

            # Perform test runs through postprocessing- and loss functions.
            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=stats_fn,
            )

            if _after_loss_init:
                _after_loss_init(self, obs_space, action_space, config)

            # Got to reset global_timestep again after this fake run-through.
            self.global_timestep = 0