def get_exploration_optimizer(self, optimizers):
        # Create, but don't add Adam for curiosity NN updating to the policy.
        # If we added and returned it here, it would be used in the policy's
        # update loop, which we don't want (curiosity updating happens inside
        # `postprocess_trajectory`).
        if self.framework == "torch":
            feature_params = list(self._curiosity_feature_net.parameters())
            inverse_params = list(self._curiosity_inverse_fcnet.parameters())
            forward_params = list(self._curiosity_forward_fcnet.parameters())

            # Now that the Policy's own optimizer(s) have been created (from
            # the Model parameters (IMPORTANT: w/o(!) the curiosity params),
            # we can add our curiosity sub-modules to the Policy's Model.
            self.model._curiosity_feature_net = \
                self._curiosity_feature_net.to(self.device)
            self.model._curiosity_inverse_fcnet = \
                self._curiosity_inverse_fcnet.to(self.device)
            self.model._curiosity_forward_fcnet = \
                self._curiosity_forward_fcnet.to(self.device)
            self._optimizer = torch.optim.Adam(forward_params +
                                               inverse_params + feature_params,
                                               lr=self.lr)
        else:
            self.model._curiosity_feature_net = self._curiosity_feature_net
            self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet
            self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet
            # Feature net is a RLlib ModelV2, the other 2 are keras Models.
            self._optimizer_var_list = \
                self._curiosity_feature_net.base_model.variables + \
                self._curiosity_inverse_fcnet.variables + \
                self._curiosity_forward_fcnet.variables
            self.model.register_variables(self._optimizer_var_list)
            self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr)
            # Create placeholders and initialize the loss.
            if self.framework == "tf":
                self._obs_ph = get_placeholder(space=self.model.obs_space,
                                               name="_curiosity_obs")
                self._next_obs_ph = get_placeholder(space=self.model.obs_space,
                                                    name="_curiosity_next_obs")
                self._action_ph = get_placeholder(
                    space=self.model.action_space, name="_curiosity_action")
                self._forward_l2_norm_sqared, self._update_op = \
                    self._postprocess_helper_tf(
                        self._obs_ph, self._next_obs_ph, self._action_ph)

        return optimizers
    def _get_input_dict_and_dummy_batch(self, view_requirements,
                                        existing_inputs):
        """Creates input_dict and dummy_batch for loss initialization.

        Used for managing the Policy's input placeholders and for loss
        initialization.
        Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.

        Args:
            view_requirements (ViewReqs): The view requirements dict.
            existing_inputs (Dict[str, tf.placeholder]): A dict of already
                existing placeholders.

        Returns:
            Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
                input_dict/dummy_batch tuple.
        """
        input_dict = {}
        dummy_batch = {}
        for view_col, view_req in view_requirements.items():
            # Point state_in to the already existing self._state_inputs.
            mo = re.match("state_in_(\d+)", view_col)
            if mo is not None:
                input_dict[view_col] = self._state_inputs[int(mo.group(1))]
                dummy_batch[view_col] = np.zeros_like(
                    [view_req.space.sample()])
            # State-outs (no placeholders needed).
            elif view_col.startswith("state_out_"):
                dummy_batch[view_col] = np.zeros_like(
                    [view_req.space.sample()])
            # Skip action dist inputs placeholder (do later).
            elif view_col == SampleBatch.ACTION_DIST_INPUTS:
                continue
            elif view_col in existing_inputs:
                input_dict[view_col] = existing_inputs[view_col]
                dummy_batch[view_col] = np.zeros(
                    shape=[
                        1 if s is None else s
                        for s in existing_inputs[view_col].shape.as_list()
                    ],
                    dtype=existing_inputs[view_col].dtype.as_numpy_dtype)
            # All others.
            else:
                if view_req.used_for_training:
                    input_dict[view_col] = get_placeholder(
                        space=view_req.space, name=view_col)
                dummy_batch[view_col] = np.zeros_like(
                    [view_req.space.sample()])
        return input_dict, dummy_batch
Beispiel #3
0
    def _get_input_dict_and_dummy_batch(self, view_requirements,
                                        existing_inputs):
        """Creates input_dict and dummy_batch for loss initialization.

        Used for managing the Policy's input placeholders and for loss
        initialization.
        Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.

        Args:
            view_requirements (ViewReqs): The view requirements dict.
            existing_inputs (Dict[str, tf.placeholder]): A dict of already
                existing placeholders.

        Returns:
            Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
                input_dict/dummy_batch tuple.
        """
        input_dict = {}
        for view_col, view_req in view_requirements.items():
            # Point state_in to the already existing self._state_inputs.
            mo = re.match("state_in_(\d+)", view_col)
            if mo is not None:
                input_dict[view_col] = self._state_inputs[int(mo.group(1))]
            # State-outs (no placeholders needed).
            elif view_col.startswith("state_out_"):
                continue
            # Skip action dist inputs placeholder (do later).
            elif view_col == SampleBatch.ACTION_DIST_INPUTS:
                continue
            elif view_col in existing_inputs:
                input_dict[view_col] = existing_inputs[view_col]
            # All others.
            else:
                time_axis = not isinstance(view_req.shift, int)
                if view_req.used_for_training:
                    # Create a +time-axis placeholder if the shift is not an
                    # int (range or list of ints).
                    input_dict[view_col] = get_placeholder(
                        space=view_req.space,
                        name=view_col,
                        time_axis=time_axis)
        dummy_batch = self._get_dummy_batch_from_view_requirements(
            batch_size=32)

        return input_dict, dummy_batch
