示例#1
0
文件: agent.py 项目: zhy52/ray
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = copy.deepcopy(self._default_config)
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.raw_user_config = config
        self.config = merged_config
        Agent._validate_config(self.config)
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        # TODO(ekl) setting the graph is unnecessary for PyTorch agents
        with tf.Graph().as_default():
            self._init()
示例#2
0
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = copy.deepcopy(self._default_config)
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.raw_user_config = config
        self.config = merged_config

        if self.config["normalize_actions"]:
            inner = self.env_creator
            self.env_creator = (
                lambda env_config: NormalizeActionWrapper(inner(env_config)))

        Trainer._validate_config(self.config)
        log_level = self.config.get("log_level")
        if log_level in ["WARN", "ERROR"]:
            logger.info("Current log_level is {}. For more information, "
                        "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                        "-vv flags.".format(log_level))
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        def get_scope():
            if tf and not tf.executing_eagerly():
                return tf.Graph().as_default()
            else:
                return open("/dev/null")  # fake a no-op scope

        with get_scope():
            self._init(self.config, self.env_creator)

            # Evaluation setup.
            if self.config.get("evaluation_interval"):
                # Update env_config with evaluation settings:
                extra_config = copy.deepcopy(self.config["evaluation_config"])
                extra_config.update({
                    "batch_mode": "complete_episodes",
                    "batch_steps": 1,
                })
                logger.debug(
                    "using evaluation_config: {}".format(extra_config))

                self.evaluation_workers = self._make_workers(
                    self.env_creator,
                    self._policy,
                    merge_dicts(self.config, extra_config),
                    num_workers=self.config["evaluation_num_workers"])
                self.evaluation_metrics = {}
示例#3
0
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = copy.deepcopy(self._default_config)
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.raw_user_config = config
        self.config = merged_config
        Agent._validate_config(self.config)
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        # TODO(ekl) setting the graph is unnecessary for PyTorch agents
        with tf.Graph().as_default():
            self._init()
示例#4
0
def validate_config(config):
    validate_config_and_setup_param_noise(config)

    # Hard-coded this setting
    # assert not config["normalize_actions"]
    # assert not config["env_config"]["normalize_actions"]
    assert not config["normalize_actions"]
    assert config["env_config"]["normalize_actions"]

    # create multi-agent environment
    assert _global_registry.contains(ENV_CREATOR, config["env"])
    env_creator = _global_registry.get(ENV_CREATOR, config["env"])
    tmp_env = env_creator(config["env_config"])
    config["multiagent"]["policies"] = {
        i: (None, tmp_env.observation_space, tmp_env.action_space, {})
        for i in tmp_env.agent_ids
    }
    config["multiagent"]["policy_mapping_fn"] = lambda x: x

    # check the model
    if config[USE_DIVERSITY_VALUE_NETWORK]:
        raise NotImplementedError()
        # ModelCatalog.register_custom_model(
        #     "ActorDoubleCriticNetwork", ActorDoubleCriticNetwork
        # )
        # config['model']['custom_model'] = "ActorDoubleCriticNetwork"
        # config['model']['custom_options'] = {
        #     "use_diversity_value_network": config[USE_DIVERSITY_VALUE_NETWORK]
        # }
    else:
        config['model']['custom_model'] = None
        config['model']['custom_options'] = None
示例#5
0
def validate_config(config):
    """Validate the config"""

    # create multi-agent environment
    assert _global_registry.contains(ENV_CREATOR, config["env"])
    env_creator = _global_registry.get(ENV_CREATOR, config["env"])
    tmp_env = env_creator(config["env_config"])
    config["multiagent"]["policies"] = {
        i: (None, tmp_env.observation_space, tmp_env.action_space, {})
        for i in tmp_env.agent_ids
    }
    config["multiagent"]["policy_mapping_fn"] = lambda x: x

    # check the model
    if config[USE_DIVERSITY_VALUE_NETWORK]:
        ModelCatalog.register_custom_model("ActorDoubleCriticNetwork",
                                           ActorDoubleCriticNetwork)
        config['model']['custom_model'] = "ActorDoubleCriticNetwork"
        config['model']['custom_options'] = {
            "use_diversity_value_network": config[USE_DIVERSITY_VALUE_NETWORK]
        }
    else:
        config['model']['custom_model'] = None
        config['model']['custom_options'] = None

    # validate other elements of PPO config
    validate_config_original(config)
示例#6
0
def validate_config(config):
    # assert REPLAY_VALUES not in config
    # config[REPLAY_VALUES] = config[REPLAY_VALUES]

    # create multi-agent environment
    assert _global_registry.contains(ENV_CREATOR, config["env"])
    env_creator = _global_registry.get(ENV_CREATOR, config["env"])
    tmp_env = env_creator(config["env_config"])
    config["multiagent"]["policies"] = {
        i: (None, tmp_env.observation_space, tmp_env.action_space, {})
        for i in tmp_env.agent_ids
    }
    config["multiagent"]["policy_mapping_fn"] = lambda x: x

    # check the model
    if config[DIVERSITY_ENCOURAGING] and config[USE_DIVERSITY_VALUE_NETWORK]:
        ModelCatalog.register_custom_model("ActorDoubleCriticNetwork",
                                           ActorDoubleCriticNetwork)

        config['model']['custom_model'] = "ActorDoubleCriticNetwork"
        config['model']['custom_options'] = {
            "use_novelty_value_network": config[USE_DIVERSITY_VALUE_NETWORK]
            # the name 'novelty' is deprecated
        }
    else:
        config['model']['custom_model'] = None
        config['model']['custom_options'] = None

    # Reduce the train batch size for each agent
    # PENGZH UPDATE 2019.01.12
    # We do not modify train batch size anymore, since the
    # issue in counting the SampleBatch is corrected.

    if config[REPLAY_VALUES]:
        # use vtrace, need to check sgd_minibatch_size
        assert config['sgd_minibatch_size'] % \
               (config['env_config']['num_agents'] *
                config['sample_batch_size']) == 0, \
            "sgd_minibatch_size: {}, num_agents: {}, sample_batch_size: {}" \
            "".format(config['sgd_minibatch_size'],
                      config['env_config']['num_agents'],
                      config['sample_batch_size'])
        assert config['sgd_minibatch_size'] >= (
            config['env_config']['num_agents'] * config['sample_batch_size'])

    validate_config_original(config)

    if not config[DIVERSITY_ENCOURAGING]:
        assert not config[USE_BISECTOR]
        assert not config[USE_DIVERSITY_VALUE_NETWORK]
        # assert not config[]

    assert config[CONSTRAIN_NOVELTY] in ['soft', 'hard', None]
示例#7
0
def get_connector(ctx: ConnectorContext, name: str, params: Tuple[Any]) -> Connector:
    """Get a connector by its name and serialized config.

    Args:
        name: name of the connector.
        params: serialized parameters of the connector.

    Returns:
        Constructed connector.
    """
    if not _global_registry.contains(RLLIB_CONNECTOR, name):
        raise NameError("connector not found.", name)
    cls = _global_registry.get(RLLIB_CONNECTOR, name)
    return cls.from_config(ctx, params)
示例#8
0
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = copy.deepcopy(self._default_config)
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.raw_user_config = config
        self.config = merged_config
        Trainer._validate_config(self.config)
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        def get_scope():
            if tf:
                return tf.Graph().as_default()
            else:
                return open("/dev/null")  # fake a no-op scope

        with get_scope():
            self._init(self.config, self.env_creator)

            # Evaluation related
            if self.config.get("evaluation_interval"):
                # Update env_config with evaluation settings:
                extra_config = copy.deepcopy(self.config["evaluation_config"])
                extra_config.update({
                    "batch_mode": "complete_episodes",
                    "batch_steps": 1,
                })
                logger.debug(
                    "using evaluation_config: {}".format(extra_config))
                self.evaluation_workers = self._make_workers(self.env_creator,
                                                             self._policy,
                                                             merge_dicts(
                                                                 self.config,
                                                                 extra_config),
                                                             num_workers=0)
                self.evaluation_metrics = self._evaluate()
示例#9
0
def get_termination_fn(env_id, env_config=None):
    """Return the termination funtion for the given environment name and configuration.

    Only returns for environments which have been registered with Tune.
    """
    assert env_id in TERMINATIONS, f"{env_id} environment termination not registered."
    assert _global_registry.contains(
        ENV_CREATOR, env_id
    ), f"{env_id} environment not registered with Tune."

    env_config = env_config or {}
    termination_fn = TERMINATIONS[env_id](env_config)
    if env_config.get("time_aware", False):
        termination_fn = TimeAwareTerminationFn(termination_fn)
    return termination_fn
示例#10
0
    def setup(self, config):
        self.config = config

        env = self._env_id
        if env is not None:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym
                self.env_creator = lambda env_config: gym.make(env)
        else:
            raise Exception('self._env_id should not be None.')

        self._policy = MuZeroTFPolicy

        self.workers = self._build_workers(self._global_vars)

        self._global_op = self._build_global_op(self.workers)
示例#11
0
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = copy.deepcopy(self._default_config)
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.raw_user_config = config
        self.config = merged_config
        Trainer._validate_config(self.config)
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        # TODO(ekl) setting the graph is unnecessary for PyTorch agents
        with tf.Graph().as_default():
            self._init(self.config, self.env_creator)

            # Evaluation related
            if self.config.get("evaluation_interval"):
                # Update env_config with evaluation settings:
                extra_config = copy.deepcopy(self.config["evaluation_config"])
                extra_config.update({
                    "batch_mode": "complete_episodes",
                    "batch_steps": 1,
                })
                logger.debug(
                    "using evaluation_config: {}".format(extra_config))
                # Make local evaluation evaluators
                self.evaluation_ev = self.make_local_evaluator(
                    self.env_creator,
                    self._policy_graph,
                    extra_config=extra_config)
                self.evaluation_metrics = self._evaluate()
示例#12
0
文件: agent.py 项目: gavinljj/ray
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default
        merged_config = self._default_config.copy()
        merged_config = deep_update(merged_config, config,
                                    self._allow_unknown_configs,
                                    self._allow_unknown_subkeys)
        self.config = merged_config

        # TODO(ekl) setting the graph is unnecessary for PyTorch agents
        with tf.Graph().as_default():
            self._init()
示例#13
0
文件: trainer.py 项目: zqxyz73/ray
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default, but store the
        # user-provided one.
        self.raw_user_config = config
        self.config = Trainer.merge_trainer_configs(self._default_config,
                                                    config)

        if self.config["normalize_actions"]:
            inner = self.env_creator

            def normalize(env):
                import gym  # soft dependency
                if not isinstance(env, gym.Env):
                    raise ValueError(
                        "Cannot apply NormalizeActionActionWrapper to env of "
                        "type {}, which does not subclass gym.Env.", type(env))
                return NormalizeActionWrapper(env)

            self.env_creator = lambda env_config: normalize(inner(env_config))

        Trainer._validate_config(self.config)
        log_level = self.config.get("log_level")
        if log_level in ["WARN", "ERROR"]:
            logger.info("Current log_level is {}. For more information, "
                        "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                        "-vv flags.".format(log_level))
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        def get_scope():
            if tf and not tf.executing_eagerly():
                return tf.Graph().as_default()
            else:
                return open("/dev/null")  # fake a no-op scope

        with get_scope():
            self._init(self.config, self.env_creator)

            # Evaluation setup.
            if self.config.get("evaluation_interval"):
                # Update env_config with evaluation settings:
                extra_config = copy.deepcopy(self.config["evaluation_config"])
                # Assert that user has not unset "in_evaluation".
                assert "in_evaluation" not in extra_config or \
                    extra_config["in_evaluation"] is True
                extra_config.update({
                    "batch_mode": "complete_episodes",
                    "batch_steps": 1,
                    "in_evaluation": True,
                })
                logger.debug(
                    "using evaluation_config: {}".format(extra_config))

                self.evaluation_workers = self._make_workers(
                    self.env_creator,
                    self._policy,
                    merge_dicts(self.config, extra_config),
                    num_workers=self.config["evaluation_num_workers"])
                self.evaluation_metrics = {}
示例#14
0
    def _setup(self, config: dict):
        env = self._env_id
        if env:
            config["env"] = env
            # An already registered env.
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            # A class specifier.
            elif "." in env:
                self.env_creator = \
                    lambda env_config: from_config(env, env_config)
            # Try gym.
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default, but store the
        # user-provided one.
        self.raw_user_config = config
        self.config = Trainer.merge_trainer_configs(self._default_config,
                                                    config)

        # Check and resolve DL framework settings.
        if "use_pytorch" in self.config and \
                self.config["use_pytorch"] != DEPRECATED_VALUE:
            deprecation_warning("use_pytorch", "framework=torch", error=False)
            if self.config["use_pytorch"]:
                self.config["framework"] = "torch"
            self.config.pop("use_pytorch")
        if "eager" in self.config and self.config["eager"] != DEPRECATED_VALUE:
            deprecation_warning("eager", "framework=tfe", error=False)
            if self.config["eager"]:
                self.config["framework"] = "tfe"
            self.config.pop("eager")

        # Enable eager/tracing support.
        if tf and self.config["framework"] == "tfe":
            if not tf.executing_eagerly():
                tf.enable_eager_execution()
            logger.info("Executing eagerly, with eager_tracing={}".format(
                self.config["eager_tracing"]))
        if tf and not tf.executing_eagerly() and \
                self.config["framework"] != "torch":
            logger.info("Tip: set framework=tfe or the --eager flag to enable "
                        "TensorFlow eager execution")

        if self.config["normalize_actions"]:
            inner = self.env_creator

            def normalize(env):
                import gym  # soft dependency
                if not isinstance(env, gym.Env):
                    raise ValueError(
                        "Cannot apply NormalizeActionActionWrapper to env of "
                        "type {}, which does not subclass gym.Env.", type(env))
                return NormalizeActionWrapper(env)

            self.env_creator = lambda env_config: normalize(inner(env_config))

        Trainer._validate_config(self.config)
        if not callable(self.config["callbacks"]):
            raise ValueError(
                "`callbacks` must be a callable method that "
                "returns a subclass of DefaultCallbacks, got {}".format(
                    self.config["callbacks"]))
        self.callbacks = self.config["callbacks"]()
        log_level = self.config.get("log_level")
        if log_level in ["WARN", "ERROR"]:
            logger.info("Current log_level is {}. For more information, "
                        "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                        "-vv flags.".format(log_level))
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        def get_scope():
            if tf and not tf.executing_eagerly():
                return tf.Graph().as_default()
            else:
                return open(os.devnull)  # fake a no-op scope

        with get_scope():
            self._init(self.config, self.env_creator)

            # Evaluation setup.
            if self.config.get("evaluation_interval"):
                # Update env_config with evaluation settings:
                extra_config = copy.deepcopy(self.config["evaluation_config"])
                # Assert that user has not unset "in_evaluation".
                assert "in_evaluation" not in extra_config or \
                    extra_config["in_evaluation"] is True
                extra_config.update({
                    "batch_mode": "complete_episodes",
                    "rollout_fragment_length": 1,
                    "in_evaluation": True,
                })
                logger.debug(
                    "using evaluation_config: {}".format(extra_config))

                self.evaluation_workers = self._make_workers(
                    self.env_creator,
                    self._policy,
                    merge_dicts(self.config, extra_config),
                    num_workers=self.config["evaluation_num_workers"])
                self.evaluation_metrics = {}
示例#15
0
def get_env_creator(env_id):
    """Return the environment creator funtion for the given environment id."""
    if not _global_registry.contains(ENV_CREATOR, env_id):
        raise ValueError(f"Environment id {env_id} not registered in Tune")
    return _global_registry.get(ENV_CREATOR, env_id)
示例#16
0
def has_env_creator(env_id: str) -> bool:
    "Whether and environment with the given id is in the global registry."
    return _global_registry.contains(ENV_CREATOR, env_id)