Example #1
0
    def __init__(
        self,
        policy_map: Dict[PolicyID, Policy],
        clip_rewards: bool,
        callbacks: "DefaultCallbacks",
    ):
        """Initialize a MultiAgentSampleBatchBuilder.

        Args:
            policy_map (Dict[str,Policy]): Maps policy ids to policy instances.
            clip_rewards (Union[bool,float]): Whether to clip rewards before
                postprocessing (at +/-1.0) or the actual value to +/- clip.
            callbacks (DefaultCallbacks): RLlib callbacks.
        """
        if log_once("MultiAgentSampleBatchBuilder"):
            deprecation_warning(old="MultiAgentSampleBatchBuilder", error=False)
        self.policy_map = policy_map
        self.clip_rewards = clip_rewards
        # Build the Policies' SampleBatchBuilders.
        self.policy_builders = {k: SampleBatchBuilder() for k in policy_map.keys()}
        # Whenever we observe a new agent, add a new SampleBatchBuilder for
        # this agent.
        self.agent_builders = {}
        # Internal agent-to-policy map.
        self.agent_to_policy = {}
        self.callbacks = callbacks
        # Number of "inference" steps taken in the environment.
        # Regardless of the number of agents involved in each of these steps.
        self.count = 0
Example #2
0
    def get_model(input_dict,
                  obs_space,
                  action_space,
                  num_outputs,
                  options,
                  state_in=None,
                  seq_lens=None):
        """Deprecated: Use get_model_v2() instead."""

        deprecation_warning("get_model", "get_model_v2", error=False)
        assert isinstance(input_dict, dict)
        options = options or MODEL_DEFAULTS
        model = ModelCatalog._get_model(input_dict, obs_space, action_space,
                                        num_outputs, options, state_in,
                                        seq_lens)

        if options.get("use_lstm"):
            copy = dict(input_dict)
            copy["obs"] = model.last_layer
            feature_space = gym.spaces.Box(-1,
                                           1,
                                           shape=(model.last_layer.shape[1], ))
            model = LSTM(copy, feature_space, action_space, num_outputs,
                         options, state_in, seq_lens)

        logger.debug(
            "Created model {}: ({} of {}, {}, {}, {}) -> {}, {}".format(
                model, input_dict, obs_space, action_space, state_in, seq_lens,
                model.outputs, model.state_out))

        model._validate_output_shape()
        return model
Example #3
0
def setup_config(policy: Policy, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space,
                 config: TrainerConfigDict) -> None:
    """Executed before Policy is "initialized" (at beginning of constructor).

    Args:
        policy (Policy): The Policy object.
        obs_space (gym.spaces.Space): The Policy's observation space.
        action_space (gym.spaces.Space): The Policy's action space.
        config (TrainerConfigDict): The Policy's config.
    """
    # Setting `vf_share_layers` in the top-level config is deprecated.
    # It's confusing as some users might (correctly!) set it in their
    # model config and then won't notice that it's silently overwritten
    # here.
    if config["vf_share_layers"] != DEPRECATED_VALUE:
        deprecation_warning(
            old="config[vf_share_layers]",
            new="config[model][vf_share_layers]",
            error=False,
        )
        config["model"]["vf_share_layers"] = config["vf_share_layers"]

    # If vf_share_layers is True, inform about the need to tune vf_loss_coeff.
    if config.get("model", {}).get("vf_share_layers") is True:
        logger.info(
            "`vf_share_layers=True` in your model. "
            "Therefore, remember to tune the value of `vf_loss_coeff`!")
Example #4
0
    def __getitem__(self, key: Union[str, slice]) -> TensorType:
        """Returns one column (by key) from the data or a sliced new batch.

        Args:
            key (Union[str, slice]): The key (column name) to return or
                a slice object for slicing this SampleBatch.

        Returns:
            TensorType: The data under the given key or a sliced version of
                this batch.
        """
        if isinstance(key, slice):
            return self._slice(key)

        if not hasattr(self, key) and key in self:
            self.accessed_keys.add(key)

        # Backward compatibility for when "input-dicts" were used.
        if key == "is_training":
            if log_once("SampleBatch['is_training']"):
                deprecation_warning(old="SampleBatch['is_training']",
                                    new="SampleBatch.is_training",
                                    error=False)
            return self.is_training

        value = dict.__getitem__(self, key)
        if self.get_interceptor is not None:
            if key not in self.intercepted_values:
                self.intercepted_values[key] = self.get_interceptor(value)
            value = self.intercepted_values[key]
        return value
