Ejemplo n.º 1
0
    def test_polynomial_schedule(self):
        ts = [0, 5, 10, 100, 90, 2, 1, 99, 23]
        config = dict(
            type="ray.rllib.utils.schedules.polynomial_schedule."
            "PolynomialSchedule",
            schedule_timesteps=100,
            initial_p=2.0,
            final_p=0.5,
            power=2.0)
        for fw in ["tf", "torch", None]:
            config["framework"] = fw
            polynomial = from_config(config)
            for t in ts:
                out = polynomial(t)
                check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)

        # Test eager as well.
        with eager_mode():
            config["framework"] = "tf"
            polynomial = from_config(config)
            for t in ts:
                out = polynomial(t)
                check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
Ejemplo n.º 2
0
 def test_linear_schedule(self):
     ts = [0, 50, 10, 100, 90, 2, 1, 99, 23]
     for fw in ["tf", "torch", None]:
         linear = from_config(
             LinearSchedule, {
                 "schedule_timesteps": 100,
                 "initial_p": 2.1,
                 "final_p": 0.6,
                 "framework": fw
             })
         if fw == "tf":
             tf.enable_eager_execution()
         for t in ts:
             out = linear(t)
             check(out, 2.1 - (t / 100) * (2.1 - 0.6), decimals=4)
Ejemplo n.º 3
0
    def __init__(
        self,
        action_space: gym.spaces.Space,
        *,
        framework: str,
        initial_epsilon: float = 1.0,
        final_epsilon: float = 0.05,
        epsilon_timesteps: int = int(1e5),
        epsilon_schedule: Optional[Schedule] = None,
        **kwargs,
    ):
        """Create an EpsilonGreedy exploration class.

        Args:
            action_space: The action space the exploration should occur in.
            framework: The framework specifier.
            initial_epsilon: The initial epsilon value to use.
            final_epsilon: The final epsilon value to use.
            epsilon_timesteps: The time step after which epsilon should
                always be `final_epsilon`.
            epsilon_schedule: An optional Schedule object
                to use (instead of constructing one from the given parameters).
        """
        assert framework is not None
        super().__init__(action_space=action_space,
                         framework=framework,
                         **kwargs)

        self.epsilon_schedule = from_config(
            Schedule, epsilon_schedule,
            framework=framework) or PiecewiseSchedule(
                endpoints=[(0, initial_epsilon),
                           (epsilon_timesteps, final_epsilon)],
                outside_value=final_epsilon,
                framework=self.framework,
            )

        # The current timestep value (tf-var or python int).
        self.last_timestep = get_variable(
            np.array(0, np.int64),
            framework=framework,
            tf_name="timestep",
            dtype=np.int64,
        )

        # Build the tf-info-op.
        if self.framework == "tf":
            self._tf_state_op = self.get_state()
Ejemplo n.º 4
0
    def test_polynomial_schedule(self):
        ts = [0, 5, 10, 100, 90, 2, 1, 99, 23, 1000]
        config = dict(type="ray.rllib.utils.schedules.polynomial_schedule."
                      "PolynomialSchedule",
                      schedule_timesteps=100,
                      initial_p=2.0,
                      final_p=0.5,
                      power=2.0)

        for fw in framework_iterator(frameworks=["tf", "tfe", "torch", None]):
            fw_ = fw if fw != "tfe" else "tf"
            polynomial = from_config(config, framework=fw_)
            for t in ts:
                out = polynomial(t)
                t = min(t, 100)
                check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
Ejemplo n.º 5
0
 def test_polynomial_schedule(self):
     ts = [0, 5, 10, 100, 90, 2, 1, 99, 23]
     for fw in ["tf", "torch", None]:
         polynomial = from_config(
             dict(type="ray.rllib.utils.schedules.polynomial_schedule."
                  "PolynomialSchedule",
                  schedule_timesteps=100,
                  initial_p=2.0,
                  final_p=0.5,
                  power=2.0,
                  framework=fw))
         if fw == "tf":
             tf.enable_eager_execution()
         for t in ts:
             out = polynomial(t)
             check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
Ejemplo n.º 6
0
    def test_linear_schedule(self):
        ts = [0, 50, 10, 100, 90, 2, 1, 99, 23, 1000]
        expected = [2.1 - (min(t, 100) / 100) * (2.1 - 0.6) for t in ts]
        config = {"schedule_timesteps": 100, "initial_p": 2.1, "final_p": 0.6}

        for fw in framework_iterator(
                frameworks=["tf2", "tf", "tfe", "torch", None]):
            linear = from_config(LinearSchedule, config, framework=fw)
            for t, e in zip(ts, expected):
                out = linear(t)
                check(out, e, decimals=4)

            ts_as_tensors = self._get_framework_tensors(ts, fw)
            for t, e in zip(ts_as_tensors, expected):
                out = linear(t)
                assert fw != "tf" or isinstance(out, tf.Tensor)
                check(out, e, decimals=4)
Ejemplo n.º 7
0
    def test_constant_schedule(self):
        value = 2.3
        ts = [100, 0, 10, 2, 3, 4, 99, 56, 10000, 23, 234, 56]

        config = {"value": value}

        for fw in framework_iterator(
                frameworks=["tf2", "tf", "tfe", "torch", None]):
            constant = from_config(ConstantSchedule, config, framework=fw)
            for t in ts:
                out = constant(t)
                check(out, value)

            ts_as_tensors = self._get_framework_tensors(ts, fw)
            for t in ts_as_tensors:
                out = constant(t)
                assert fw != "tf" or isinstance(out, tf.Tensor)
                check(out, value, decimals=4)
Ejemplo n.º 8
0
    def test_piecewise_schedule(self):
        ts = [0, 5, 10, 100, 90, 2, 1, 99, 27]
        expected = [50.0, 60.0, 70.0, 14.5, 14.5, 54.0, 52.0, 14.5, 140.0]
        config = dict(endpoints=[(0, 50.0), (25, 100.0), (30, 200.0)],
                      outside_value=14.5)

        for fw in framework_iterator(
                frameworks=["tf2", "tf", "tfe", "torch", None]):
            piecewise = from_config(PiecewiseSchedule, config, framework=fw)
            for t, e in zip(ts, expected):
                out = piecewise(t)
                check(out, e, decimals=4)

            ts_as_tensors = self._get_framework_tensors(ts, fw)
            for t, e in zip(ts_as_tensors, expected):
                out = piecewise(t)
                assert fw != "tf" or isinstance(out, tf.Tensor)
                check(out, e, decimals=4)
Ejemplo n.º 9
0
    def _create_exploration(self, action_space, config):
        """Creates the Policy's Exploration object.

        This method only exists b/c some Trainers do not use TfPolicy nor
        TorchPolicy, but inherit directly from Policy. Others inherit from
        TfPolicy w/o using DynamicTfPolicy.
        TODO(sven): unify these cases."""
        exploration = from_config(Exploration,
                                  config.get("exploration_config",
                                             {"type": "StochasticSampling"}),
                                  action_space=action_space,
                                  num_workers=config.get("num_workers", 0),
                                  worker_index=config.get("worker_index", 0),
                                  framework=getattr(self, "framework", "tf"))
        # If config is further passed around, it'll contain an already
        # instantiated object.
        config["exploration_config"] = exploration
        return exploration
Ejemplo n.º 10
0
    def __init__(
        self,
        action_space,
        *,
        framework,
        initial_temperature=1.0,
        final_temperature=1e-6,
        temperature_timesteps=int(1e5),
        temperature_schedule=None,
        **kwargs,
    ):
        """Initializes a SoftQ Exploration object.

        Args:
            action_space (Space): The gym action space used by the environment.
            temperature (Schedule): The temperature to divide model outputs by
                before creating the Categorical distribution to sample from.
            framework (str): One of None, "tf", "torch".
            temperature_schedule (Optional[Schedule]): An optional Schedule
                object to use (instead of constructing one from the given
                parameters).
        """
        assert isinstance(action_space, Discrete)
        super().__init__(action_space, framework=framework, **kwargs)

        self.temperature_schedule = from_config(
            Schedule, temperature_schedule, framework=framework
        ) or PiecewiseSchedule(
            endpoints=[
                (0, initial_temperature),
                (temperature_timesteps, final_temperature),
            ],
            outside_value=final_temperature,
            framework=self.framework,
        )

        # The current timestep value (tf-var or python int).
        self.last_timestep = get_variable(
            0, framework=framework, tf_name="timestep"
        )
        self.temperature = self.temperature_schedule(self.last_timestep)
Ejemplo n.º 11
0
    def __init__(self,
                 action_space,
                 *,
                 framework: str,
                 initial_epsilon=1.0,
                 final_epsilon=0.05,
                 epsilon_timesteps=int(1e5),
                 epsilon_schedule=None,
                 **kwargs):
        """Create an EpsilonGreedy exploration class.

        Args:
            initial_epsilon (float): The initial epsilon value to use.
            final_epsilon (float): The final epsilon value to use.
            epsilon_timesteps (int): The time step after which epsilon should
                always be `final_epsilon`.
            epsilon_schedule (Optional[Schedule]): An optional Schedule object
                to use (instead of constructing one from the given parameters).
        """
        assert framework is not None
        super().__init__(action_space=action_space,
                         framework=framework,
                         **kwargs)

        self.epsilon_schedule = \
            from_config(Schedule, epsilon_schedule, framework=framework) or \
            PiecewiseSchedule(
                endpoints=[
                    (0, initial_epsilon), (epsilon_timesteps, final_epsilon)],
                outside_value=final_epsilon,
                framework=self.framework)

        # The current timestep value (tf-var or python int).
        self.last_timestep = get_variable(0,
                                          framework=framework,
                                          tf_name="timestep")

        # Build the tf-info-op.
        if self.framework == "tf":
            raise ValueError("Torch version does not support "
                             "multiobj episilon-greedy yet!")