Beispiel #4
0
    def __init__(
            self,
            obs_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            config: TrainerConfigDict,
            loss_fn: Callable[
                [Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
                TensorType],
            *,
            stats_fn: Optional[Callable[[Policy, SampleBatch],
                                        Dict[str, TensorType]]] = None,
            grad_stats_fn: Optional[
                Callable[[Policy, SampleBatch, ModelGradients],
                         Dict[str, TensorType]]] = None,
            before_loss_init: Optional[Callable[[
                Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
            ], None]] = None,
            make_model: Optional[Callable[[
                Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
            ], ModelV2]] = None,
            action_sampler_fn: Optional[Callable[
                [TensorType, List[TensorType]], Tuple[TensorType,
                                                      TensorType]]] = None,
            action_distribution_fn: Optional[
                Callable[[Policy, ModelV2, TensorType, TensorType, TensorType],
                         Tuple[TensorType, type, List[TensorType]]]] = None,
            existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
            existing_model: Optional[ModelV2] = None,
            get_batch_divisibility_req: Optional[Callable[[Policy],
                                                          int]] = None,
            obs_include_prev_action_reward: bool = True):
        """Initialize a dynamic TF policy.

        Args:
            observation_space (gym.spaces.Space): Observation space of the
                policy.
            action_space (gym.spaces.Space): Action space of the policy.
            config (TrainerConfigDict): Policy-specific configuration data.
            loss_fn (Callable[[Policy, ModelV2, Type[TFActionDistribution],
                SampleBatch], TensorType]): Function that returns a loss tensor
                for the policy graph.
            stats_fn (Optional[Callable[[Policy, SampleBatch],
                Dict[str, TensorType]]]): Optional function that returns a dict
                of TF fetches given the policy and batch input tensors.
            grad_stats_fn (Optional[Callable[[Policy, SampleBatch,
                ModelGradients], Dict[str, TensorType]]]):
                Optional function that returns a dict of TF fetches given the
                policy, sample batch, and loss gradient tensors.
            before_loss_init (Optional[Callable[
                [Policy, gym.spaces.Space, gym.spaces.Space,
                TrainerConfigDict], None]]): Optional function to run prior to
                loss init that takes the same arguments as __init__.
            make_model (Optional[Callable[[Policy, gym.spaces.Space,
                gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional
                function that returns a ModelV2 object given
                policy, obs_space, action_space, and policy config.
                All policy variables should be created in this function. If not
                specified, a default model will be created.
            action_sampler_fn (Optional[Callable[[Policy, ModelV2, Dict[
                str, TensorType], TensorType, TensorType], Tuple[TensorType,
                TensorType]]]): A callable returning a sampled action and its
                log-likelihood given Policy, ModelV2, input_dict, explore,
                timestep, and is_training.
            action_distribution_fn (Optional[Callable[[Policy, ModelV2,
                Dict[str, TensorType], TensorType, TensorType],
                Tuple[TensorType, type, List[TensorType]]]]): A callable
                returning distribution inputs (parameters), a dist-class to
                generate an action distribution object from, and
                internal-state outputs (or an empty list if not applicable).
                Note: No Exploration hooks have to be called from within
                `action_distribution_fn`. It's should only perform a simple
                forward pass through some model.
                If None, pass inputs through `self.model()` to get distribution
                inputs.
                The callable takes as inputs: Policy, ModelV2, input_dict,
                explore, timestep, is_training.
            existing_inputs (Optional[Dict[str, tf1.placeholder]]): When
                copying a policy, this specifies an existing dict of
                placeholders to use instead of defining new ones.
            existing_model (Optional[ModelV2]): When copying a policy, this
                specifies an existing model to clone and share weights with.
            get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
                Optional callable that returns the divisibility requirement for
                sample batches. If None, will assume a value of 1.
            obs_include_prev_action_reward (bool): Whether to include the
                previous action and reward in the model input (default: True).
        """
        self.observation_space = obs_space
        self.action_space = action_space
        self.config = config
        self.framework = "tf"
        self._loss_fn = loss_fn
        self._stats_fn = stats_fn
        self._grad_stats_fn = grad_stats_fn
        self._obs_include_prev_action_reward = obs_include_prev_action_reward

        dist_class = dist_inputs = None
        if action_sampler_fn or action_distribution_fn:
            if not make_model:
                raise ValueError(
                    "`make_model` is required if `action_sampler_fn` OR "
                    "`action_distribution_fn` is given")
        else:
            dist_class, logit_dim = ModelCatalog.get_action_dist(
                action_space, self.config["model"])

        # Setup self.model.
        if existing_model:
            if isinstance(existing_model, list):
                self.model = existing_model[0]
                # TODO: (sven) hack, but works for `target_[q_]?model`.
                for i in range(1, len(existing_model)):
                    setattr(self, existing_model[i][0], existing_model[i][1])
        elif make_model:
            self.model = make_model(self, obs_space, action_space, config)
        else:
            self.model = ModelCatalog.get_model_v2(
                obs_space=obs_space,
                action_space=action_space,
                num_outputs=logit_dim,
                model_config=self.config["model"],
                framework="tf")
        # Auto-update model's inference view requirements, if recurrent.
        self._update_model_view_requirements_from_init_state()

        if existing_inputs:
            self._state_inputs = [
                v for k, v in existing_inputs.items()
                if k.startswith("state_in_")
            ]
            if self._state_inputs:
                self._seq_lens = existing_inputs["seq_lens"]
        else:
            if self.config["_use_trajectory_view_api"]:
                self._state_inputs = [
                    get_placeholder(
                        space=vr.space,
                        time_axis=not isinstance(vr.shift, int),
                    ) for k, vr in self.model.view_requirements.items()
                    if k.startswith("state_in_")
                ]
            else:
                self._state_inputs = [
                    tf1.placeholder(shape=(None, ) + s.shape, dtype=s.dtype)
                    for s in self.model.get_initial_state()
                ]

        # Use default settings.
        # Add NEXT_OBS, STATE_IN_0.., and others.
        self.view_requirements = self._get_default_view_requirements()
        # Combine view_requirements for Model and Policy.
        self.view_requirements.update(self.model.view_requirements)

        # Setup standard placeholders.
        if existing_inputs is not None:
            timestep = existing_inputs["timestep"]
            explore = existing_inputs["is_exploring"]
            self._input_dict, self._dummy_batch = \
                self._get_input_dict_and_dummy_batch(
                    self.view_requirements, existing_inputs)
        else:
            action_ph = ModelCatalog.get_action_placeholder(action_space)
            if self.config["_use_trajectory_view_api"]:
                prev_action_ph = {}
                if SampleBatch.PREV_ACTIONS not in self.view_requirements:
                    prev_action_ph = {
                        SampleBatch.PREV_ACTIONS:
                        ModelCatalog.get_action_placeholder(
                            action_space, "prev_action")
                    }
                self._input_dict, self._dummy_batch = \
                    self._get_input_dict_and_dummy_batch(
                        self.view_requirements,
                        dict({SampleBatch.ACTIONS: action_ph},
                             **prev_action_ph))
            else:
                prev_action_ph = ModelCatalog.get_action_placeholder(
                    action_space, "prev_action")
                self._input_dict = {
                    SampleBatch.CUR_OBS:
                    tf1.placeholder(tf.float32,
                                    shape=[None] + list(obs_space.shape),
                                    name="observation")
                }
                self._input_dict[SampleBatch.ACTIONS] = action_ph
                if self._obs_include_prev_action_reward:
                    self._input_dict.update({
                        SampleBatch.PREV_ACTIONS:
                        prev_action_ph,
                        SampleBatch.PREV_REWARDS:
                        tf1.placeholder(tf.float32, [None],
                                        name="prev_reward"),
                    })
            # Placeholder for (sampling steps) timestep (int).
            timestep = tf1.placeholder_with_default(tf.zeros((),
                                                             dtype=tf.int64),
                                                    (),
                                                    name="timestep")
            # Placeholder for `is_exploring` flag.
            explore = tf1.placeholder_with_default(True, (),
                                                   name="is_exploring")

        # Placeholder for RNN time-chunk valid lengths.
        self._seq_lens = tf1.placeholder(dtype=tf.int32,
                                         shape=[None],
                                         name="seq_lens")
        # Placeholder for `is_training` flag.
        self._input_dict["is_training"] = self._get_is_training_placeholder()

        # Create the Exploration object to use for this Policy.
        self.exploration = self._create_exploration()

        # Fully customized action generation (e.g., custom policy).
        if action_sampler_fn:
            sampled_action, sampled_action_logp = action_sampler_fn(
                self,
                self.model,
                obs_batch=self._input_dict[SampleBatch.CUR_OBS],
                state_batches=self._state_inputs,
                seq_lens=self._seq_lens,
                prev_action_batch=self._input_dict.get(
                    SampleBatch.PREV_ACTIONS),
                prev_reward_batch=self._input_dict.get(
                    SampleBatch.PREV_REWARDS),
                explore=explore,
                is_training=self._input_dict["is_training"])
        # Distribution generation is customized, e.g., DQN, DDPG.
        else:
            if action_distribution_fn:

                # Try new action_distribution_fn signature, supporting
                # state_batches and seq_lens.
                in_dict = self._input_dict
                try:
                    dist_inputs, dist_class, self._state_out = \
                        action_distribution_fn(
                            self,
                            self.model,
                            input_dict=in_dict,
                            state_batches=self._state_inputs,
                            seq_lens=self._seq_lens,
                            explore=explore,
                            timestep=timestep,
                            is_training=in_dict["is_training"])
                # Trying the old way (to stay backward compatible).
                # TODO: Remove in future.
                except TypeError as e:
                    if "positional argument" in e.args[0] or \
                            "unexpected keyword argument" in e.args[0]:
                        dist_inputs, dist_class, self._state_out = \
                            action_distribution_fn(
                                self, self.model,
                                obs_batch=in_dict[SampleBatch.CUR_OBS],
                                state_batches=self._state_inputs,
                                seq_lens=self._seq_lens,
                                prev_action_batch=in_dict.get(
                                    SampleBatch.PREV_ACTIONS),
                                prev_reward_batch=in_dict.get(
                                    SampleBatch.PREV_REWARDS),
                                explore=explore,
                                is_training=in_dict["is_training"])
                    else:
                        raise e

            # Default distribution generation behavior:
            # Pass through model. E.g., PG, PPO.
            else:
                dist_inputs, self._state_out = self.model(
                    self._input_dict, self._state_inputs, self._seq_lens)

            action_dist = dist_class(dist_inputs, self.model)

            # Using exploration to get final action (e.g. via sampling).
            sampled_action, sampled_action_logp = \
                self.exploration.get_exploration_action(
                    action_distribution=action_dist,
                    timestep=timestep,
                    explore=explore)

        # Phase 1 init.
        sess = tf1.get_default_session() or tf1.Session()

        batch_divisibility_req = get_batch_divisibility_req(self) if \
            callable(get_batch_divisibility_req) else \
            (get_batch_divisibility_req or 1)

        super().__init__(
            observation_space=obs_space,
            action_space=action_space,
            config=config,
            sess=sess,
            obs_input=self._input_dict[SampleBatch.OBS],
            action_input=self._input_dict[SampleBatch.ACTIONS],
            sampled_action=sampled_action,
            sampled_action_logp=sampled_action_logp,
            dist_inputs=dist_inputs,
            dist_class=dist_class,
            loss=None,  # dynamically initialized on run
            loss_inputs=[],
            model=self.model,
            state_inputs=self._state_inputs,
            state_outputs=self._state_out,
            prev_action_input=self._input_dict.get(SampleBatch.PREV_ACTIONS),
            prev_reward_input=self._input_dict.get(SampleBatch.PREV_REWARDS),
            seq_lens=self._seq_lens,
            max_seq_len=config["model"]["max_seq_len"],
            batch_divisibility_req=batch_divisibility_req,
            explore=explore,
            timestep=timestep)

        # Phase 2 init.
        if before_loss_init is not None:
            before_loss_init(self, obs_space, action_space, config)

        # Loss initialization and model/postprocessing test calls.
        if not existing_inputs:
            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True)
