コード例 #1
0
ファイル: agent.py プロジェクト: zhy52/ray
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = copy.deepcopy(self._default_config)
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.raw_user_config = config
        self.config = merged_config
        Agent._validate_config(self.config)
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        # TODO(ekl) setting the graph is unnecessary for PyTorch agents
        with tf.Graph().as_default():
            self._init()
コード例 #2
0
ファイル: trainer.py プロジェクト: skyofwinter/ray
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = copy.deepcopy(self._default_config)
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.raw_user_config = config
        self.config = merged_config

        if self.config["normalize_actions"]:
            inner = self.env_creator
            self.env_creator = (
                lambda env_config: NormalizeActionWrapper(inner(env_config)))

        Trainer._validate_config(self.config)
        log_level = self.config.get("log_level")
        if log_level in ["WARN", "ERROR"]:
            logger.info("Current log_level is {}. For more information, "
                        "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                        "-vv flags.".format(log_level))
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        def get_scope():
            if tf and not tf.executing_eagerly():
                return tf.Graph().as_default()
            else:
                return open("/dev/null")  # fake a no-op scope

        with get_scope():
            self._init(self.config, self.env_creator)

            # Evaluation setup.
            if self.config.get("evaluation_interval"):
                # Update env_config with evaluation settings:
                extra_config = copy.deepcopy(self.config["evaluation_config"])
                extra_config.update({
                    "batch_mode": "complete_episodes",
                    "batch_steps": 1,
                })
                logger.debug(
                    "using evaluation_config: {}".format(extra_config))

                self.evaluation_workers = self._make_workers(
                    self.env_creator,
                    self._policy,
                    merge_dicts(self.config, extra_config),
                    num_workers=self.config["evaluation_num_workers"])
                self.evaluation_metrics = {}
コード例 #3
0
ファイル: agent.py プロジェクト: robertnishihara/ray
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = copy.deepcopy(self._default_config)
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.raw_user_config = config
        self.config = merged_config
        Agent._validate_config(self.config)
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        # TODO(ekl) setting the graph is unnecessary for PyTorch agents
        with tf.Graph().as_default():
            self._init()
コード例 #4
0
ファイル: trainer.py プロジェクト: zommiommy/ray
    def merge_trainer_configs(cls, config1, config2):
        config1 = copy.deepcopy(config1)
        # Error if trainer default has deprecated value.
        if config1["sample_batch_size"] != DEPRECATED_VALUE:
            deprecation_warning("sample_batch_size",
                                new="rollout_fragment_length",
                                error=True)
        # Warning if user override config has deprecated value.
        if ("sample_batch_size" in config2
                and config2["sample_batch_size"] != DEPRECATED_VALUE):
            deprecation_warning("sample_batch_size",
                                new="rollout_fragment_length")
            config2["rollout_fragment_length"] = config2["sample_batch_size"]
            del config2["sample_batch_size"]
        if "callbacks" in config2 and type(config2["callbacks"]) is dict:
            legacy_callbacks_dict = config2["callbacks"]

            def make_callbacks():
                # Deprecation warning will be logged by DefaultCallbacks.
                return DefaultCallbacks(
                    legacy_callbacks_dict=legacy_callbacks_dict)

            config2["callbacks"] = make_callbacks
        return deep_update(config1, config2, cls._allow_unknown_configs,
                           cls._allow_unknown_subkeys,
                           cls._override_all_subkeys_if_type_changes)
コード例 #5
0
 def merge_defaults_with(self, config: Config) -> Config:
     """Deep merge the given config with the defaults."""
     defaults = copy.deepcopy(self.defaults)
     new = deep_update(
         defaults,
         config,
         new_keys_allowed=False,
         allow_new_subkey_list=self.allow_unknown_subkeys,
         override_all_if_type_changes=self.override_all_if_type_changes,
     )
     return new
コード例 #6
0
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = copy.deepcopy(self._default_config)
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.raw_user_config = config
        self.config = merged_config
        Trainer._validate_config(self.config)
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        def get_scope():
            if tf:
                return tf.Graph().as_default()
            else:
                return open("/dev/null")  # fake a no-op scope

        with get_scope():
            self._init(self.config, self.env_creator)

            # Evaluation related
            if self.config.get("evaluation_interval"):
                # Update env_config with evaluation settings:
                extra_config = copy.deepcopy(self.config["evaluation_config"])
                extra_config.update({
                    "batch_mode": "complete_episodes",
                    "batch_steps": 1,
                })
                logger.debug(
                    "using evaluation_config: {}".format(extra_config))
                self.evaluation_workers = self._make_workers(self.env_creator,
                                                             self._policy,
                                                             merge_dicts(
                                                                 self.config,
                                                                 extra_config),
                                                             num_workers=0)
                self.evaluation_metrics = self._evaluate()
コード例 #7
0
def deep_merge(
    dict1,
    dict2,
    new_keys_allowed: bool = False,
    allow_new_subkey_list: Optional[list] = None,
    override_all_if_type_changes: Optional[list] = None,
):
    """Deep copy original dict and pass it to RLlib's deep_update."""
    clone = copy.deepcopy(dict1)
    return deep_update(
        clone,
        dict2,
        new_keys_allowed=new_keys_allowed,
        allow_new_subkey_list=allow_new_subkey_list,
        override_all_if_type_changes=override_all_if_type_changes,
    )
コード例 #8
0
ファイル: trainer.py プロジェクト: w0617/ray
 def merge_trainer_configs(cls, config1, config2):
     config1 = copy.deepcopy(config1)
     # Error if trainer default has deprecated value.
     if config1["sample_batch_size"] != DEPRECATED_VALUE:
         deprecation_warning("sample_batch_size",
                             new="rollout_fragment_length",
                             error=True)
     # Warning if user override config has deprecated value.
     if ("sample_batch_size" in config2
             and config2["sample_batch_size"] != DEPRECATED_VALUE):
         deprecation_warning("sample_batch_size",
                             new="rollout_fragment_length")
         config2["rollout_fragment_length"] = config2["sample_batch_size"]
         del config2["sample_batch_size"]
     return deep_update(config1, config2, cls._allow_unknown_configs,
                        cls._allow_unknown_subkeys,
                        cls._override_all_subkeys_if_type_changes)
コード例 #9
0
ファイル: trainer.py プロジェクト: zhaokang1228/ray
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = copy.deepcopy(self._default_config)
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.raw_user_config = config
        self.config = merged_config
        Trainer._validate_config(self.config)
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        # TODO(ekl) setting the graph is unnecessary for PyTorch agents
        with tf.Graph().as_default():
            self._init(self.config, self.env_creator)

            # Evaluation related
            if self.config.get("evaluation_interval"):
                # Update env_config with evaluation settings:
                extra_config = copy.deepcopy(self.config["evaluation_config"])
                extra_config.update({
                    "batch_mode": "complete_episodes",
                    "batch_steps": 1,
                })
                logger.debug(
                    "using evaluation_config: {}".format(extra_config))
                # Make local evaluation evaluators
                self.evaluation_ev = self.make_local_evaluator(
                    self.env_creator,
                    self._policy_graph,
                    extra_config=extra_config)
                self.evaluation_metrics = self._evaluate()
コード例 #10
0
ファイル: agent.py プロジェクト: gavinljj/ray
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = self._default_config.copy()
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.config = merged_config

        # TODO(ekl) setting the graph is unnecessary for PyTorch agents
        with tf.Graph().as_default():
            self._init()
コード例 #11
0
ファイル: trainer.py プロジェクト: zqxyz73/ray
 def merge_trainer_configs(cls, config1, config2):
     config1 = copy.deepcopy(config1)
     return deep_update(config1, config2, cls._allow_unknown_configs,
                        cls._allow_unknown_subkeys,
                        cls._override_all_subkeys_if_type_changes)
コード例 #12
0
ファイル: simple_q.py プロジェクト: tchordia/ray
    def training(
        self,
        *,
        target_network_update_freq: Optional[int] = None,
        replay_buffer_config: Optional[dict] = None,
        store_buffer_in_checkpoints: Optional[bool] = None,
        lr_schedule: Optional[List[List[Union[int, float]]]] = None,
        adam_epsilon: Optional[float] = None,
        grad_clip: Optional[int] = None,
        **kwargs,
    ) -> "SimpleQConfig":
        """Sets the training related configuration.

        Args:
            timesteps_per_iteration: Minimum env steps to optimize for per train call.
                This value does not affect learning, only the length of iterations.
            target_network_update_freq: Update the target network every
                `target_network_update_freq` sample steps.
            replay_buffer_config: Replay buffer config.
                Examples:
                {
                "_enable_replay_buffer_api": True,
                "type": "MultiAgentReplayBuffer",
                "learning_starts": 1000,
                "capacity": 50000,
                "replay_sequence_length": 1,
                }
                - OR -
                {
                "_enable_replay_buffer_api": True,
                "type": "MultiAgentPrioritizedReplayBuffer",
                "capacity": 50000,
                "prioritized_replay_alpha": 0.6,
                "prioritized_replay_beta": 0.4,
                "prioritized_replay_eps": 1e-6,
                "replay_sequence_length": 1,
                }
                - Where -
                prioritized_replay_alpha: Alpha parameter controls the degree of
                prioritization in the buffer. In other words, when a buffer sample has
                a higher temporal-difference error, with how much more probability
                should it drawn to use to update the parametrized Q-network. 0.0
                corresponds to uniform probability. Setting much above 1.0 may quickly
                result as the sampling distribution could become heavily “pointy” with
                low entropy.
                prioritized_replay_beta: Beta parameter controls the degree of
                importance sampling which suppresses the influence of gradient updates
                from samples that have higher probability of being sampled via alpha
                parameter and the temporal-difference error.
                prioritized_replay_eps: Epsilon parameter sets the baseline probability
                for sampling so that when the temporal-difference error of a sample is
                zero, there is still a chance of drawing the sample.
            store_buffer_in_checkpoints: Set this to True, if you want the contents of
                your buffer(s) to be stored in any saved checkpoints as well.
                Warnings will be created if:
                - This is True AND restoring from a checkpoint that contains no buffer
                data.
                - This is False AND restoring from a checkpoint that does contain
                buffer data.
            lr_schedule: Learning rate schedule. In the format of [[timestep, value],
                [timestep, value], ...]. A schedule should normally start from
                timestep 0.
            adam_epsilon: Adam optimizer's epsilon hyper parameter.
            grad_clip: If not None, clip gradients during optimization at this value.

        Returns:
            This updated TrainerConfig object.
        """
        # Pass kwargs onto super's `training()` method.
        super().training(**kwargs)

        if target_network_update_freq is not None:
            self.target_network_update_freq = target_network_update_freq
        if replay_buffer_config is not None:
            # Override entire `replay_buffer_config` if `type` key changes.
            # Update, if `type` key remains the same or is not specified.
            new_replay_buffer_config = deep_update(
                {"replay_buffer_config": self.replay_buffer_config},
                {"replay_buffer_config": replay_buffer_config},
                False,
                ["replay_buffer_config"],
                ["replay_buffer_config"],
            )
            self.replay_buffer_config = new_replay_buffer_config[
                "replay_buffer_config"]
        if store_buffer_in_checkpoints is not None:
            self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
        if lr_schedule is not None:
            self.lr_schedule = lr_schedule
        if adam_epsilon is not None:
            self.adam_epsilon = adam_epsilon
        if grad_clip is not None:
            self.grad_clip = grad_clip

        return self