Ejemplo n.º 12
0
    def _create_exploration(self):
        """Creates the Policy's Exploration object.

        This method only exists b/c some Trainers do not use TfPolicy nor
        TorchPolicy, but inherit directly from Policy. Others inherit from
        TfPolicy w/o using DynamicTfPolicy.
        TODO(sven): unify these cases."""
        if getattr(self, "exploration", None) is not None:
            return self.exploration

        exploration = from_config(
            Exploration,
            self.config.get("exploration_config",
                            {"type": "StochasticSampling"}),
            action_space=self.action_space,
            policy_config=self.config,
            model=getattr(self, "model", None),
            num_workers=self.config.get("num_workers", 0),
            worker_index=self.config.get("worker_index", 0),
            framework=getattr(self, "framework", "tf"))
        return exploration
Ejemplo n.º 13
0
    def test_exponential_schedule(self):
        decay_rate = 0.2
        ts = [0, 5, 10, 100, 90, 2, 1, 99, 23]
        expected = [2.0 * decay_rate**(t / 100) for t in ts]
        config = dict(initial_p=2.0,
                      decay_rate=decay_rate,
                      schedule_timesteps=100)

        for fw in framework_iterator(
                frameworks=["tf2", "tf", "tfe", "torch", None]):
            exponential = from_config(ExponentialSchedule,
                                      config,
                                      framework=fw)
            for t, e in zip(ts, expected):
                out = exponential(t)
                check(out, e, decimals=4)

            ts_as_tensors = self._get_framework_tensors(ts, fw)
            for t, e in zip(ts_as_tensors, expected):
                out = exponential(t)
                assert fw != "tf" or isinstance(out, tf.Tensor)
                check(out, e, decimals=4)
Ejemplo n.º 14
0
    def test_dummy_components(self):
        # Switch on eager for testing purposes.
        tf.enable_eager_execution()

        # Try to create from an abstract class w/o default constructor.
        # Expect None.
        test = from_config({
            "type": AbstractDummyComponent,
            "framework": "torch"
        })
        check(test, None)

        # Create a Component via python API (config dict).
        component = from_config(
            dict(type=DummyComponent, prop_a=1.0, prop_d="non_default"))
        check(component.prop_d, "non_default")

        # Create a tf Component from json file.
        component = from_config("dummy_config.json")
        check(component.prop_c, "default")
        check(component.prop_d, 4)  # default
        check(component.add(3.3).numpy(), 5.3)  # prop_b == 2.0

        # Create a torch Component from yaml file.
        component = from_config("dummy_config.yml")
        check(component.prop_a, "something else")
        check(component.prop_d, 3)
        check(component.add(1.2), torch.Tensor([2.2]))  # prop_b == 1.0

        # Create tf Component from json-string (e.g. on command line).
        component = from_config(
            '{"type": "ray.rllib.utils.tests.'
            'test_framework_agnostic_components.DummyComponent", '
            '"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default"}')
        check(component.prop_a, "A")
        check(component.prop_d, 4)  # default
        check(component.add(-1.1).numpy(), -2.1)  # prop_b == -1.0

        # Create torch Component from yaml-string.
        component = from_config(
            "type: ray.rllib.utils.tests."
            "test_framework_agnostic_components.DummyComponent\n"
            "prop_a: B\nprop_b: -1.5\nprop_c: non-default\nframework: torch")
        check(component.prop_a, "B")
        check(component.prop_d, 4)  # default
        check(component.add(-5.1), torch.Tensor([-6.6]))  # prop_b == -1.5
Ejemplo n.º 15
0
    def test_polynomial_schedule(self):
        ts = [0, 5, 10, 100, 90, 2, 1, 99, 23, 1000]
        expected = [
            0.5 + (2.0 - 0.5) * (1.0 - min(t, 100) / 100)**2 for t in ts
        ]
        config = dict(type="ray.rllib.utils.schedules.polynomial_schedule."
                      "PolynomialSchedule",
                      schedule_timesteps=100,
                      initial_p=2.0,
                      final_p=0.5,
                      power=2.0)

        for fw in framework_iterator(
                frameworks=["tf2", "tf", "tfe", "torch", None]):
            polynomial = from_config(config, framework=fw)
            for t, e in zip(ts, expected):
                out = polynomial(t)
                check(out, e, decimals=4)

            ts_as_tensors = self._get_framework_tensors(ts, fw)
            for t, e in zip(ts_as_tensors, expected):
                out = polynomial(t)
                assert fw != "tf" or isinstance(out, tf.Tensor)
                check(out, e, decimals=4)
Ejemplo n.º 16
0
    def test_dummy_components(self):
        # Switch on eager for testing purposes.
        tf.enable_eager_execution()

        # Bazel makes it hard to find files specified in `args` (and `data`).
        # Use the true absolute path.
        script_dir = Path(__file__).parent
        abs_path = script_dir.absolute()

        # Try to create from an abstract class w/o default constructor.
        # Expect None.
        test = from_config({
            "type": AbstractDummyComponent,
            "framework": "torch"
        })
        check(test, None)

        # Create a Component via python API (config dict).
        component = from_config(
            dict(type=DummyComponent, prop_a=1.0, prop_d="non_default"))
        check(component.prop_d, "non_default")

        # Create a tf Component from json file.
        config_file = str(abs_path.joinpath("dummy_config.json"))
        component = from_config(config_file)
        check(component.prop_c, "default")
        check(component.prop_d, 4)  # default
        check(component.add(3.3).numpy(), 5.3)  # prop_b == 2.0

        # Create a torch Component from yaml file.
        config_file = str(abs_path.joinpath("dummy_config.yml"))
        component = from_config(config_file)
        check(component.prop_a, "something else")
        check(component.prop_d, 3)
        check(component.add(1.2), np.array([2.2]))  # prop_b == 1.0

        # Create tf Component from json-string (e.g. on command line).
        component = from_config(
            '{"type": "ray.rllib.utils.tests.'
            'test_framework_agnostic_components.DummyComponent", '
            '"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default"}')
        check(component.prop_a, "A")
        check(component.prop_d, 4)  # default
        check(component.add(-1.1).numpy(), -2.1)  # prop_b == -1.0

        # Test recognizing default module path.
        component = from_config(
            DummyComponent, '{"type": "NonAbstractChildOfDummyComponent", '
            '"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default"}')
        check(component.prop_a, "A")
        check(component.prop_d, 4)  # default
        check(component.add(-1.1).numpy(), -2.1)  # prop_b == -1.0

        # Test recognizing default package path.
        component = from_config(Exploration, {
            "type": "EpsilonGreedy",
            "action_space": Discrete(2)
        })
        check(component.epsilon_schedule.outside_value, 0.05)  # default

        # Create torch Component from yaml-string.
        component = from_config(
            "type: ray.rllib.utils.tests."
            "test_framework_agnostic_components.DummyComponent\n"
            "prop_a: B\nprop_b: -1.5\nprop_c: non-default\nframework: torch")
        check(component.prop_a, "B")
        check(component.prop_d, 4)  # default
        check(component.add(-5.1), np.array([-6.6]))  # prop_b == -1.5
Ejemplo n.º 17
0
    def __init__(self,
                 action_space: Space,
                 *,
                 framework: str,
                 model: ModelV2,
                 feature_dim: int = 288,
                 feature_net_config: Optional[ModelConfigDict] = None,
                 inverse_net_hiddens: Tuple[int] = (256, ),
                 inverse_net_activation: str = "relu",
                 forward_net_hiddens: Tuple[int] = (256, ),
                 forward_net_activation: str = "relu",
                 beta: float = 0.2,
                 eta: float = 1.0,
                 lr: float = 1e-3,
                 sub_exploration: Optional[FromConfigSpec] = None,
                 **kwargs):
        """Initializes a Curiosity object.

        Uses as defaults the hyperparameters described in [1].

        Args:
             feature_dim (int): The dimensionality of the feature (phi)
                vectors.
             feature_net_config (Optional[ModelConfigDict]): Optional model
                configuration for the feature network, producing feature
                vectors (phi) from observations. This can be used to configure
                fcnet- or conv_net setups to properly process any observation
                space.
             inverse_net_hiddens (Tuple[int]): Tuple of the layer sizes of the
                inverse (action predicting) NN head (on top of the feature
                outputs for phi and phi').
             inverse_net_activation (str): Activation specifier for the inverse
                net.
             forward_net_hiddens (Tuple[int]): Tuple of the layer sizes of the
                forward (phi' predicting) NN head.
             forward_net_activation (str): Activation specifier for the forward
                net.
             beta (float): Weight for the forward loss (over the inverse loss,
                which gets weight=1.0-beta) in the common loss term.
             eta (float): Weight for intrinsic rewards before being added to
                extrinsic ones.
             lr (float): The learning rate for the curiosity-specific
                optimizer, optimizing feature-, inverse-, and forward nets.
             sub_exploration (Optional[FromConfigSpec]): The config dict for
                the underlying Exploration to use (e.g. epsilon-greedy for
                DQN). If None, uses the FromSpecDict provided in the Policy's
                default config.
        """
        if not isinstance(action_space, (Discrete, MultiDiscrete)):
            raise ValueError(
                "Only (Multi)Discrete action spaces supported for Curiosity "
                "so far!")

        super().__init__(
            action_space, model=model, framework=framework, **kwargs)

        if self.policy_config["num_workers"] != 0:
            raise ValueError(
                "Curiosity exploration currently does not support parallelism."
                " `num_workers` must be 0!")

        self.feature_dim = feature_dim
        if feature_net_config is None:
            feature_net_config = self.policy_config["model"].copy()
        self.feature_net_config = feature_net_config
        self.inverse_net_hiddens = inverse_net_hiddens
        self.inverse_net_activation = inverse_net_activation
        self.forward_net_hiddens = forward_net_hiddens
        self.forward_net_activation = forward_net_activation

        self.action_dim = self.action_space.n if isinstance(
            self.action_space, Discrete) else np.sum(self.action_space.nvec)

        self.beta = beta
        self.eta = eta
        self.lr = lr
        # TODO: (sven) if sub_exploration is None, use Trainer's default
        #  Exploration config.
        if sub_exploration is None:
            raise NotImplementedError
        self.sub_exploration = sub_exploration

        # Creates modules/layers inside the actual ModelV2.
        self._curiosity_feature_net = ModelCatalog.get_model_v2(
            self.model.obs_space,
            self.action_space,
            self.feature_dim,
            model_config=self.feature_net_config,
            framework=self.framework,
            name="feature_net",
        )

        self._curiosity_inverse_fcnet = self._create_fc_net(
            [2 * self.feature_dim] + list(self.inverse_net_hiddens) +
            [self.action_dim],
            self.inverse_net_activation,
            name="inverse_net")

        self._curiosity_forward_fcnet = self._create_fc_net(
            [self.feature_dim + self.action_dim] + list(
                self.forward_net_hiddens) + [self.feature_dim],
            self.forward_net_activation,
            name="forward_net")

        # This is only used to select the correct action
        self.exploration_submodule = from_config(
            cls=Exploration,
            config=self.sub_exploration,
            action_space=self.action_space,
            framework=self.framework,
            policy_config=self.policy_config,
            model=self.model,
            num_workers=self.num_workers,
            worker_index=self.worker_index,
        )