Example #5
0
    def __init__(self,
                 capacity: int = 10000,
                 size: Optional[int] = DEPRECATED_VALUE):
        """Initializes a Replaybuffer instance.

        Args:
            capacity (int): Max number of timesteps to store in the FIFO
                buffer. After reaching this number, older samples will be
                dropped to make space for new ones.
        """
        # Deprecated args.
        if size != DEPRECATED_VALUE:
            deprecation_warning("ReplayBuffer(size)",
                                "ReplayBuffer(capacity)",
                                error=False)
            capacity = size

        # The actual storage (list of SampleBatches).
        self._storage = []

        self.capacity = capacity
        # The next index to override in the buffer.
        self._next_idx = 0
        self._hit_count = np.zeros(self.capacity)

        # Whether we have already hit our capacity (and have therefore
        # started to evict older samples).
        self._eviction_started = False

        self._num_timesteps_added = 0
        self._num_timesteps_added_wrap = 0
        self._num_timesteps_sampled = 0
        self._evicted_hit_stats = WindowStat("evicted_hit", 1000)
        self._est_size_bytes = 0
Example #6
0
    def __getitem__(self, key: str) -> TensorType:
        """Returns one column (by key) from the data.

        Args:
            key (str): The key (column name) to return.

        Returns:
            TensorType: The data under the given key.
        """
        if not hasattr(self, key):
            self.accessed_keys.add(key)

        # Backward compatibility for when "input-dicts" were used.
        if key == "is_training":
            if log_once("SampleBatch['is_training']"):
                deprecation_warning(
                    old="SampleBatch['is_training']",
                    new="SampleBatch.is_training",
                    error=False)
            return self.is_training
        elif key == "seq_lens":
            if self.get_interceptor is not None and self.seq_lens is not None:
                if "seq_lens" not in self.intercepted_values:
                    self.intercepted_values["seq_lens"] = self.get_interceptor(
                        self.seq_lens)
                return self.intercepted_values["seq_lens"]
            return self.seq_lens

        value = dict.__getitem__(self, key)
        if self.get_interceptor is not None:
            if key not in self.intercepted_values:
                self.intercepted_values[key] = self.get_interceptor(value)
            value = self.intercepted_values[key]
        return value
Example #7
0
 def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None):
     if legacy_callbacks_dict:
         deprecation_warning(
             "callbacks dict interface",
             "a class extending rllib.agents.callbacks.DefaultCallbacks",
         )
     self.legacy_callbacks = legacy_callbacks_dict or {}
Example #8
0
    def __setitem__(self, key, item) -> None:
        """Inserts (overrides) an entire column (by key) in the data buffer.

        Args:
            key: The column name to set a value for.
            item: The data to insert.
        """
        # Defend against creating SampleBatch via pickle (no property
        # `added_keys` and first item is already set).
        if not hasattr(self, "added_keys"):
            dict.__setitem__(self, key, item)
            return

        # Backward compatibility for when "input-dicts" were used.
        if key == "is_training":
            if log_once("SampleBatch['is_training']"):
                deprecation_warning(
                    old="SampleBatch['is_training']",
                    new="SampleBatch.is_training",
                    error=False,
                )
            self._is_training = item
            return

        if key not in self:
            self.added_keys.add(key)

        dict.__setitem__(self, key, item)
        if key in self.intercepted_values:
            self.intercepted_values[key] = item
Example #9
0
    def export_checkpoint(self, export_dir: str) -> None:
        """Export Policy checkpoint to local directory.

        Args:
            export_dir (str): Local writable directory.
        """
        deprecation_warning("export_checkpoint", "save")
        raise NotImplementedError
Example #10
0
 def __init__(self):
     if log_once("SampleBatchBuilder"):
         deprecation_warning(
             old="SampleBatchBuilder",
             new="child class of `SampleCollector`",
             error=False)
     self.buffers: Dict[str, List] = collections.defaultdict(list)
     self.count = 0
Example #11
0
 def patched_init(*args, **kwargs):
     if log_once(old or obj.__name__):
         deprecation_warning(
             old=old or obj.__name__,
             new=new,
             help=help,
             error=error,
         )
     return obj_init(*args, **kwargs)