Beispiel #5
0
    def _initialize_loss_from_dummy_batch(
            self,
            auto_remove_unneeded_view_reqs: bool = True,
            stats_fn=None) -> None:

        # Create the optimizer/exploration optimizer here. Some initialization
        # steps (e.g. exploration postprocessing) may need this.
        self._optimizer = self.optimizer()

        # Test calls depend on variable init, so initialize model first.
        self._sess.run(tf1.global_variables_initializer())

        if self.config["_use_trajectory_view_api"]:
            logger.info("Testing `compute_actions` w/ dummy batch.")
            actions, state_outs, extra_fetches = \
                self.compute_actions_from_input_dict(
                    self._dummy_batch, explore=False, timestep=0)
            for key, value in extra_fetches.items():
                self._dummy_batch[key] = np.zeros_like(value)
                self._input_dict[key] = get_placeholder(value=value, name=key)
                if key not in self.view_requirements:
                    logger.info("Adding extra-action-fetch `{}` to "
                                "view-reqs.".format(key))
                    self.view_requirements[key] = \
                        ViewRequirement(space=gym.spaces.Box(
                            -1.0, 1.0, shape=value.shape[1:],
                            dtype=value.dtype))
            dummy_batch = self._dummy_batch
        else:

            def fake_array(tensor):
                shape = tensor.shape.as_list()
                shape = [s if s is not None else 1 for s in shape]
                return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

            dummy_batch = {
                SampleBatch.CUR_OBS:
                fake_array(self._obs_input),
                SampleBatch.NEXT_OBS:
                fake_array(self._obs_input),
                SampleBatch.DONES:
                np.array([False], dtype=np.bool),
                SampleBatch.ACTIONS:
                fake_array(
                    ModelCatalog.get_action_placeholder(self.action_space)),
                SampleBatch.REWARDS:
                np.array([0], dtype=np.float32),
            }
            if self._obs_include_prev_action_reward:
                dummy_batch.update({
                    SampleBatch.PREV_ACTIONS:
                    fake_array(self._prev_action_input),
                    SampleBatch.PREV_REWARDS:
                    fake_array(self._prev_reward_input),
                })
            state_init = self.get_initial_state()
            state_batches = []
            for i, h in enumerate(state_init):
                dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
                dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
                state_batches.append(np.expand_dims(h, 0))
            if state_init:
                dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
            for k, v in self.extra_compute_action_fetches().items():
                dummy_batch[k] = fake_array(v)
            dummy_batch = SampleBatch(dummy_batch)

        batch_for_postproc = UsageTrackingDict(dummy_batch)
        batch_for_postproc.count = dummy_batch.count
        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        self.exploration.postprocess_trajectory(self, batch_for_postproc,
                                                self._sess)
        postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
        # Add new columns automatically to (loss) input_dict.
        if self.config["_use_trajectory_view_api"]:
            for key in batch_for_postproc.added_keys:
                if key not in self._input_dict:
                    self._input_dict[key] = get_placeholder(
                        value=batch_for_postproc[key], name=key)
                if key not in self.view_requirements:
                    self.view_requirements[key] = \
                        ViewRequirement(space=gym.spaces.Box(
                            -1.0, 1.0, shape=batch_for_postproc[key].shape[1:],
                            dtype=batch_for_postproc[key].dtype))

        if not self.config["_use_trajectory_view_api"]:
            train_batch = UsageTrackingDict(
                dict({
                    SampleBatch.CUR_OBS: self._obs_input,
                }, **self._loss_input_dict))
            if self._obs_include_prev_action_reward:
                train_batch.update({
                    SampleBatch.PREV_ACTIONS: self._prev_action_input,
                    SampleBatch.PREV_REWARDS: self._prev_reward_input,
                    SampleBatch.CUR_OBS: self._obs_input,
                })

            for k, v in postprocessed_batch.items():
                if k in train_batch:
                    continue
                elif v.dtype == np.object:
                    continue  # can't handle arbitrary objects in TF
                elif k == "seq_lens" or k.startswith("state_in_"):
                    continue
                shape = (None, ) + v.shape[1:]
                dtype = np.float32 if v.dtype == np.float64 else v.dtype
                placeholder = tf1.placeholder(dtype, shape=shape, name=k)
                train_batch[k] = placeholder

            for i, si in enumerate(self._state_inputs):
                train_batch["state_in_{}".format(i)] = si
        else:
            train_batch = UsageTrackingDict(
                dict(self._input_dict, **self._loss_input_dict))

        if self._state_inputs:
            train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})
        loss = self._do_loss_init(train_batch)

        all_accessed_keys = \
            train_batch.accessed_keys | batch_for_postproc.accessed_keys | \
            batch_for_postproc.added_keys | set(
                self.model.view_requirements.keys())

        TFPolicy._initialize_loss(
            self, loss,
            [(k, v) for k, v in train_batch.items() if k in all_accessed_keys])

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))

        # Add new columns automatically to view-reqs.
        if self.config["_use_trajectory_view_api"] and \
                auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | \
                                batch_for_postproc.accessed_keys
            # Tag those only needed for post-processing (with some exceptions).
            for key in batch_for_postproc.accessed_keys:
                if key not in train_batch.accessed_keys and \
                        key not in self.model.view_requirements and \
                        key not in [
                            SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID, SampleBatch.DONES,
                            SampleBatch.REWARDS, SampleBatch.INFOS]:
                    if key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES, REWARDS, and INFOS, no matter what.
            for key in list(self.view_requirements.keys()):
                if key not in all_accessed_keys and key not in [
                    SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                    SampleBatch.UNROLL_ID, SampleBatch.DONES,
                    SampleBatch.REWARDS, SampleBatch.INFOS] and \
                        key not in self.model.view_requirements:
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in batch_for_postproc.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    else:
                        del self.view_requirements[key]
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = \
                        vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }

        # Initialize again after loss init.
        self._sess.run(tf1.global_variables_initializer())