Ejemplo n.º 18
0
def check_multi_agent(
    config: PartialTrainerConfigDict,
) -> Tuple[MultiAgentPolicyConfigDict, bool]:
    """Checks, whether a (partial) config defines a multi-agent setup.

    Args:
        config: The user/Trainer/Policy config to check for multi-agent.

    Returns:
        Tuple consisting of the resulting (all fixed) multi-agent policy
        dict and bool indicating whether we have a multi-agent setup or not.

    Raises:
        KeyError: If `config` does not contain a "multiagent" key or if there
            is an invalid key inside the "multiagent" config or if any policy
            in the "policies" dict has a non-str ID (key).
        ValueError: If any subkey of the "multiagent" dict has an invalid
            value.
    """
    if "multiagent" not in config:
        raise KeyError(
            "Your `config` to be checked for a multi-agent setup must have "
            "the 'multiagent' key defined!"
        )
    multiagent_config = config["multiagent"]

    policies = multiagent_config.get("policies")

    # Check for invalid sub-keys of multiagent config.
    from ray.rllib.agents.trainer import COMMON_CONFIG

    allowed = list(COMMON_CONFIG["multiagent"].keys())
    if any(k not in allowed for k in multiagent_config.keys()):
        raise KeyError(
            f"You have invalid keys in your 'multiagent' config dict! "
            f"The only allowed keys are: {allowed}."
        )

    # Nothing specified in config dict -> Assume simple single agent setup
    # with DEFAULT_POLICY_ID as only policy.
    if not policies:
        policies = {DEFAULT_POLICY_ID}
    # Policies given as set/list/tuple (of PolicyIDs) -> Setup each policy
    # automatically via empty PolicySpec (will make RLlib infer obs- and action spaces
    # as well as the Policy's class).
    if isinstance(policies, (set, list, tuple)):
        policies = multiagent_config["policies"] = {
            pid: PolicySpec() for pid in policies
        }

    # Check each defined policy ID and spec.
    for pid, policy_spec in policies.copy().items():
        # Policy IDs must be strings.
        if not isinstance(pid, str):
            raise KeyError(f"Policy IDs must always be of type `str`, got {type(pid)}")
        # Convert to PolicySpec if plain list/tuple.
        if not isinstance(policy_spec, PolicySpec):
            # Values must be lists/tuples of len 4.
            if not isinstance(policy_spec, (list, tuple)) or len(policy_spec) != 4:
                raise ValueError(
                    "Policy specs must be tuples/lists of "
                    "(cls or None, obs_space, action_space, config), "
                    f"got {policy_spec}"
                )
            policies[pid] = PolicySpec(*policy_spec)

        # Config is None -> Set to {}.
        if policies[pid].config is None:
            policies[pid] = policies[pid]._replace(config={})
        # Config not a dict.
        elif not isinstance(policies[pid].config, dict):
            raise ValueError(
                f"Multiagent policy config for {pid} must be a dict, "
                f"but got {type(policies[pid].config)}!"
            )

    # Check other "multiagent" sub-keys' values.
    if multiagent_config.get("count_steps_by", "env_steps") not in [
        "env_steps",
        "agent_steps",
    ]:
        raise ValueError(
            "config.multiagent.count_steps_by must be one of "
            "[env_steps|agent_steps], not "
            f"{multiagent_config['count_steps_by']}!"
        )
    if multiagent_config.get("replay_mode", "independent") not in [
        "independent",
        "lockstep",
    ]:
        raise ValueError(
            "`config.multiagent.replay_mode` must be "
            "[independent|lockstep], not "
            f"{multiagent_config['replay_mode']}!"
        )
    # Attempt to create a `policy_mapping_fn` from config dict. Helpful
    # is users would like to specify custom callable classes in yaml files.
    if isinstance(multiagent_config.get("policy_mapping_fn"), dict):
        multiagent_config["policy_mapping_fn"] = from_config(
            multiagent_config["policy_mapping_fn"]
        )
    # Check `policies_to_train` for invalid entries.
    if isinstance(multiagent_config["policies_to_train"], (list, set, tuple)):
        if len(multiagent_config["policies_to_train"]) == 0:
            logger.warning(
                "`config.multiagent.policies_to_train` is empty! "
                "Make sure - if you would like to learn at least one policy - "
                "to add its ID to that list."
            )
        for pid in multiagent_config["policies_to_train"]:
            if pid not in policies:
                raise ValueError(
                    "`config.multiagent.policies_to_train` contains policy "
                    f"ID ({pid}) that was not defined in `config.multiagent.policies!"
                )

    # Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only
    # PolicyID found in policies dict.
    is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies
    return policies, is_multiagent
Ejemplo n.º 19
0
    def __init__(self,
                 action_space,
                 *,
                 framework: str,
                 policy_config: dict,
                 model: ModelV2,
                 initial_stddev: float = 1.0,
                 random_timesteps: int = 10000,
                 sub_exploration: Optional[dict] = None,
                 **kwargs):
        """Initializes a ParameterNoise Exploration object.

        Args:
            initial_stddev (float): The initial stddev to use for the noise.
            random_timesteps (int): The number of timesteps to act completely
                randomly (see [1]).
            sub_exploration (Optional[dict]): Optional sub-exploration config.
                None for auto-detection/setup.
        """
        assert framework is not None
        super().__init__(action_space,
                         policy_config=policy_config,
                         model=model,
                         framework=framework,
                         **kwargs)

        self.stddev = get_variable(initial_stddev,
                                   framework=self.framework,
                                   tf_name="stddev")
        self.stddev_val = initial_stddev  # Out-of-graph tf value holder.

        # The weight variables of the Model where noise should be applied to.
        # This excludes any variable, whose name contains "LayerNorm" (those
        # are BatchNormalization layers, which should not be perturbed).
        self.model_variables = [
            v for k, v in self.model.trainable_variables(as_dict=True).items()
            if "LayerNorm" not in k
        ]
        # Our noise to be added to the weights. Each item in `self.noise`
        # corresponds to one Model variable and holding the Gaussian noise to
        # be added to that variable (weight).
        self.noise = []
        for var in self.model_variables:
            name_ = var.name.split(":")[0] + "_noisy" if var.name else ""
            self.noise.append(
                get_variable(np.zeros(var.shape, dtype=np.float32),
                             framework=self.framework,
                             tf_name=name_,
                             torch_tensor=True,
                             device=self.device))

        # tf-specific ops to sample, assign and remove noise.
        if self.framework == "tf" and not tf.executing_eagerly():
            self.tf_sample_new_noise_op = \
                self._tf_sample_new_noise_op()
            self.tf_add_stored_noise_op = \
                self._tf_add_stored_noise_op()
            self.tf_remove_noise_op = \
                self._tf_remove_noise_op()
            # Create convenience sample+add op for tf.
            with tf1.control_dependencies([self.tf_sample_new_noise_op]):
                add_op = self._tf_add_stored_noise_op()
            with tf1.control_dependencies([add_op]):
                self.tf_sample_new_noise_and_add_op = tf.no_op()

        # Whether the Model's weights currently have noise added or not.
        self.weights_are_currently_noisy = False

        # Auto-detection of underlying exploration functionality.
        if sub_exploration is None:
            # For discrete action spaces, use an underlying EpsilonGreedy with
            # a special schedule.
            if isinstance(self.action_space, Discrete):
                sub_exploration = {
                    "type": "EpsilonGreedy",
                    "epsilon_schedule": {
                        "type":
                        "PiecewiseSchedule",
                        # Step function (see [2]).
                        "endpoints": [(0, 1.0), (random_timesteps + 1, 1.0),
                                      (random_timesteps + 2, 0.01)],
                        "outside_value":
                        0.01
                    }
                }
            elif isinstance(self.action_space, Box):
                sub_exploration = {
                    "type": "OrnsteinUhlenbeckNoise",
                    "random_timesteps": random_timesteps,
                }
            # TODO(sven): Implement for any action space.
            else:
                raise NotImplementedError

        self.sub_exploration = from_config(Exploration,
                                           sub_exploration,
                                           framework=self.framework,
                                           action_space=self.action_space,
                                           policy_config=self.policy_config,
                                           model=self.model,
                                           **kwargs)

        # Whether we need to call `self._delayed_on_episode_start` before
        # the forward pass.
        self.episode_started = False