Example #12
0
    def _build_layers_v2(self, input_dict, num_outputs, options):
        # Hard deprecate this class. All Models should use the ModelV2
        # API from here on.
        deprecation_warning("Model->LSTM", "RecurrentNetwork", error=False)

        cell_size = options.get("lstm_cell_size")
        if options.get("lstm_use_prev_action_reward"):
            action_dim = int(
                np.product(
                    input_dict["prev_actions"].get_shape().as_list()[1:]))
            features = tf.concat(
                [
                    input_dict["obs"],
                    tf.reshape(
                        tf.cast(input_dict["prev_actions"], tf.float32),
                        [-1, action_dim]),
                    tf.reshape(input_dict["prev_rewards"], [-1, 1]),
                ],
                axis=1)
        else:
            features = input_dict["obs"]
        last_layer = add_time_dimension(features, self.seq_lens)

        # Setup the LSTM cell
        lstm = tf1.nn.rnn_cell.LSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm.state_size.c, np.float32),
            np.zeros(lstm.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c_in, h_in = self.state_in
        else:
            c_in = tf1.placeholder(
                tf.float32, [None, lstm.state_size.c], name="c")
            h_in = tf1.placeholder(
                tf.float32, [None, lstm.state_size.h], name="h")
            self.state_in = [c_in, h_in]

        # Setup LSTM outputs
        state_in = tf1.nn.rnn_cell.LSTMStateTuple(c_in, h_in)
        lstm_out, lstm_state = tf1.nn.dynamic_rnn(
            lstm,
            last_layer,
            initial_state=state_in,
            sequence_length=self.seq_lens,
            time_major=False,
            dtype=tf.float32)

        self.state_out = list(lstm_state)

        # Compute outputs
        last_layer = tf.reshape(lstm_out, [-1, cell_size])
        logits = linear(last_layer, num_outputs, "action",
                        normc_initializer(0.01))
        return logits, last_layer
Example #13
0
    def __init__(self,
                 input_dict,
                 obs_space,
                 action_space,
                 num_outputs,
                 options,
                 state_in=None,
                 seq_lens=None):
        # Soft-deprecate this class. All Models should use the ModelV2
        # API from here on.
        deprecation_warning("Model", "ModelV2", error=False)
        assert isinstance(input_dict, dict), input_dict

        # Default attribute values for the non-RNN case
        self.state_init = []
        self.state_in = state_in or []
        self.state_out = []
        self.obs_space = obs_space
        self.action_space = action_space
        self.num_outputs = num_outputs
        self.options = options
        self.scope = tf.get_variable_scope()
        self.session = tf.get_default_session()
        self.input_dict = input_dict
        if seq_lens is not None:
            self.seq_lens = seq_lens
        else:
            self.seq_lens = tf.placeholder(dtype=tf.int32,
                                           shape=[None],
                                           name="seq_lens")

        self._num_outputs = num_outputs
        if options.get("free_log_std"):
            assert num_outputs % 2 == 0
            num_outputs = num_outputs // 2

        ok = True
        try:
            restored = input_dict.copy()
            restored["obs"] = restore_original_dimensions(
                input_dict["obs"], obs_space)
            self.outputs, self.last_layer = self._build_layers_v2(
                restored, num_outputs, options)
        except NotImplementedError:
            ok = False
        # In TF 1.14, you cannot construct variable scopes in exception
        # handlers so we have to set the OK flag and check it here:
        if not ok:
            self.outputs, self.last_layer = self._build_layers(
                input_dict["obs"], num_outputs, options)

        if options.get("free_log_std", False):
            log_std = tf.get_variable(name="log_std",
                                      shape=[num_outputs],
                                      initializer=tf.zeros_initializer)
            self.outputs = tf.concat(
                [self.outputs, 0.0 * self.outputs + log_std], 1)
Example #14
0
 def _ctor(*args, **kwargs):
     if log_once(old or obj.__name__):
         deprecation_warning(
             old=old or obj.__name__,
             new=new,
             help=help,
             error=error,
         )
     # Call the deprecated method/function.
     return obj(*args, **kwargs)
Example #15
0
    def _validate_config(
        config: ModelConfigDict, action_space: gym.spaces.Space, framework: str
    ) -> None:
        """Validates a given model config dict.

        Args:
            config: The "model" sub-config dict
                within the Trainer's config dict.
            action_space: The action space of the model, whose config are
                    validated.
            framework: One of "jax", "tf2", "tf", "tfe", or "torch".

        Raises:
            ValueError: If something is wrong with the given config.
        """
        # Soft-deprecate custom preprocessors.
        if config.get("custom_preprocessor") is not None:
            deprecation_warning(
                old="model.custom_preprocessor",
                new="gym.ObservationWrapper around your env or handle complex "
                "inputs inside your Model",
                error=False,
            )

        if config.get("use_attention") and config.get("use_lstm"):
            raise ValueError(
                "Only one of `use_lstm` or `use_attention` may be set to True!"
            )

        # For complex action spaces, only allow prev action inputs to
        # LSTMs and attention nets iff `_disable_action_flattening=True`.
        # TODO: `_disable_action_flattening=True` will be the default in
        #  the future.
        if (
            (
                config.get("lstm_use_prev_action")
                or config.get("attention_use_n_prev_actions", 0) > 0
            )
            and not config.get("_disable_action_flattening")
            and isinstance(action_space, (Tuple, Dict))
        ):
            raise ValueError(
                "For your complex action space (Tuple|Dict) and your model's "
                "`prev-actions` setup of your model, you must set "
                "`_disable_action_flattening=True` in your main config dict!"
            )

        if framework == "jax":
            if config.get("use_attention"):
                raise ValueError(
                    "`use_attention` not available for framework=jax so far!"
                )
            elif config.get("use_lstm"):
                raise ValueError("`use_lstm` not available for framework=jax so far!")
Example #16
0
    def timeslices(
        self,
        size: Optional[int] = None,
        num_slices: Optional[int] = None,
        k: Optional[int] = None,
    ) -> List["SampleBatch"]:
        """Returns SampleBatches, each one representing a k-slice of this one.

        Will start from timestep 0 and produce slices of size=k.

        Args:
            size: The size (in timesteps) of each returned SampleBatch.
            num_slices: The number of slices to produce.
            k: Deprecated: Use size or num_slices instead. The size
                (in timesteps) of each returned SampleBatch.

        Returns:
            The list of `num_slices` (new) SampleBatches or n (new)
            SampleBatches each one of size `size`.
        """
        if size is None and num_slices is None:
            deprecation_warning("k", "size or num_slices")
            assert k is not None
            size = k

        if size is None:
            assert isinstance(num_slices, int)

            slices = []
            left = len(self)
            start = 0
            while left:
                len_ = left // (num_slices - len(slices))
                stop = start + len_
                slices.append(self[start:stop])
                left -= len_
                start = stop

            return slices

        else:
            assert isinstance(size, int)

            slices = []
            left = len(self)
            start = 0
            while left:
                stop = start + size
                slices.append(self[start:stop])
                left -= size
                start = stop

            return slices
Example #17
0
def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
    """Returns a framework specific activation function, given a name string.

    Args:
        name (Optional[str]): One of "relu" (default), "tanh", "swish", or
            "linear" or None.
        framework (str): One of "tf" or "torch".

    Returns:
        A framework-specific activtion function. e.g. tf.nn.tanh or
            torch.nn.ReLU. None if name in ["linear", None].

    Raises:
        ValueError: If name is an unknown activation function.
    """
    deprecation_warning(
        "rllib/utils/framework.py::get_activation_fn",
        "rllib/models/utils.py::get_activation_fn",
        error=False)
    if framework == "torch":
        if name in ["linear", None]:
            return None
        if name in ["swish", "silu"]:
            from ray.rllib.utils.torch_ops import Swish
            return Swish
        _, nn = try_import_torch()
        if name == "relu":
            return nn.ReLU
        elif name == "tanh":
            return nn.Tanh
    elif framework == "jax":
        if name in ["linear", None]:
            return None
        jax, flax = try_import_jax()
        if name == "swish":
            return jax.nn.swish
        if name == "relu":
            return jax.nn.relu
        elif name == "tanh":
            return jax.nn.hard_tanh
    else:
        if name in ["linear", None]:
            return None
        if name == "swish":
            name = "silu"
        tf1, tf, tfv = try_import_tf()
        fn = getattr(tf.nn, name, None)
        if fn is not None:
            return fn

    raise ValueError("Unknown activation ({}) for framework={}!".format(
        name, framework))
Example #18
0
    def register_custom_model(model_name: str, model_class: type) -> None:
        """Register a custom model class by name.

        The model can be later used by specifying {"custom_model": model_name}
        in the model config.

        Args:
            model_name (str): Name to register the model under.
            model_class (type): Python class of the model.
        """
        if issubclass(model_class, tf.keras.Model):
            deprecation_warning(old="register_custom_model", error=False)
        _global_registry.register(RLLIB_MODEL, model_name, model_class)
Example #19
0
 def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None):
     if legacy_callbacks_dict:
         deprecation_warning(
             "callbacks dict interface",
             "a class extending rllib.algorithms.callbacks.DefaultCallbacks",
         )
     self.legacy_callbacks = legacy_callbacks_dict or {}
     if is_overridden(self.on_trainer_init):
         deprecation_warning(
             old="on_trainer_init(trainer, **kwargs)",
             new="on_algorithm_init(algorithm, **kwargs)",
             error=False,
         )