Beispiel #6
0
    def _initialize_loss_from_dummy_batch(
            self,
            auto_remove_unneeded_view_reqs: bool = True,
            stats_fn=None) -> None:

        # Create the optimizer/exploration optimizer here. Some initialization
        # steps (e.g. exploration postprocessing) may need this.
        self._optimizer = self.optimizer()

        # Test calls depend on variable init, so initialize model first.
        self.get_session().run(tf1.global_variables_initializer())

        logger.info("Testing `compute_actions` w/ dummy batch.")
        actions, state_outs, extra_fetches = \
            self.compute_actions_from_input_dict(
                self._dummy_batch, explore=False, timestep=0)
        for key, value in extra_fetches.items():
            self._dummy_batch[key] = value
            self._input_dict[key] = get_placeholder(value=value, name=key)
            if key not in self.view_requirements:
                logger.info("Adding extra-action-fetch `{}` to "
                            "view-reqs.".format(key))
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(-1.0,
                                         1.0,
                                         shape=value.shape[1:],
                                         dtype=value.dtype),
                    used_for_compute_actions=False,
                )
        dummy_batch = self._dummy_batch

        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        self.exploration.postprocess_trajectory(self, dummy_batch,
                                                self.get_session())
        _ = self.postprocess_trajectory(dummy_batch)
        # Add new columns automatically to (loss) input_dict.
        for key in dummy_batch.added_keys:
            if key not in self._input_dict:
                self._input_dict[key] = get_placeholder(value=dummy_batch[key],
                                                        name=key)
            if key not in self.view_requirements:
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(-1.0,
                                         1.0,
                                         shape=dummy_batch[key].shape[1:],
                                         dtype=dummy_batch[key].dtype),
                    used_for_compute_actions=False,
                )

        train_batch = SampleBatch(
            dict(self._input_dict, **self._loss_input_dict))

        if self._state_inputs:
            train_batch["seq_lens"] = self._seq_lens
            self._loss_input_dict.update({"seq_lens": train_batch["seq_lens"]})

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        loss = self._do_loss_init(train_batch)

        all_accessed_keys = \
            train_batch.accessed_keys | dummy_batch.accessed_keys | \
            dummy_batch.added_keys | set(
                self.model.view_requirements.keys())

        TFPolicy._initialize_loss(
            self, loss,
            [(k, v)
             for k, v in train_batch.items() if k in all_accessed_keys] +
            ([("seq_lens",
               train_batch["seq_lens"])] if "seq_lens" in train_batch else []))

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))

        # Add new columns automatically to view-reqs.
        if auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | \
                                dummy_batch.accessed_keys
            # Tag those only needed for post-processing (with some exceptions).
            for key in dummy_batch.accessed_keys:
                if key not in train_batch.accessed_keys and \
                        key not in self.model.view_requirements and \
                        key not in [
                            SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID, SampleBatch.DONES,
                            SampleBatch.REWARDS, SampleBatch.INFOS]:
                    if key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES, REWARDS, and INFOS, no matter what.
            for key in list(self.view_requirements.keys()):
                if key not in all_accessed_keys and key not in [
                    SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                    SampleBatch.UNROLL_ID, SampleBatch.DONES,
                    SampleBatch.REWARDS, SampleBatch.INFOS] and \
                        key not in self.model.view_requirements:
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in dummy_batch.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    else:
                        del self.view_requirements[key]
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = \
                        vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }

        # Initialize again after loss init.
        self.get_session().run(tf1.global_variables_initializer())