Ejemplo n.º 20
0
    def default_resource_request(cls, config):
        cf = dict(cls.get_default_config(), **config)
        # Construct a dummy LeagueBuilder, such that it gets the opportunity to
        # adjust the multiagent config, according to its setup, and we can then
        # properly infer the resources to allocate.
        from_config(cf["league_builder_config"], trainer=None, trainer_config=cf)

        max_num_policies_to_train = cf["max_num_policies_to_train"] or len(
            cf["multiagent"].get("policies_to_train") or cf["multiagent"]["policies"]
        )
        num_learner_shards = min(
            cf["num_gpus"] or max_num_policies_to_train, max_num_policies_to_train
        )
        num_gpus_per_shard = cf["num_gpus"] / num_learner_shards
        num_policies_per_shard = max_num_policies_to_train / num_learner_shards

        fake_gpus = cf["_fake_gpus"]

        eval_config = cf["evaluation_config"]

        # Return PlacementGroupFactory containing all needed resources
        # (already properly defined as device bundles).
        return PlacementGroupFactory(
            bundles=[
                {
                    # Driver (no GPUs).
                    "CPU": cf["num_cpus_for_driver"],
                }
            ]
            + [
                {
                    # RolloutWorkers (no GPUs).
                    "CPU": cf["num_cpus_per_worker"],
                }
                for _ in range(cf["num_workers"])
            ]
            + [
                {
                    # Policy learners (and Replay buffer shards).
                    # 1 CPU for the replay buffer.
                    # 1 CPU (or fractional GPU) for each learning policy.
                    "CPU": 1 + (num_policies_per_shard if fake_gpus else 0),
                    "GPU": 0 if fake_gpus else num_gpus_per_shard,
                }
                for _ in range(num_learner_shards)
            ]
            + (
                [
                    {
                        # Evaluation (remote) workers.
                        # Note: The local eval worker is located on the driver
                        # CPU or not even created iff >0 eval workers.
                        "CPU": eval_config.get(
                            "num_cpus_per_worker", cf["num_cpus_per_worker"]
                        ),
                    }
                    for _ in range(cf["evaluation_num_workers"])
                ]
                if cf["evaluation_interval"]
                else []
            ),
            strategy=config.get("placement_strategy", "PACK"),
        )
Ejemplo n.º 21
0
    def _make_worker(
        self,
        *,
        cls: Callable,
        env_creator: EnvCreator,
        validate_env: Optional[Callable[[EnvType], None]],
        policy_cls: Type[Policy],
        worker_index: int,
        num_workers: int,
        recreated_worker: bool = False,
        config: AlgorithmConfigDict,
        spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
                                              gym.spaces.Space]]] = None,
    ) -> Union[RolloutWorker, ActorHandle]:
        def session_creator():
            logger.debug("Creating TF session {}".format(
                config["tf_session_args"]))
            return tf1.Session(config=tf1.ConfigProto(
                **config["tf_session_args"]))

        def valid_module(class_path):
            if (isinstance(class_path, str) and not os.path.isfile(class_path)
                    and "." in class_path):
                module_path, class_name = class_path.rsplit(".", 1)
                try:
                    spec = importlib.util.find_spec(module_path)
                    if spec is not None:
                        return True
                except (ModuleNotFoundError, ValueError):
                    print(
                        f"module {module_path} not found while trying to get "
                        f"input {class_path}")
            return False

        # A callable returning an InputReader object to use.
        if isinstance(config["input"], FunctionType):
            input_creator = config["input"]
        # Use RLlib's Sampler classes (SyncSampler or AsynchSampler, depending
        # on `config.sample_async` setting).
        elif config["input"] == "sampler":
            input_creator = lambda ioctx: ioctx.default_sampler_input()
        # Ray Dataset input -> Use `config.input_config` to construct DatasetReader.
        elif config["input"] == "dataset":
            # Input dataset shards should have already been prepared.
            # We just need to take the proper shard here.
            input_creator = lambda ioctx: DatasetReader(
                ioctx, self._ds_shards[worker_index])
        # Dict: Mix of different input methods with different ratios.
        elif isinstance(config["input"], dict):
            input_creator = lambda ioctx: ShuffledInput(
                MixedInput(config["input"], ioctx), config[
                    "shuffle_buffer_size"])
        # A pre-registered input descriptor (str).
        elif isinstance(config["input"], str) and registry_contains_input(
                config["input"]):
            input_creator = registry_get_input(config["input"])
        # D4RL input.
        elif "d4rl" in config["input"]:
            env_name = config["input"].split(".")[-1]
            input_creator = lambda ioctx: D4RLReader(env_name, ioctx)
        # Valid python module (class path) -> Create using `from_config`.
        elif valid_module(config["input"]):
            input_creator = lambda ioctx: ShuffledInput(
                from_config(config["input"], ioctx=ioctx))
        # JSON file or list of JSON files -> Use JsonReader (shuffled).
        else:
            input_creator = lambda ioctx: ShuffledInput(
                JsonReader(config["input"], ioctx), config[
                    "shuffle_buffer_size"])

        if isinstance(config["output"], FunctionType):
            output_creator = config["output"]
        elif config["output"] is None:
            output_creator = lambda ioctx: NoopOutput()
        elif config["output"] == "dataset":
            output_creator = lambda ioctx: DatasetWriter(
                ioctx, compress_columns=config["output_compress_columns"])
        elif config["output"] == "logdir":
            output_creator = lambda ioctx: JsonWriter(
                ioctx.log_dir,
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"],
            )
        else:
            output_creator = lambda ioctx: JsonWriter(
                config["output"],
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"],
            )

        # Assert everything is correct in "multiagent" config dict (if given).
        ma_policies = config["multiagent"]["policies"]
        if ma_policies:
            for pid, policy_spec in ma_policies.copy().items():
                assert isinstance(policy_spec, PolicySpec)
                # Class is None -> Use `policy_cls`.
                if policy_spec.policy_class is None:
                    ma_policies[pid].policy_class = policy_cls
            policies = ma_policies

        # Create a policy_spec (MultiAgentPolicyConfigDict),
        # even if no "multiagent" setup given by user.
        else:
            policies = policy_cls

        if worker_index == 0:
            extra_python_environs = config.get(
                "extra_python_environs_for_driver", None)
        else:
            extra_python_environs = config.get(
                "extra_python_environs_for_worker", None)

        worker = cls(
            env_creator=env_creator,
            validate_env=validate_env,
            policy_spec=policies,
            policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
            policies_to_train=config["multiagent"]["policies_to_train"],
            tf_session_creator=(session_creator
                                if config["tf_session_args"] else None),
            rollout_fragment_length=config["rollout_fragment_length"],
            count_steps_by=config["multiagent"]["count_steps_by"],
            batch_mode=config["batch_mode"],
            episode_horizon=config["horizon"],
            preprocessor_pref=config["preprocessor_pref"],
            sample_async=config["sample_async"],
            compress_observations=config["compress_observations"],
            num_envs=config["num_envs_per_worker"],
            observation_fn=config["multiagent"]["observation_fn"],
            observation_filter=config["observation_filter"],
            clip_rewards=config["clip_rewards"],
            normalize_actions=config["normalize_actions"],
            clip_actions=config["clip_actions"],
            env_config=config["env_config"],
            policy_config=config,
            worker_index=worker_index,
            num_workers=num_workers,
            recreated_worker=recreated_worker,
            log_dir=self._logdir,
            log_level=config["log_level"],
            callbacks=config["callbacks"],
            input_creator=input_creator,
            output_creator=output_creator,
            remote_worker_envs=config["remote_worker_envs"],
            remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
            soft_horizon=config["soft_horizon"],
            no_done_at_end=config["no_done_at_end"],
            seed=(config["seed"] +
                  worker_index) if config["seed"] is not None else None,
            fake_sampler=config["fake_sampler"],
            extra_python_environs=extra_python_environs,
            spaces=spaces,
            disable_env_checking=config["disable_env_checking"],
        )

        return worker