Example #20
0
def postprocess_advantages(policy,
                           sample_batch,
                           other_agent_batches=None,
                           episode=None):

    # Stub serving backward compatibility.
    deprecation_warning(
        old="rllib.agents.a3c.a3c_tf_policy.postprocess_advantages",
        new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
        error=False)

    return compute_gae_for_sample_batch(policy, sample_batch,
                                        other_agent_batches, episode)
    def _build_layers_v2(self, input_dict, num_outputs, options):
        # Hard deprecate this class. All Models should use the ModelV2
        # API from here on.
        deprecation_warning(
            "Model->VisionNetwork", "ModelV2->VisionNetwork", error=False)
        inputs = input_dict["obs"]
        filters = options.get("conv_filters")
        if not filters:
            filters = _get_filter_config(inputs.shape.as_list()[1:])

        activation = get_activation_fn(options.get("conv_activation"))

        with tf.name_scope("vision_net"):
            for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
                inputs = tf.layers.conv2d(
                    inputs,
                    out_size,
                    kernel,
                    stride,
                    activation=activation,
                    padding="same",
                    name="conv{}".format(i))
            out_size, kernel, stride = filters[-1]

            # skip final linear layer
            if options.get("no_final_linear"):
                fc_out = tf.layers.conv2d(
                    inputs,
                    num_outputs,
                    kernel,
                    stride,
                    activation=activation,
                    padding="valid",
                    name="fc_out")
                return flatten(fc_out), flatten(fc_out)

            fc1 = tf.layers.conv2d(
                inputs,
                out_size,
                kernel,
                stride,
                activation=activation,
                padding="valid",
                name="fc1")
            fc2 = tf.layers.conv2d(
                fc1,
                num_outputs, [1, 1],
                activation=None,
                padding="same",
                name="fc2")
            return flatten(fc2), flatten(fc1)