Beispiel #7
0
    def __init__(
            self,
            obs_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            config: TrainerConfigDict,
            loss_fn: Callable[[
                Policy, ModelV2, Type[TFActionDistribution], SampleBatch
            ], TensorType],
            *,
            stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
                str, TensorType]]] = None,
            grad_stats_fn: Optional[Callable[[
                Policy, SampleBatch, ModelGradients
            ], Dict[str, TensorType]]] = None,
            before_loss_init: Optional[Callable[[
                Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
            ], None]] = None,
            make_model: Optional[Callable[[
                Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
            ], ModelV2]] = None,
            action_sampler_fn: Optional[Callable[[
                TensorType, List[TensorType]
            ], Tuple[TensorType, TensorType]]] = None,
            action_distribution_fn: Optional[Callable[[
                Policy, ModelV2, TensorType, TensorType, TensorType
            ], Tuple[TensorType, type, List[TensorType]]]] = None,
            existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
            existing_model: Optional[ModelV2] = None,
            get_batch_divisibility_req: Optional[Callable[[Policy],
                                                          int]] = None,
            obs_include_prev_action_reward=DEPRECATED_VALUE):
        """Initializes a DynamicTFPolicy instance.

        Args:
            observation_space (gym.spaces.Space): Observation space of the
                policy.
            action_space (gym.spaces.Space): Action space of the policy.
            config (TrainerConfigDict): Policy-specific configuration data.
            loss_fn (Callable[[Policy, ModelV2, Type[TFActionDistribution],
                SampleBatch], TensorType]): Function that returns a loss tensor
                for the policy graph.
            stats_fn (Optional[Callable[[Policy, SampleBatch],
                Dict[str, TensorType]]]): Optional function that returns a dict
                of TF fetches given the policy and batch input tensors.
            grad_stats_fn (Optional[Callable[[Policy, SampleBatch,
                ModelGradients], Dict[str, TensorType]]]):
                Optional function that returns a dict of TF fetches given the
                policy, sample batch, and loss gradient tensors.
            before_loss_init (Optional[Callable[
                [Policy, gym.spaces.Space, gym.spaces.Space,
                TrainerConfigDict], None]]): Optional function to run prior to
                loss init that takes the same arguments as __init__.
            make_model (Optional[Callable[[Policy, gym.spaces.Space,
                gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional
                function that returns a ModelV2 object given
                policy, obs_space, action_space, and policy config.
                All policy variables should be created in this function. If not
                specified, a default model will be created.
            action_sampler_fn (Optional[Callable[[Policy, ModelV2, Dict[
                str, TensorType], TensorType, TensorType], Tuple[TensorType,
                TensorType]]]): A callable returning a sampled action and its
                log-likelihood given Policy, ModelV2, input_dict, explore,
                timestep, and is_training.
            action_distribution_fn (Optional[Callable[[Policy, ModelV2,
                Dict[str, TensorType], TensorType, TensorType],
                Tuple[TensorType, type, List[TensorType]]]]): A callable
                returning distribution inputs (parameters), a dist-class to
                generate an action distribution object from, and
                internal-state outputs (or an empty list if not applicable).
                Note: No Exploration hooks have to be called from within
                `action_distribution_fn`. It's should only perform a simple
                forward pass through some model.
                If None, pass inputs through `self.model()` to get distribution
                inputs.
                The callable takes as inputs: Policy, ModelV2, input_dict,
                explore, timestep, is_training.
            existing_inputs (Optional[Dict[str, tf1.placeholder]]): When
                copying a policy, this specifies an existing dict of
                placeholders to use instead of defining new ones.
            existing_model (Optional[ModelV2]): When copying a policy, this
                specifies an existing model to clone and share weights with.
            get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
                Optional callable that returns the divisibility requirement for
                sample batches. If None, will assume a value of 1.
        """
        if obs_include_prev_action_reward != DEPRECATED_VALUE:
            deprecation_warning(
                old="obs_include_prev_action_reward", error=False)
        self.observation_space = obs_space
        self.action_space = action_space
        self.config = config
        self.framework = "tf"
        self._loss_fn = loss_fn
        self._stats_fn = stats_fn
        self._grad_stats_fn = grad_stats_fn
        self._seq_lens = None
        self._is_tower = existing_inputs is not None

        dist_class = None
        if action_sampler_fn or action_distribution_fn:
            if not make_model:
                raise ValueError(
                    "`make_model` is required if `action_sampler_fn` OR "
                    "`action_distribution_fn` is given")
        else:
            dist_class, logit_dim = ModelCatalog.get_action_dist(
                action_space, self.config["model"])

        # Setup self.model.
        if existing_model:
            if isinstance(existing_model, list):
                self.model = existing_model[0]
                # TODO: (sven) hack, but works for `target_[q_]?model`.
                for i in range(1, len(existing_model)):
                    setattr(self, existing_model[i][0], existing_model[i][1])
        elif make_model:
            self.model = make_model(self, obs_space, action_space, config)
        else:
            self.model = ModelCatalog.get_model_v2(
                obs_space=obs_space,
                action_space=action_space,
                num_outputs=logit_dim,
                model_config=self.config["model"],
                framework="tf")
        # Auto-update model's inference view requirements, if recurrent.
        self._update_model_view_requirements_from_init_state()

        # Input placeholders already given -> Use these.
        if existing_inputs:
            self._state_inputs = [
                v for k, v in existing_inputs.items()
                if k.startswith("state_in_")
            ]
            # Placeholder for RNN time-chunk valid lengths.
            if self._state_inputs:
                self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS]
        # Create new input placeholders.
        else:
            self._state_inputs = [
                get_placeholder(
                    space=vr.space,
                    time_axis=not isinstance(vr.shift, int),
                ) for k, vr in self.model.view_requirements.items()
                if k.startswith("state_in_")
            ]
            # Placeholder for RNN time-chunk valid lengths.
            if self._state_inputs:
                self._seq_lens = tf1.placeholder(
                    dtype=tf.int32, shape=[None], name="seq_lens")

        # Use default settings.
        # Add NEXT_OBS, STATE_IN_0.., and others.
        self.view_requirements = self._get_default_view_requirements()
        # Combine view_requirements for Model and Policy.
        self.view_requirements.update(self.model.view_requirements)
        # Disable env-info placeholder.
        if SampleBatch.INFOS in self.view_requirements:
            self.view_requirements[SampleBatch.INFOS].used_for_training = False

        # Setup standard placeholders.
        if self._is_tower:
            timestep = existing_inputs["timestep"]
            explore = False
            self._input_dict, self._dummy_batch = \
                self._get_input_dict_and_dummy_batch(
                    self.view_requirements, existing_inputs)
        else:
            action_ph = ModelCatalog.get_action_placeholder(action_space)
            prev_action_ph = {}
            if SampleBatch.PREV_ACTIONS not in self.view_requirements:
                prev_action_ph = {
                    SampleBatch.PREV_ACTIONS: ModelCatalog.
                    get_action_placeholder(action_space, "prev_action")
                }
            self._input_dict, self._dummy_batch = \
                self._get_input_dict_and_dummy_batch(
                    self.view_requirements,
                    dict({SampleBatch.ACTIONS: action_ph},
                         **prev_action_ph))
            # Placeholder for (sampling steps) timestep (int).
            timestep = tf1.placeholder_with_default(
                tf.zeros((), dtype=tf.int64), (), name="timestep")
            # Placeholder for `is_exploring` flag.
            explore = tf1.placeholder_with_default(
                True, (), name="is_exploring")

        # Placeholder for `is_training` flag.
        self._input_dict.is_training = self._get_is_training_placeholder()

        # Multi-GPU towers do not need any action computing/exploration
        # graphs.
        sampled_action = None
        sampled_action_logp = None
        dist_inputs = None
        self._state_out = None
        if not self._is_tower:
            # Create the Exploration object to use for this Policy.
            self.exploration = self._create_exploration()

            # Fully customized action generation (e.g., custom policy).
            if action_sampler_fn:
                sampled_action, sampled_action_logp = action_sampler_fn(
                    self,
                    self.model,
                    obs_batch=self._input_dict[SampleBatch.CUR_OBS],
                    state_batches=self._state_inputs,
                    seq_lens=self._seq_lens,
                    prev_action_batch=self._input_dict.get(
                        SampleBatch.PREV_ACTIONS),
                    prev_reward_batch=self._input_dict.get(
                        SampleBatch.PREV_REWARDS),
                    explore=explore,
                    is_training=self._input_dict.is_training)
            # Distribution generation is customized, e.g., DQN, DDPG.
            else:
                if action_distribution_fn:

                    # Try new action_distribution_fn signature, supporting
                    # state_batches and seq_lens.
                    in_dict = self._input_dict
                    try:
                        dist_inputs, dist_class, self._state_out = \
                            action_distribution_fn(
                                self,
                                self.model,
                                input_dict=in_dict,
                                state_batches=self._state_inputs,
                                seq_lens=self._seq_lens,
                                explore=explore,
                                timestep=timestep,
                                is_training=in_dict.is_training)
                    # Trying the old way (to stay backward compatible).
                    # TODO: Remove in future.
                    except TypeError as e:
                        if "positional argument" in e.args[0] or \
                                "unexpected keyword argument" in e.args[0]:
                            dist_inputs, dist_class, self._state_out = \
                                action_distribution_fn(
                                    self, self.model,
                                    obs_batch=in_dict[SampleBatch.CUR_OBS],
                                    state_batches=self._state_inputs,
                                    seq_lens=self._seq_lens,
                                    prev_action_batch=in_dict.get(
                                        SampleBatch.PREV_ACTIONS),
                                    prev_reward_batch=in_dict.get(
                                        SampleBatch.PREV_REWARDS),
                                    explore=explore,
                                    is_training=in_dict.is_training)
                        else:
                            raise e

                # Default distribution generation behavior:
                # Pass through model. E.g., PG, PPO.
                else:
                    if isinstance(self.model, tf.keras.Model):
                        dist_inputs, self._state_out, \
                            self._extra_action_fetches = \
                            self.model(self._input_dict)
                    else:
                        dist_inputs, self._state_out = self.model(
                            self._input_dict, self._state_inputs,
                            self._seq_lens)

                action_dist = dist_class(dist_inputs, self.model)

                # Using exploration to get final action (e.g. via sampling).
                sampled_action, sampled_action_logp = \
                    self.exploration.get_exploration_action(
                        action_distribution=action_dist,
                        timestep=timestep,
                        explore=explore)

        # Phase 1 init.
        sess = tf1.get_default_session() or tf1.Session(
            config=tf1.ConfigProto(**self.config["tf_session_args"]))

        batch_divisibility_req = get_batch_divisibility_req(self) if \
            callable(get_batch_divisibility_req) else \
            (get_batch_divisibility_req or 1)

        super().__init__(
            observation_space=obs_space,
            action_space=action_space,
            config=config,
            sess=sess,
            obs_input=self._input_dict[SampleBatch.OBS],
            action_input=self._input_dict[SampleBatch.ACTIONS],
            sampled_action=sampled_action,
            sampled_action_logp=sampled_action_logp,
            dist_inputs=dist_inputs,
            dist_class=dist_class,
            loss=None,  # dynamically initialized on run
            loss_inputs=[],
            model=self.model,
            state_inputs=self._state_inputs,
            state_outputs=self._state_out,
            prev_action_input=self._input_dict.get(SampleBatch.PREV_ACTIONS),
            prev_reward_input=self._input_dict.get(SampleBatch.PREV_REWARDS),
            seq_lens=self._seq_lens,
            max_seq_len=config["model"]["max_seq_len"],
            batch_divisibility_req=batch_divisibility_req,
            explore=explore,
            timestep=timestep)

        # Phase 2 init.
        if before_loss_init is not None:
            before_loss_init(self, obs_space, action_space, config)

        # Loss initialization and model/postprocessing test calls.
        if not self._is_tower:
            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True)

            # Create MultiGPUTowerStacks, if we have at least one actual
            # GPU or >1 CPUs (fake GPUs).
            if len(self.devices) > 1 or any("gpu" in d for d in self.devices):
                # Per-GPU graph copies created here must share vars with the
                # policy. Therefore, `reuse` is set to tf1.AUTO_REUSE because
                # Adam nodes are created after all of the device copies are
                # created.
                with tf1.variable_scope("", reuse=tf1.AUTO_REUSE):
                    self.multi_gpu_tower_stacks = [
                        TFMultiGPUTowerStack(policy=self) for i in range(
                            self.config.get("num_multi_gpu_tower_stacks", 1))
                    ]

            # Initialize again after loss and tower init.
            self.get_session().run(tf1.global_variables_initializer())