Ejemplo n.º 22
0
def validate_buffer_config(config: dict):
    if config.get("replay_buffer_config", None) is None:
        config["replay_buffer_config"] = {}

    prioritized_replay = config.get("prioritized_replay")
    if prioritized_replay != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['prioritized_replay']",
            help="Replay prioritization specified at new location config["
            "'replay_buffer_config']["
            "'prioritized_replay'] will be overwritten.",
            error=False,
        )
        config["replay_buffer_config"][
            "prioritized_replay"] = prioritized_replay

    capacity = config.get("buffer_size", DEPRECATED_VALUE)
    if capacity != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['buffer_size']",
            help="Buffer size specified at new location config["
            "'replay_buffer_config']["
            "'capacity'] will be overwritten.",
            error=False,
        )
        config["replay_buffer_config"]["capacity"] = capacity

    # Deprecation of old-style replay buffer args
    # Warnings before checking of we need local buffer so that algorithms
    # Without local buffer also get warned
    deprecated_replay_buffer_keys = [
        "prioritized_replay_alpha",
        "prioritized_replay_beta",
        "prioritized_replay_eps",
        "learning_starts",
    ]
    for k in deprecated_replay_buffer_keys:
        if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
            deprecation_warning(
                old="config[{}]".format(k),
                help="config['replay_buffer_config'][{}] should be used "
                "for Q-Learning algorithms. Ignore this warning if "
                "you are not using a Q-Learning algorithm and still "
                "provide {}."
                "".format(k, k),
                error=False,
            )
            # Copy values over to new location in config to support new
            # and old configuration style
            if config.get("replay_buffer_config") is not None:
                config["replay_buffer_config"][k] = config[k]

    # Old Ape-X configs may contain no_local_replay_buffer
    no_local_replay_buffer = config.get("no_local_replay_buffer", False)
    if no_local_replay_buffer:
        deprecation_warning(
            old="config['no_local_replay_buffer']",
            help="no_local_replay_buffer specified at new location config["
            "'replay_buffer_config']["
            "'capacity'] will be overwritten.",
            error=False,
        )
        config["replay_buffer_config"][
            "no_local_replay_buffer"] = no_local_replay_buffer

    # TODO (Artur):
    if config["replay_buffer_config"].get("no_local_replay_buffer", False):
        return

    replay_buffer_config = config["replay_buffer_config"]
    assert (
        "type" in replay_buffer_config
    ), "Can not instantiate ReplayBuffer from config without 'type' key."

    # Check if old replay buffer should be instantiated
    buffer_type = config["replay_buffer_config"]["type"]
    if not config["replay_buffer_config"].get("_enable_replay_buffer_api",
                                              False):
        if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
            # Prepend old-style buffers' path
            assert buffer_type == "MultiAgentReplayBuffer", (
                "Without "
                "ReplayBuffer "
                "API, only "
                "MultiAgentReplayBuffer "
                "is supported!")
            # Create valid full [module].[class] string for from_config
            buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
        else:
            assert buffer_type in [
                "ray.rllib.execution.MultiAgentReplayBuffer",
                Legacy_MultiAgentReplayBuffer,
            ], ("Without ReplayBuffer API, only "
                "MultiAgentReplayBuffer is supported!")

        config["replay_buffer_config"]["type"] = buffer_type

        # Remove from config, so it's not passed into the buffer c'tor
        config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)

        # We need to deprecate the old-style location of the following
        # buffer arguments and make users put them into the
        # "replay_buffer_config" field of their config.
        replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
        if replay_batch_size != DEPRECATED_VALUE:
            config["replay_buffer_config"][
                "replay_batch_size"] = replay_batch_size
            deprecation_warning(
                old="config['replay_batch_size']",
                help="Replay batch size specified at new "
                "location config['replay_buffer_config']["
                "'replay_batch_size'] will be overwritten.",
                error=False,
            )

        replay_mode = config.get("replay_mode", DEPRECATED_VALUE)
        if replay_mode != DEPRECATED_VALUE:
            config["replay_buffer_config"]["replay_mode"] = replay_mode
            deprecation_warning(
                old="config['multiagent']['replay_mode']",
                help="Replay sequence length specified at new "
                "location config['replay_buffer_config']["
                "'replay_mode'] will be overwritten.",
                error=False,
            )

        # Can't use DEPRECATED_VALUE here because this is also a deliberate
        # value set for some algorithms
        # TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
        replay_sequence_length = config.get("replay_sequence_length", None)
        if replay_sequence_length is not None:
            config["replay_buffer_config"][
                "replay_sequence_length"] = replay_sequence_length
            deprecation_warning(
                old="config['replay_sequence_length']",
                help="Replay sequence length specified at new "
                "location config['replay_buffer_config']["
                "'replay_sequence_length'] will be overwritten.",
                error=False,
            )

        replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
        if replay_burn_in != DEPRECATED_VALUE:
            config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
            deprecation_warning(
                old="config['burn_in']",
                help="Burn in specified at new location config["
                "'replay_buffer_config']["
                "'replay_burn_in'] will be overwritten.",
            )

        replay_zero_init_states = config.get("replay_zero_init_states",
                                             DEPRECATED_VALUE)
        if replay_zero_init_states != DEPRECATED_VALUE:
            config["replay_buffer_config"][
                "replay_zero_init_states"] = replay_zero_init_states
            deprecation_warning(
                old="config['replay_zero_init_states']",
                help="Replay zero init states specified at new location "
                "config["
                "'replay_buffer_config']["
                "'replay_zero_init_states'] will be overwritten.",
                error=False,
            )

        # TODO (Artur): Move this logic into config objects
        if config["replay_buffer_config"].get("prioritized_replay", False):
            is_prioritized_buffer = True
        else:
            is_prioritized_buffer = False
            # This triggers non-prioritization in old-style replay buffer
            config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
    else:
        if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
            # Create valid full [module].[class] string for from_config
            config["replay_buffer_config"]["type"] = (
                "ray.rllib.utils.replay_buffers." + buffer_type)
        test_buffer = from_config(buffer_type, config["replay_buffer_config"])
        if hasattr(test_buffer, "update_priorities"):
            is_prioritized_buffer = True
        else:
            is_prioritized_buffer = False

    if is_prioritized_buffer:
        if config["multiagent"]["replay_mode"] == "lockstep":
            raise ValueError(
                "Prioritized replay is not supported when replay_mode=lockstep."
            )
        elif config["replay_buffer_config"].get("replay_sequence_length",
                                                0) > 1:
            raise ValueError("Prioritized replay is not supported when "
                             "replay_sequence_length > 1.")
    else:
        if config.get("worker_side_prioritization"):
            raise ValueError(
                "Worker side prioritization is not supported when "
                "prioritized_replay=False.")

    if config["replay_buffer_config"].get("replay_batch_size", None) is None:
        # Fall back to train batch size if no replay batch size was provided
        config["replay_buffer_config"]["replay_batch_size"] = config[
            "train_batch_size"]

    # Pop prioritized replay because it's not a valid parameter for older
    # replay buffers
    config["replay_buffer_config"].pop("prioritized_replay", None)
Ejemplo n.º 23
0
def validate_buffer_config(config: dict) -> None:
    """Checks and fixes values in the replay buffer config.

    Checks the replay buffer config for common misconfigurations, warns or raises
    error in case validation fails. The type "key" is changed into the inferred
    replay buffer class.

    Args:
        config: The replay buffer config to be validated.

    Raises:
        ValueError: When detecting severe misconfiguration.
    """
    if config.get("replay_buffer_config", None) is None:
        config["replay_buffer_config"] = {}

    if config.get("worker_side_prioritization",
                  DEPRECATED_VALUE) != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['worker_side_prioritization']",
            new="config['replay_buffer_config']['worker_side_prioritization']",
            error=True,
        )

    prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE)
    if prioritized_replay != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['prioritized_replay'] or config['replay_buffer_config']["
            "'prioritized_replay']",
            help=
            "Replay prioritization specified by config key. RLlib's new replay "
            "buffer API requires setting `config["
            "'replay_buffer_config']['type']`, e.g. `config["
            "'replay_buffer_config']['type'] = "
            "'MultiAgentPrioritizedReplayBuffer'` to change the default "
            "behaviour.",
            error=True,
        )

    capacity = config.get("buffer_size", DEPRECATED_VALUE)
    if capacity == DEPRECATED_VALUE:
        capacity = config["replay_buffer_config"].get("buffer_size",
                                                      DEPRECATED_VALUE)
    if capacity != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['buffer_size'] or config['replay_buffer_config']["
            "'buffer_size']",
            new="config['replay_buffer_config']['capacity']",
            error=True,
        )

    replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
    if replay_burn_in != DEPRECATED_VALUE:
        config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
        deprecation_warning(
            old="config['burn_in']",
            help="config['replay_buffer_config']['replay_burn_in']",
        )

    replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
    if replay_batch_size == DEPRECATED_VALUE:
        replay_batch_size = config["replay_buffer_config"].get(
            "replay_batch_size", DEPRECATED_VALUE)
    if replay_batch_size != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['replay_batch_size'] or config['replay_buffer_config']["
            "'replay_batch_size']",
            help=
            "Specification of replay_batch_size is not supported anymore but is "
            "derived from `train_batch_size`. Specify the number of "
            "items you want to replay upon calling the sample() method of replay "
            "buffers if this does not work for you.",
            error=True,
        )

    # Deprecation of old-style replay buffer args
    # Warnings before checking of we need local buffer so that algorithms
    # Without local buffer also get warned
    keys_with_deprecated_positions = [
        "prioritized_replay_alpha",
        "prioritized_replay_beta",
        "prioritized_replay_eps",
        "no_local_replay_buffer",
        "replay_zero_init_states",
        "learning_starts",
        "replay_buffer_shards_colocated_with_driver",
    ]
    for k in keys_with_deprecated_positions:
        if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
            deprecation_warning(
                old="config['{}']".format(k),
                help="config['replay_buffer_config']['{}']"
                "".format(k),
                error=False,
            )
            # Copy values over to new location in config to support new
            # and old configuration style.
            if config.get("replay_buffer_config") is not None:
                config["replay_buffer_config"][k] = config[k]

    replay_mode = config.get("multiagent", {}).get("replay_mode",
                                                   DEPRECATED_VALUE)
    if replay_mode != DEPRECATED_VALUE:
        deprecation_warning(
            old="config['multiagent']['replay_mode']",
            help="config['replay_buffer_config']['replay_mode']",
            error=False,
        )
        config["replay_buffer_config"]["replay_mode"] = replay_mode

    # Can't use DEPRECATED_VALUE here because this is also a deliberate
    # value set for some algorithms
    # TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
    replay_sequence_length = config.get("replay_sequence_length", None)
    if replay_sequence_length is not None:
        config["replay_buffer_config"][
            "replay_sequence_length"] = replay_sequence_length
        deprecation_warning(
            old="config['replay_sequence_length']",
            help="Replay sequence length specified at new "
            "location config['replay_buffer_config']["
            "'replay_sequence_length'] will be overwritten.",
            error=False,
        )

    replay_buffer_config = config["replay_buffer_config"]
    assert (
        "type" in replay_buffer_config
    ), "Can not instantiate ReplayBuffer from config without 'type' key."

    # Check if old replay buffer should be instantiated
    buffer_type = config["replay_buffer_config"]["type"]

    if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
        # Create valid full [module].[class] string for from_config
        config["replay_buffer_config"]["type"] = (
            "ray.rllib.utils.replay_buffers." + buffer_type)

    # Instantiate a dummy buffer to fail early on misconfiguration and find out about
    # inferred buffer class
    dummy_buffer = from_config(buffer_type, config["replay_buffer_config"])

    config["replay_buffer_config"]["type"] = type(dummy_buffer)

    if hasattr(dummy_buffer, "update_priorities"):
        if config["multiagent"]["replay_mode"] == "lockstep":
            raise ValueError(
                "Prioritized replay is not supported when replay_mode=lockstep."
            )
        elif config["replay_buffer_config"].get("replay_sequence_length",
                                                0) > 1:
            raise ValueError("Prioritized replay is not supported when "
                             "replay_sequence_length > 1.")
    else:
        if config["replay_buffer_config"].get("worker_side_prioritization"):
            raise ValueError(
                "Worker side prioritization is not supported when "
                "prioritized_replay=False.")