Example #22
0
def add_advantages(policy: Policy,
                   sample_batch: SampleBatch,
                   other_agent_batches: Optional[Dict[PolicyID,
                                                      SampleBatch]] = None,
                   episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:

    # Stub serving backward compatibility.
    deprecation_warning(
        old="rllib.agents.a3c.a3c_torch_policy.add_advantages",
        new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
        error=False)

    return compute_gae_for_sample_batch(policy, sample_batch,
                                        other_agent_batches, episode)
Example #23
0
def postprocess_ppo_gae(
        policy: Policy,
        sample_batch: SampleBatch,
        other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
        episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:

    # Stub serving backward compatibility.
    deprecation_warning(
        old="rllib.agents.ppo.ppo_tf_policy.postprocess_ppo_gae",
        new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
        error=False)

    return compute_gae_for_sample_batch(policy, sample_batch,
                                        other_agent_batches, episode)
Example #24
0
    def __init__(self, *args, **kwargs):
        # DEPRECATED class: Use SampleBatch instead!
        deprecation_warning(
            old="UsageTrackingDict",
            new="SampleBatch",
            error=False,
        )

        dict.__init__(self, *args, **kwargs)
        self.accessed_keys = set()
        self.added_keys = set()
        self.deleted_keys = set()
        self.intercepted_values = {}
        self.get_interceptor = None
Example #25
0
File: ddpg.py Project: rawsh-bt/ray
def validate_config(config):
    if config["model"]["custom_model"]:
        logger.warning(
            "Setting use_state_preprocessor=True since a custom model "
            "was specified.")
        config["use_state_preprocessor"] = True

    # TODO(sven): Remove at some point.
    #  Backward compatibility of noise-based exploration config.
    schedule_max_timesteps = None
    if config.get("schedule_max_timesteps", DEPRECATED_VALUE) != \
            DEPRECATED_VALUE:
        deprecation_warning("schedule_max_timesteps",
                            "exploration_config.scale_timesteps")
        schedule_max_timesteps = config["schedule_max_timesteps"]
    if config.get("exploration_final_scale", DEPRECATED_VALUE) != \
            DEPRECATED_VALUE:
        deprecation_warning("exploration_final_scale",
                            "exploration_config.final_scale")
        if isinstance(config["exploration_config"], dict):
            config["exploration_config"]["final_scale"] = \
                config.pop("exploration_final_scale")
    if config.get("exploration_fraction", DEPRECATED_VALUE) != \
            DEPRECATED_VALUE:
        assert schedule_max_timesteps is not None
        deprecation_warning("exploration_fraction",
                            "exploration_config.scale_timesteps")
        if isinstance(config["exploration_config"], dict):
            config["exploration_config"]["scale_timesteps"] = config.pop(
                "exploration_fraction") * schedule_max_timesteps
    if config.get("per_worker_exploration", DEPRECATED_VALUE) != \
            DEPRECATED_VALUE:
        deprecation_warning(
            "per_worker_exploration",
            "exploration_config.type=PerWorkerOrnsteinUhlenbeckNoise")
        if isinstance(config["exploration_config"], dict):
            config["exploration_config"]["type"] = \
                PerWorkerOrnsteinUhlenbeckNoise

    if config.get("parameter_noise", DEPRECATED_VALUE) != DEPRECATED_VALUE:
        deprecation_warning("parameter_noise", "exploration_config={"
                            "type=ParameterNoise}")

    if config["exploration_config"]["type"] == "ParameterNoise":
        if config["batch_mode"] != "complete_episodes":
            logger.warning(
                "ParameterNoise Exploration requires `batch_mode` to be "
                "'complete_episodes'. Setting batch_mode=complete_episodes.")
            config["batch_mode"] = "complete_episodes"
Example #26
0
    def policy_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> PolicyID:
        """Returns and stores the policy ID for the specified agent.

        If the agent is new, the policy mapping fn will be called to bind the
        agent to a policy for the duration of the entire episode (even if the
        policy_mapping_fn is changed in the meantime!).

        Args:
            agent_id: The agent ID to lookup the policy ID for.

        Returns:
            The policy ID for the specified agent.
        """

        # Perform a new policy_mapping_fn lookup and bind AgentID for the
        # duration of this episode to the returned PolicyID.
        if agent_id not in self._agent_to_policy:
            # Try new API: pass in agent_id and episode as named args.
            # New signature should be: (agent_id, episode, worker, **kwargs)
            try:
                policy_id = self._agent_to_policy[agent_id] = self.policy_mapping_fn(
                    agent_id, self, worker=self.worker
                )
            except TypeError as e:
                if (
                    "positional argument" in e.args[0]
                    or "unexpected keyword argument" in e.args[0]
                ):
                    if log_once("policy_mapping_new_signature"):
                        deprecation_warning(
                            old="policy_mapping_fn(agent_id)",
                            new="policy_mapping_fn(agent_id, episode, "
                            "worker, **kwargs)",
                        )
                    policy_id = self._agent_to_policy[
                        agent_id
                    ] = self.policy_mapping_fn(agent_id)
                else:
                    raise e
        # Use already determined PolicyID.
        else:
            policy_id = self._agent_to_policy[agent_id]

        # PolicyID not found in policy map -> Error.
        if policy_id not in self.policy_map:
            raise KeyError(
                "policy_mapping_fn returned invalid policy id " f"'{policy_id}'!"
            )
        return policy_id
Example #27
0
 def export_checkpoint(self,
                       export_dir: str,
                       filename_prefix: str = "model") -> None:
     """Export tensorflow checkpoint to export_dir."""
     deprecation_warning("export_checkpoint", "save")
     try:
         os.makedirs(export_dir)
     except OSError as e:
         # ignore error if export dir already exists
         if e.errno != errno.EEXIST:
             raise
     save_path = os.path.join(export_dir, filename_prefix)
     with self._sess.graph.as_default():
         saver = tf1.train.Saver()
         saver.save(self._sess, save_path)
Example #28
0
def validate_config(config):
    # PyTorch check.
    if config["use_pytorch"]:
        raise ValueError("DDPG does not support PyTorch yet! Use tf instead.")

    # TODO(sven): Remove at some point.
    #  Backward compatibility of noise-based exploration config.
    schedule_max_timesteps = None
    if config.get("schedule_max_timesteps", DEPRECATED_VALUE) != \
            DEPRECATED_VALUE:
        deprecation_warning("schedule_max_timesteps",
                            "exploration_config.scale_timesteps")
        schedule_max_timesteps = config["schedule_max_timesteps"]
    if config.get("exploration_final_scale", DEPRECATED_VALUE) != \
            DEPRECATED_VALUE:
        deprecation_warning("exploration_final_scale",
                            "exploration_config.final_scale")
        if isinstance(config["exploration_config"], dict):
            config["exploration_config"]["final_scale"] = \
                config.pop("exploration_final_scale")
    if config.get("exploration_fraction", DEPRECATED_VALUE) != \
            DEPRECATED_VALUE:
        assert schedule_max_timesteps is not None
        deprecation_warning("exploration_fraction",
                            "exploration_config.scale_timesteps")
        if isinstance(config["exploration_config"], dict):
            config["exploration_config"]["scale_timesteps"] = config.pop(
                "exploration_fraction") * schedule_max_timesteps
    if config.get("per_worker_exploration", DEPRECATED_VALUE) != \
            DEPRECATED_VALUE:
        deprecation_warning(
            "per_worker_exploration",
            "exploration_config.type=PerWorkerOrnsteinUhlenbeckNoise")
        if isinstance(config["exploration_config"], dict):
            config["exploration_config"]["type"] = \
                PerWorkerOrnsteinUhlenbeckNoise

    if config.get("parameter_noise", DEPRECATED_VALUE) != DEPRECATED_VALUE:
        deprecation_warning("parameter_noise", "exploration_config={"
                            "type=ParameterNoise"
                            "}")

    if config["exploration_config"]["type"] == "ParameterNoise":
        if config["batch_mode"] != "complete_episodes":
            logger.warning(
                "ParameterNoise Exploration requires `batch_mode` to be "
                "'complete_episodes'. Setting batch_mode=complete_episodes.")
            config["batch_mode"] = "complete_episodes"
Example #29
0
    def _build_layers(self, inputs, num_outputs, options):
        """Process the flattened inputs.

        Note that dict inputs will be flattened into a vector. To define a
        model that processes the components separately, use _build_layers_v2().
        """
        # Soft deprecate this class. All Models should use the ModelV2
        # API from here on.
        deprecation_warning("Model->FullyConnectedNetwork",
                            "ModelV2->FullyConnectedNetwork",
                            error=False)

        hiddens = options.get("fcnet_hiddens")
        activation = get_activation_fn(options.get("fcnet_activation"))

        if len(inputs.shape) > 2:
            inputs = tf.layers.flatten(inputs)

        with tf.name_scope("fc_net"):
            i = 1
            last_layer = inputs
            for size in hiddens:
                # skip final linear layer
                if options.get("no_final_linear") and i == len(hiddens):
                    output = tf.layers.dense(
                        last_layer,
                        num_outputs,
                        kernel_initializer=normc_initializer(1.0),
                        activation=activation,
                        name="fc_out")
                    return output, output

                label = "fc{}".format(i)
                last_layer = tf.layers.dense(
                    last_layer,
                    size,
                    kernel_initializer=normc_initializer(1.0),
                    activation=activation,
                    name=label)
                i += 1

            output = tf.layers.dense(
                last_layer,
                num_outputs,
                kernel_initializer=normc_initializer(0.01),
                activation=None,
                name="fc_out")
            return output, last_layer
Example #30
0
def validate_config(config: TrainerConfigDict) -> None:
    """Validates the Trainer's config dict.

    Args:
        config (TrainerConfigDict): The Trainer's config to check.

    Raises:
        ValueError: In case something is wrong with the config.
    """
    if config["use_state_preprocessor"] != DEPRECATED_VALUE:
        deprecation_warning(old="config['use_state_preprocessor']",
                            error=False)
        config["use_state_preprocessor"] = DEPRECATED_VALUE

    if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
        raise ValueError("`grad_clip` value must be > 0.0!")