Ejemplo n.º 24
0
    def _setup(self, config: dict):
        env = self._env_id
        if env:
            config["env"] = env
            # An already registered env.
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            # A class specifier.
            elif "." in env:
                self.env_creator = \
                    lambda env_config: from_config(env, env_config)
            # Try gym.
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default, but store the
        # user-provided one.
        self.raw_user_config = config
        self.config = Trainer.merge_trainer_configs(self._default_config,
                                                    config)

        # Check and resolve DL framework settings.
        if "use_pytorch" in self.config and \
                self.config["use_pytorch"] != DEPRECATED_VALUE:
            deprecation_warning("use_pytorch", "framework=torch", error=False)
            if self.config["use_pytorch"]:
                self.config["framework"] = "torch"
            self.config.pop("use_pytorch")
        if "eager" in self.config and self.config["eager"] != DEPRECATED_VALUE:
            deprecation_warning("eager", "framework=tfe", error=False)
            if self.config["eager"]:
                self.config["framework"] = "tfe"
            self.config.pop("eager")

        # Enable eager/tracing support.
        if tf and self.config["framework"] == "tfe":
            if not tf.executing_eagerly():
                tf.enable_eager_execution()
            logger.info("Executing eagerly, with eager_tracing={}".format(
                self.config["eager_tracing"]))
        if tf and not tf.executing_eagerly() and \
                self.config["framework"] != "torch":
            logger.info("Tip: set framework=tfe or the --eager flag to enable "
                        "TensorFlow eager execution")

        if self.config["normalize_actions"]:
            inner = self.env_creator

            def normalize(env):
                import gym  # soft dependency
                if not isinstance(env, gym.Env):
                    raise ValueError(
                        "Cannot apply NormalizeActionActionWrapper to env of "
                        "type {}, which does not subclass gym.Env.", type(env))
                return NormalizeActionWrapper(env)

            self.env_creator = lambda env_config: normalize(inner(env_config))

        Trainer._validate_config(self.config)
        if not callable(self.config["callbacks"]):
            raise ValueError(
                "`callbacks` must be a callable method that "
                "returns a subclass of DefaultCallbacks, got {}".format(
                    self.config["callbacks"]))
        self.callbacks = self.config["callbacks"]()
        log_level = self.config.get("log_level")
        if log_level in ["WARN", "ERROR"]:
            logger.info("Current log_level is {}. For more information, "
                        "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                        "-vv flags.".format(log_level))
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        def get_scope():
            if tf and not tf.executing_eagerly():
                return tf.Graph().as_default()
            else:
                return open(os.devnull)  # fake a no-op scope

        with get_scope():
            self._init(self.config, self.env_creator)

            # Evaluation setup.
            if self.config.get("evaluation_interval"):
                # Update env_config with evaluation settings:
                extra_config = copy.deepcopy(self.config["evaluation_config"])
                # Assert that user has not unset "in_evaluation".
                assert "in_evaluation" not in extra_config or \
                    extra_config["in_evaluation"] is True
                extra_config.update({
                    "batch_mode": "complete_episodes",
                    "rollout_fragment_length": 1,
                    "in_evaluation": True,
                })
                logger.debug(
                    "using evaluation_config: {}".format(extra_config))

                self.evaluation_workers = self._make_workers(
                    self.env_creator,
                    self._policy,
                    merge_dicts(self.config, extra_config),
                    num_workers=self.config["evaluation_num_workers"])
                self.evaluation_metrics = {}
Ejemplo n.º 25
0
    def get_model_v2(obs_space: gym.Space,
                     action_space: gym.Space,
                     num_outputs: int,
                     model_config: ModelConfigDict,
                     framework: str = "tf",
                     name: str = "default_model",
                     model_interface: type = None,
                     default_model: type = None,
                     **model_kwargs) -> ModelV2:
        """Returns a suitable model compatible with given spaces and output.

        Args:
            obs_space (Space): Observation space of the target gym env. This
                may have an `original_space` attribute that specifies how to
                unflatten the tensor into a ragged tensor.
            action_space (Space): Action space of the target gym env.
            num_outputs (int): The size of the output vector of the model.
            model_config (ModelConfigDict): The "model" sub-config dict
                within the Trainer's config dict.
            framework (str): One of "tf2", "tf", "tfe", "torch", or "jax".
            name (str): Name (scope) for the model.
            model_interface (cls): Interface required for the model
            default_model (cls): Override the default class for the model. This
                only has an effect when not using a custom model
            model_kwargs (dict): Args to pass to the ModelV2 constructor

        Returns:
            model (ModelV2): Model to use for the policy.
        """

        # Validate the given config dict.
        ModelCatalog._validate_config(config=model_config,
                                      action_space=action_space,
                                      framework=framework)

        if model_config.get("custom_model"):
            # Allow model kwargs to be overridden / augmented by
            # custom_model_config.
            customized_model_kwargs = dict(
                model_kwargs, **model_config.get("custom_model_config", {}))

            if isinstance(model_config["custom_model"], type):
                model_cls = model_config["custom_model"]
            elif (isinstance(model_config["custom_model"], str)
                  and "." in model_config["custom_model"]):
                return from_config(
                    cls=model_config["custom_model"],
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=num_outputs,
                    model_config=customized_model_kwargs,
                    name=name,
                )
            else:
                model_cls = _global_registry.get(RLLIB_MODEL,
                                                 model_config["custom_model"])

            # Only allow ModelV2 or native keras Models.
            if not issubclass(model_cls, ModelV2):
                if framework not in [
                        "tf", "tf2", "tfe"
                ] or not issubclass(model_cls, tf.keras.Model):
                    raise ValueError(
                        "`model_cls` must be a ModelV2 sub-class, but is"
                        " {}!".format(model_cls))

            logger.info("Wrapping {} as {}".format(model_cls, model_interface))
            model_cls = ModelCatalog._wrap_if_needed(model_cls,
                                                     model_interface)

            if framework in ["tf2", "tf", "tfe"]:
                # Try wrapping custom model with LSTM/attention, if required.
                if model_config.get("use_lstm") or model_config.get(
                        "use_attention"):
                    from ray.rllib.models.tf.attention_net import (
                        AttentionWrapper,
                        Keras_AttentionWrapper,
                    )
                    from ray.rllib.models.tf.recurrent_net import (
                        LSTMWrapper,
                        Keras_LSTMWrapper,
                    )

                    wrapped_cls = model_cls
                    # Wrapped (custom) model is itself a keras Model ->
                    # wrap with keras LSTM/GTrXL (attention) wrappers.
                    if issubclass(wrapped_cls, tf.keras.Model):
                        model_cls = (Keras_LSTMWrapper
                                     if model_config.get("use_lstm") else
                                     Keras_AttentionWrapper)
                        model_config["wrapped_cls"] = wrapped_cls
                    # Wrapped (custom) model is ModelV2 ->
                    # wrap with ModelV2 LSTM/GTrXL (attention) wrappers.
                    else:
                        forward = wrapped_cls.forward
                        model_cls = ModelCatalog._wrap_if_needed(
                            wrapped_cls,
                            LSTMWrapper if model_config.get("use_lstm") else
                            AttentionWrapper,
                        )
                        model_cls._wrapped_forward = forward

                # Obsolete: Track and warn if vars were created but not
                # registered. Only still do this, if users do register their
                # variables. If not (which they shouldn't), don't check here.
                created = set()

                def track_var_creation(next_creator, **kw):
                    v = next_creator(**kw)
                    created.add(v.ref())
                    return v

                with tf.variable_creator_scope(track_var_creation):
                    if issubclass(model_cls, tf.keras.Model):
                        instance = model_cls(
                            input_space=obs_space,
                            action_space=action_space,
                            num_outputs=num_outputs,
                            name=name,
                            **customized_model_kwargs,
                        )
                    else:
                        # Try calling with kwargs first (custom ModelV2 should
                        # accept these as kwargs, not get them from
                        # config["custom_model_config"] anymore).
                        try:
                            instance = model_cls(
                                obs_space,
                                action_space,
                                num_outputs,
                                model_config,
                                name,
                                **customized_model_kwargs,
                            )
                        except TypeError as e:
                            # Keyword error: Try old way w/o kwargs.
                            if "__init__() got an unexpected " in e.args[0]:
                                instance = model_cls(
                                    obs_space,
                                    action_space,
                                    num_outputs,
                                    model_config,
                                    name,
                                    **model_kwargs,
                                )
                                logger.warning(
                                    "Custom ModelV2 should accept all custom "
                                    "options as **kwargs, instead of expecting"
                                    " them in config['custom_model_config']!")
                            # Other error -> re-raise.
                            else:
                                raise e

                # User still registered TFModelV2's variables: Check, whether
                # ok.
                registered = []
                if not isinstance(instance, tf.keras.Model):
                    registered = set(instance.var_list)
                if len(registered) > 0:
                    not_registered = set()
                    for var in created:
                        if var not in registered:
                            not_registered.add(var)
                    if not_registered:
                        raise ValueError(
                            "It looks like you are still using "
                            "`{}.register_variables()` to register your "
                            "model's weights. This is no longer required, but "
                            "if you are still calling this method at least "
                            "once, you must make sure to register all created "
                            "variables properly. The missing variables are {},"
                            " and you only registered {}. "
                            "Did you forget to call `register_variables()` on "
                            "some of the variables in question?".format(
                                instance, not_registered, registered))
            elif framework == "torch":
                # Try wrapping custom model with LSTM/attention, if required.
                if model_config.get("use_lstm") or model_config.get(
                        "use_attention"):
                    from ray.rllib.models.torch.attention_net import AttentionWrapper
                    from ray.rllib.models.torch.recurrent_net import LSTMWrapper

                    wrapped_cls = model_cls
                    forward = wrapped_cls.forward
                    model_cls = ModelCatalog._wrap_if_needed(
                        wrapped_cls,
                        LSTMWrapper
                        if model_config.get("use_lstm") else AttentionWrapper,
                    )
                    model_cls._wrapped_forward = forward

                # PyTorch automatically tracks nn.Modules inside the parent
                # nn.Module's constructor.
                # Try calling with kwargs first (custom ModelV2 should
                # accept these as kwargs, not get them from
                # config["custom_model_config"] anymore).
                try:
                    instance = model_cls(
                        obs_space,
                        action_space,
                        num_outputs,
                        model_config,
                        name,
                        **customized_model_kwargs,
                    )
                except TypeError as e:
                    # Keyword error: Try old way w/o kwargs.
                    if "__init__() got an unexpected " in e.args[0]:
                        instance = model_cls(
                            obs_space,
                            action_space,
                            num_outputs,
                            model_config,
                            name,
                            **model_kwargs,
                        )
                        logger.warning(
                            "Custom ModelV2 should accept all custom "
                            "options as **kwargs, instead of expecting"
                            " them in config['custom_model_config']!")
                    # Other error -> re-raise.
                    else:
                        raise e
            else:
                raise NotImplementedError(
                    "`framework` must be 'tf2|tf|tfe|torch', but is "
                    "{}!".format(framework))

            return instance

        # Find a default TFModelV2 and wrap with model_interface.
        if framework in ["tf", "tfe", "tf2"]:
            v2_class = None
            # Try to get a default v2 model.
            if not model_config.get("custom_model"):
                v2_class = default_model or ModelCatalog._get_v2_model_class(
                    obs_space, model_config, framework=framework)

            if not v2_class:
                raise ValueError("ModelV2 class could not be determined!")

            if model_config.get("use_lstm") or model_config.get(
                    "use_attention"):

                from ray.rllib.models.tf.attention_net import (
                    AttentionWrapper,
                    Keras_AttentionWrapper,
                )
                from ray.rllib.models.tf.recurrent_net import (
                    LSTMWrapper,
                    Keras_LSTMWrapper,
                )

                wrapped_cls = v2_class
                if model_config.get("use_lstm"):
                    if issubclass(wrapped_cls, tf.keras.Model):
                        v2_class = Keras_LSTMWrapper
                        model_config["wrapped_cls"] = wrapped_cls
                    else:
                        v2_class = ModelCatalog._wrap_if_needed(
                            wrapped_cls, LSTMWrapper)
                        v2_class._wrapped_forward = wrapped_cls.forward
                else:
                    if issubclass(wrapped_cls, tf.keras.Model):
                        v2_class = Keras_AttentionWrapper
                        model_config["wrapped_cls"] = wrapped_cls
                    else:
                        v2_class = ModelCatalog._wrap_if_needed(
                            wrapped_cls, AttentionWrapper)
                        v2_class._wrapped_forward = wrapped_cls.forward

            # Wrap in the requested interface.
            wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)

            if issubclass(wrapper, tf.keras.Model):
                model = wrapper(
                    input_space=obs_space,
                    action_space=action_space,
                    num_outputs=num_outputs,
                    name=name,
                    **dict(model_kwargs, **model_config),
                )
                return model

            return wrapper(obs_space, action_space, num_outputs, model_config,
                           name, **model_kwargs)

        # Find a default TorchModelV2 and wrap with model_interface.
        elif framework == "torch":
            # Try to get a default v2 model.
            if not model_config.get("custom_model"):
                v2_class = default_model or ModelCatalog._get_v2_model_class(
                    obs_space, model_config, framework=framework)

            if not v2_class:
                raise ValueError("ModelV2 class could not be determined!")

            if model_config.get("use_lstm") or model_config.get(
                    "use_attention"):

                from ray.rllib.models.torch.attention_net import AttentionWrapper
                from ray.rllib.models.torch.recurrent_net import LSTMWrapper

                wrapped_cls = v2_class
                forward = wrapped_cls.forward
                if model_config.get("use_lstm"):
                    v2_class = ModelCatalog._wrap_if_needed(
                        wrapped_cls, LSTMWrapper)
                else:
                    v2_class = ModelCatalog._wrap_if_needed(
                        wrapped_cls, AttentionWrapper)

                v2_class._wrapped_forward = forward

            # Wrap in the requested interface.
            wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
            return wrapper(obs_space, action_space, num_outputs, model_config,
                           name, **model_kwargs)

        # Find a default JAXModelV2 and wrap with model_interface.
        elif framework == "jax":
            v2_class = default_model or ModelCatalog._get_v2_model_class(
                obs_space, model_config, framework=framework)
            # Wrap in the requested interface.
            wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
            return wrapper(obs_space, action_space, num_outputs, model_config,
                           name, **model_kwargs)
        else:
            raise NotImplementedError(
                "`framework` must be 'tf2|tf|tfe|torch', but is "
                "{}!".format(framework))
Ejemplo n.º 26
0
    def test_dummy_components(self):

        # Bazel makes it hard to find files specified in `args`
        # (and `data`).
        # Use the true absolute path.
        script_dir = Path(__file__).parent
        abs_path = script_dir.absolute()

        for fw, sess in framework_iterator(session=True):
            fw_ = fw if fw != "tfe" else "tf"
            # Try to create from an abstract class w/o default constructor.
            # Expect None.
            test = from_config({"type": AbstractDummyComponent, "framework": fw_})
            check(test, None)

            # Create a Component via python API (config dict).
            component = from_config(
                dict(
                    type=DummyComponent, prop_a=1.0, prop_d="non_default", framework=fw_
                )
            )
            check(component.prop_d, "non_default")

            # Create a tf Component from json file.
            config_file = str(abs_path.joinpath("dummy_config.json"))
            component = from_config(config_file, framework=fw_)
            check(component.prop_c, "default")
            check(component.prop_d, 4)  # default
            value = component.add(3.3)
            if sess:
                value = sess.run(value)
            check(value, 5.3)  # prop_b == 2.0

            # Create a torch Component from yaml file.
            config_file = str(abs_path.joinpath("dummy_config.yml"))
            component = from_config(config_file, framework=fw_)
            check(component.prop_a, "something else")
            check(component.prop_d, 3)
            value = component.add(1.2)
            if sess:
                value = sess.run(value)
            check(value, np.array([2.2]))  # prop_b == 1.0

            # Create tf Component from json-string (e.g. on command line).
            component = from_config(
                '{"type": "ray.rllib.utils.tests.'
                'test_framework_agnostic_components.DummyComponent", '
                '"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default", '
                '"framework": "' + fw_ + '"}'
            )
            check(component.prop_a, "A")
            check(component.prop_d, 4)  # default
            value = component.add(-1.1)
            if sess:
                value = sess.run(value)
            check(value, -2.1)  # prop_b == -1.0

            # Test recognizing default module path.
            component = from_config(
                DummyComponent,
                '{"type": "NonAbstractChildOfDummyComponent", '
                '"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default",'
                '"framework": "' + fw_ + '"}',
            )
            check(component.prop_a, "A")
            check(component.prop_d, 4)  # default
            value = component.add(-1.1)
            if sess:
                value = sess.run(value)
            check(value, -2.1)  # prop_b == -1.0

            # Test recognizing default package path.
            scope = None
            if sess:
                scope = tf1.variable_scope("exploration_object")
                scope.__enter__()
            component = from_config(
                Exploration,
                {
                    "type": "EpsilonGreedy",
                    "action_space": Discrete(2),
                    "framework": fw_,
                    "num_workers": 0,
                    "worker_index": 0,
                    "policy_config": {},
                    "model": None,
                },
            )
            if scope:
                scope.__exit__(None, None, None)
            check(component.epsilon_schedule.outside_value, 0.05)  # default

            # Create torch Component from yaml-string.
            component = from_config(
                "type: ray.rllib.utils.tests."
                "test_framework_agnostic_components.DummyComponent\n"
                "prop_a: B\nprop_b: -1.5\nprop_c: non-default\nframework: "
                "{}".format(fw_)
            )
            check(component.prop_a, "B")
            check(component.prop_d, 4)  # default
            value = component.add(-5.1)
            if sess:
                value = sess.run(value)
            check(value, np.array([-6.6]))  # prop_b == -1.5
Ejemplo n.º 27
0
 def test_unregistered_envs(self):
     """Tests, whether an Env can be specified simply by its absolute class."""
     env_cls = "ray.rllib.examples.env.stateless_cartpole.StatelessCartPole"
     env = from_config(env_cls, {"config": 42.0})
     state = env.reset()
     self.assertTrue(state.shape == (2,))
Ejemplo n.º 28
0
 def new_buffer():
     return from_config(self.underlying_buffer_config["type"],
                        ctor_args)
Ejemplo n.º 29
0
    def _make_worker(
        self,
        *,
        cls: Callable,
        env_creator: Callable[[EnvContext], EnvType],
        validate_env: Optional[Callable[[EnvType], None]],
        policy_cls: Type[Policy],
        worker_index: int,
        num_workers: int,
        config: TrainerConfigDict,
        spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
                                              gym.spaces.Space]]] = None,
    ) -> Union[RolloutWorker, ActorHandle]:
        def session_creator():
            logger.debug("Creating TF session {}".format(
                config["tf_session_args"]))
            return tf1.Session(config=tf1.ConfigProto(
                **config["tf_session_args"]))

        def valid_module(class_path):
            if isinstance(class_path, str) and "." in class_path:
                module_path, class_name = class_path.rsplit(".", 1)
                try:
                    spec = importlib.util.find_spec(module_path)
                    if spec is not None:
                        return True
                except (ModuleNotFoundError, ValueError):
                    print(
                        f"module {module_path} not found while trying to get "
                        f"input {class_path}")
            return False

        if isinstance(config["input"], FunctionType):
            input_creator = config["input"]
        elif config["input"] == "sampler":
            input_creator = (lambda ioctx: ioctx.default_sampler_input())
        elif isinstance(config["input"], dict):
            input_creator = (
                lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))
        elif isinstance(config["input"], str) and \
                registry_contains_input(config["input"]):
            input_creator = registry_get_input(config["input"])
        elif "d4rl" in config["input"]:
            env_name = config["input"].split(".")[-1]
            input_creator = (lambda ioctx: D4RLReader(env_name, ioctx))
        elif valid_module(config["input"]):
            input_creator = (lambda ioctx: ShuffledInput(
                from_config(config["input"], ioctx=ioctx)))
        else:
            input_creator = (
                lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))

        if isinstance(config["output"], FunctionType):
            output_creator = config["output"]
        elif config["output"] is None:
            output_creator = (lambda ioctx: NoopOutput())
        elif config["output"] == "logdir":
            output_creator = (lambda ioctx: JsonWriter(
                ioctx.log_dir,
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))
        else:
            output_creator = (lambda ioctx: JsonWriter(
                config["output"],
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))

        if config["input"] == "sampler":
            input_evaluation = []
        else:
            input_evaluation = config["input_evaluation"]

        # Fill in the default policy_cls if 'None' is specified in multiagent.
        if config["multiagent"]["policies"]:
            tmp = config["multiagent"]["policies"]
            _validate_multiagent_config(tmp, allow_none_graph=True)
            # TODO: (sven) Allow for setting observation and action spaces to
            #  None as well, in which case, spaces are taken from env.
            #  It's tedious to have to provide these in a multi-agent config.
            for k, v in tmp.items():
                if v[0] is None:
                    tmp[k] = (policy_cls, v[1], v[2], v[3])
            policy_spec = tmp
        # Otherwise, policy spec is simply the policy class itself.
        else:
            policy_spec = policy_cls

        if worker_index == 0:
            extra_python_environs = config.get(
                "extra_python_environs_for_driver", None)
        else:
            extra_python_environs = config.get(
                "extra_python_environs_for_worker", None)

        worker = cls(
            env_creator=env_creator,
            validate_env=validate_env,
            policy_spec=policy_spec,
            policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
            policies_to_train=config["multiagent"]["policies_to_train"],
            tf_session_creator=(session_creator
                                if config["tf_session_args"] else None),
            rollout_fragment_length=config["rollout_fragment_length"],
            count_steps_by=config["multiagent"]["count_steps_by"],
            batch_mode=config["batch_mode"],
            episode_horizon=config["horizon"],
            preprocessor_pref=config["preprocessor_pref"],
            sample_async=config["sample_async"],
            compress_observations=config["compress_observations"],
            num_envs=config["num_envs_per_worker"],
            observation_fn=config["multiagent"]["observation_fn"],
            observation_filter=config["observation_filter"],
            clip_rewards=config["clip_rewards"],
            normalize_actions=config["normalize_actions"],
            clip_actions=config["clip_actions"],
            env_config=config["env_config"],
            model_config=config["model"],
            policy_config=config,
            worker_index=worker_index,
            num_workers=num_workers,
            record_env=config["record_env"],
            log_dir=self._logdir,
            log_level=config["log_level"],
            callbacks=config["callbacks"],
            input_creator=input_creator,
            input_evaluation=input_evaluation,
            output_creator=output_creator,
            remote_worker_envs=config["remote_worker_envs"],
            remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
            soft_horizon=config["soft_horizon"],
            no_done_at_end=config["no_done_at_end"],
            seed=(config["seed"] +
                  worker_index) if config["seed"] is not None else None,
            fake_sampler=config["fake_sampler"],
            extra_python_environs=extra_python_environs,
            spaces=spaces,
        )

        return worker
Ejemplo n.º 30
0
    def _make_worker(
            self,
            *,
            cls: Callable,
            env_creator: Callable[[EnvContext], EnvType],
            validate_env: Optional[Callable[[EnvType], None]],
            policy_cls: Type[Policy],
            worker_index: int,
            num_workers: int,
            config: TrainerConfigDict,
            spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
                                                  gym.spaces.Space]]] = None,
    ) -> Union[RolloutWorker, ActorHandle]:
        def session_creator():
            logger.debug("Creating TF session {}".format(
                config["tf_session_args"]))
            return tf1.Session(
                config=tf1.ConfigProto(**config["tf_session_args"]))

        def valid_module(class_path):
            if isinstance(class_path, str) and "." in class_path:
                module_path, class_name = class_path.rsplit(".", 1)
                try:
                    spec = importlib.util.find_spec(module_path)
                    if spec is not None:
                        return True
                except (ModuleNotFoundError, ValueError):
                    print(
                        f"module {module_path} not found while trying to get "
                        f"input {class_path}")
            return False

        if isinstance(config["input"], FunctionType):
            input_creator = config["input"]
        elif config["input"] == "sampler":
            input_creator = (lambda ioctx: ioctx.default_sampler_input())
        elif isinstance(config["input"], dict):
            input_creator = (
                lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))
        elif isinstance(config["input"], str) and \
                registry_contains_input(config["input"]):
            input_creator = registry_get_input(config["input"])
        elif "d4rl" in config["input"]:
            env_name = config["input"].split(".")[-1]
            input_creator = (lambda ioctx: D4RLReader(env_name, ioctx))
        elif valid_module(config["input"]):
            input_creator = (lambda ioctx: ShuffledInput(from_config(
                config["input"], ioctx=ioctx)))
        else:
            input_creator = (
                lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),
                                            config["shuffle_buffer_size"]))

        if isinstance(config["output"], FunctionType):
            output_creator = config["output"]
        elif config["output"] is None:
            output_creator = (lambda ioctx: NoopOutput())
        elif config["output"] == "logdir":
            output_creator = (lambda ioctx: JsonWriter(
                ioctx.log_dir,
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))
        else:
            output_creator = (lambda ioctx: JsonWriter(
                config["output"],
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"]))

        if config["input"] == "sampler":
            input_evaluation = []
        else:
            input_evaluation = config["input_evaluation"]

        # Assert everything is correct in "multiagent" config dict (if given).
        ma_policies = config["multiagent"]["policies"]
        if ma_policies:
            for pid, policy_spec in ma_policies.copy().items():
                assert isinstance(policy_spec, (PolicySpec, list, tuple))
                # Class is None -> Use `policy_cls`.
                if policy_spec.policy_class is None:
                    ma_policies[pid] = ma_policies[pid]._replace(
                        policy_class=policy_cls)
            policies = ma_policies

        # Create a policy_spec (MultiAgentPolicyConfigDict),
        # even if no "multiagent" setup given by user.
        else:
            policies = policy_cls

        if worker_index == 0:
            extra_python_environs = config.get(
                "extra_python_environs_for_driver", None)
        else:
            extra_python_environs = config.get(
                "extra_python_environs_for_worker", None)

        worker = cls(
            env_creator=env_creator,
            validate_env=validate_env,
            policy_spec=policies,
            policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
            policies_to_train=config["multiagent"]["policies_to_train"],
            tf_session_creator=(session_creator
                                if config["tf_session_args"] else None),
            rollout_fragment_length=config["rollout_fragment_length"],
            count_steps_by=config["multiagent"]["count_steps_by"],
            batch_mode=config["batch_mode"],
            episode_horizon=config["horizon"],
            preprocessor_pref=config["preprocessor_pref"],
            sample_async=config["sample_async"],
            compress_observations=config["compress_observations"],
            num_envs=config["num_envs_per_worker"],
            observation_fn=config["multiagent"]["observation_fn"],
            observation_filter=config["observation_filter"],
            clip_rewards=config["clip_rewards"],
            normalize_actions=config["normalize_actions"],
            clip_actions=config["clip_actions"],
            env_config=config["env_config"],
            policy_config=config,
            worker_index=worker_index,
            num_workers=num_workers,
            record_env=config["record_env"],
            log_dir=self._logdir,
            log_level=config["log_level"],
            callbacks=config["callbacks"],
            input_creator=input_creator,
            input_evaluation=input_evaluation,
            output_creator=output_creator,
            remote_worker_envs=config["remote_worker_envs"],
            remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
            soft_horizon=config["soft_horizon"],
            no_done_at_end=config["no_done_at_end"],
            seed=(config["seed"] + worker_index)
            if config["seed"] is not None else None,
            fake_sampler=config["fake_sampler"],
            extra_python_environs=extra_python_environs,
            spaces=spaces,
        )

        return worker