예제 #1
0
파일: curiosity.py 프로젝트: tchordia/ray
    def get_exploration_optimizer(self, optimizers):
        # Create, but don't add Adam for curiosity NN updating to the policy.
        # If we added and returned it here, it would be used in the policy's
        # update loop, which we don't want (curiosity updating happens inside
        # `postprocess_trajectory`).
        if self.framework == "torch":
            feature_params = list(self._curiosity_feature_net.parameters())
            inverse_params = list(self._curiosity_inverse_fcnet.parameters())
            forward_params = list(self._curiosity_forward_fcnet.parameters())

            # Now that the Policy's own optimizer(s) have been created (from
            # the Model parameters (IMPORTANT: w/o(!) the curiosity params),
            # we can add our curiosity sub-modules to the Policy's Model.
            self.model._curiosity_feature_net = self._curiosity_feature_net.to(
                self.device
            )
            self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet.to(
                self.device
            )
            self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet.to(
                self.device
            )
            self._optimizer = torch.optim.Adam(
                forward_params + inverse_params + feature_params, lr=self.lr
            )
        else:
            self.model._curiosity_feature_net = self._curiosity_feature_net
            self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet
            self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet
            # Feature net is a RLlib ModelV2, the other 2 are keras Models.
            self._optimizer_var_list = (
                self._curiosity_feature_net.base_model.variables
                + self._curiosity_inverse_fcnet.variables
                + self._curiosity_forward_fcnet.variables
            )
            self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr)
            # Create placeholders and initialize the loss.
            if self.framework == "tf":
                self._obs_ph = get_placeholder(
                    space=self.model.obs_space, name="_curiosity_obs"
                )
                self._next_obs_ph = get_placeholder(
                    space=self.model.obs_space, name="_curiosity_next_obs"
                )
                self._action_ph = get_placeholder(
                    space=self.model.action_space, name="_curiosity_action"
                )
                (
                    self._forward_l2_norm_sqared,
                    self._update_op,
                ) = self._postprocess_helper_tf(
                    self._obs_ph, self._next_obs_ph, self._action_ph
                )

        return optimizers
예제 #2
0
    def _init_state_inputs(self, existing_inputs: Dict[str,
                                                       "tf1.placeholder"]):
        """Initialize input placeholders.

        Args:
            existing_inputs: existing placeholders.
        """
        if existing_inputs:
            self._state_inputs = [
                v for k, v in existing_inputs.items()
                if k.startswith("state_in_")
            ]
            # Placeholder for RNN time-chunk valid lengths.
            if self._state_inputs:
                self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS]
            # Create new input placeholders.
        else:
            self._state_inputs = [
                get_placeholder(
                    space=vr.space,
                    time_axis=not isinstance(vr.shift, int),
                    name=k,
                ) for k, vr in self.model.view_requirements.items()
                if k.startswith("state_in_")
            ]
            # Placeholder for RNN time-chunk valid lengths.
            if self._state_inputs:
                self._seq_lens = tf1.placeholder(dtype=tf.int32,
                                                 shape=[None],
                                                 name="seq_lens")
예제 #3
0
    def _initialize_loss_from_dummy_batch(
            self, auto_remove_unneeded_view_reqs: bool = True) -> None:
        # Test calls depend on variable init, so initialize model first.
        self.get_session().run(tf1.global_variables_initializer())

        # Fields that have not been accessed are not needed for action
        # computations -> Tag them as `used_for_compute_actions=False`.
        for key, view_req in self.view_requirements.items():
            if (not key.startswith("state_in_")
                    and key not in self._input_dict.accessed_keys):
                view_req.used_for_compute_actions = False
        for key, value in self.extra_action_out_fn().items():
            self._dummy_batch[key] = get_dummy_batch_for_space(
                gym.spaces.Box(-1.0,
                               1.0,
                               shape=value.shape.as_list()[1:],
                               dtype=value.dtype.name),
                batch_size=len(self._dummy_batch),
            )
            self._input_dict[key] = get_placeholder(value=value, name=key)
            if key not in self.view_requirements:
                logger.info(
                    "Adding extra-action-fetch `{}` to view-reqs.".format(key))
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(-1.0,
                                         1.0,
                                         shape=value.shape[1:],
                                         dtype=value.dtype.name),
                    used_for_compute_actions=False,
                )
        dummy_batch = self._dummy_batch

        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        self.exploration.postprocess_trajectory(self, dummy_batch,
                                                self.get_session())
        _ = self.postprocess_trajectory(dummy_batch)
        # Add new columns automatically to (loss) input_dict.
        for key in dummy_batch.added_keys:
            if key not in self._input_dict:
                self._input_dict[key] = get_placeholder(value=dummy_batch[key],
                                                        name=key)
            if key not in self.view_requirements:
                self.view_requirements[key] = ViewRequirement(
                    space=gym.spaces.Box(
                        -1.0,
                        1.0,
                        shape=dummy_batch[key].shape[1:],
                        dtype=dummy_batch[key].dtype,
                    ),
                    used_for_compute_actions=False,
                )

        train_batch = SampleBatch(
            dict(self._input_dict, **self._loss_input_dict),
            _is_training=True,
        )

        if self._state_inputs:
            train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
            self._loss_input_dict.update(
                {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]})

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        losses = self._do_loss_init(train_batch)

        all_accessed_keys = (train_batch.accessed_keys
                             | dummy_batch.accessed_keys
                             | dummy_batch.added_keys
                             | set(self.model.view_requirements.keys()))

        TFPolicy._initialize_loss(
            self,
            losses,
            [(k, v)
             for k, v in train_batch.items() if k in all_accessed_keys] + ([
                 (SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])
             ] if SampleBatch.SEQ_LENS in train_batch else []),
        )

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        self._stats_fetches.update(self.grad_stats_fn(train_batch,
                                                      self._grads))

        # Add new columns automatically to view-reqs.
        if auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys
            # Tag those only needed for post-processing (with some exceptions).
            for key in dummy_batch.accessed_keys:
                if (key not in train_batch.accessed_keys
                        and key not in self.model.view_requirements
                        and key not in [
                            SampleBatch.EPS_ID,
                            SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID,
                            SampleBatch.DONES,
                            SampleBatch.REWARDS,
                            SampleBatch.INFOS,
                            SampleBatch.OBS_EMBEDS,
                        ]):
                    if key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES, REWARDS, and INFOS, no matter what.
            for key in list(self.view_requirements.keys()):
                if (key not in all_accessed_keys and key not in [
                        SampleBatch.EPS_ID,
                        SampleBatch.AGENT_INDEX,
                        SampleBatch.UNROLL_ID,
                        SampleBatch.DONES,
                        SampleBatch.REWARDS,
                        SampleBatch.INFOS,
                ] and key not in self.model.view_requirements):
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in dummy_batch.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    # If we are not writing output to disk, safe to erase
                    # this key to save space in the sample batch.
                    elif self.config["output"] is None:
                        del self.view_requirements[key]

                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }
예제 #4
0
    def _create_input_dict_and_dummy_batch(self, view_requirements,
                                           existing_inputs):
        """Creates input_dict and dummy_batch for loss initialization.

        Used for managing the Policy's input placeholders and for loss
        initialization.
        Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.

        Args:
            view_requirements: The view requirements dict.
            existing_inputs (Dict[str, tf.placeholder]): A dict of already
                existing placeholders.

        Returns:
            Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
                input_dict/dummy_batch tuple.
        """
        input_dict = {}
        for view_col, view_req in view_requirements.items():
            # Point state_in to the already existing self._state_inputs.
            mo = re.match("state_in_(\d+)", view_col)
            if mo is not None:
                input_dict[view_col] = self._state_inputs[int(mo.group(1))]
            # State-outs (no placeholders needed).
            elif view_col.startswith("state_out_"):
                continue
            # Skip action dist inputs placeholder (do later).
            elif view_col == SampleBatch.ACTION_DIST_INPUTS:
                continue
            # This is a tower: Input placeholders already exist.
            elif view_col in existing_inputs:
                input_dict[view_col] = existing_inputs[view_col]
            # All others.
            else:
                time_axis = not isinstance(view_req.shift, int)
                if view_req.used_for_training:
                    # Create a +time-axis placeholder if the shift is not an
                    # int (range or list of ints).
                    # Do not flatten actions if action flattening disabled.
                    if self.config.get(
                            "_disable_action_flattening") and view_col in [
                                SampleBatch.ACTIONS,
                                SampleBatch.PREV_ACTIONS,
                            ]:
                        flatten = False
                    # Do not flatten observations if no preprocessor API used.
                    elif (view_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS]
                          and self.config["_disable_preprocessor_api"]):
                        flatten = False
                    # Flatten everything else.
                    else:
                        flatten = True
                    input_dict[view_col] = get_placeholder(
                        space=view_req.space,
                        name=view_col,
                        time_axis=time_axis,
                        flatten=flatten,
                    )
        dummy_batch = self._get_dummy_batch_from_view_requirements(
            batch_size=32)

        return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch
예제 #5
0
    def __init__(self,
                 action_space: Space,
                 *,
                 framework: str,
                 model: ModelV2,
                 embeds_dim: int = 128,
                 encoder_net_config: Optional[ModelConfigDict] = None,
                 beta: float = 0.2,
                 beta_schedule: str = "constant",
                 rho: float = 0.1,
                 k_nn: int = 50,
                 random_timesteps: int = 10000,
                 sub_exploration: Optional[FromConfigSpec] = None,
                 **kwargs):
        """Initialize RE3.

        Args:
            action_space: The action space in which to explore.
            framework: Supports "tf", this implementation does not
                support torch.
            model: The policy's model.
            embeds_dim: The dimensionality of the observation embedding
                vectors in latent space.
            encoder_net_config: Optional model
                configuration for the encoder network, producing embedding
                vectors from observations. This can be used to configure
                fcnet- or conv_net setups to properly process any
                observation space.
            beta: Hyperparameter to choose between exploration and
                exploitation.
            beta_schedule: Schedule to use for beta decay, one of
                "constant" or "linear_decay".
            rho: Beta decay factor, used for on-policy algorithm.
            k_nn: Number of neighbours to set for K-NN entropy
                estimation.
            random_timesteps: The number of timesteps to act completely
                randomly (see [1]).
            sub_exploration: The config dict for the underlying Exploration
                to use (e.g. epsilon-greedy for DQN). If None, uses the
                FromSpecDict provided in the Policy's default config.

        Raises:
            ValueError: If the input framework is Torch.
        """
        # TODO(gjoliver): Add supports for Pytorch.
        if framework == "torch":
            raise ValueError("This RE3 implementation does not support Torch.")
        super().__init__(action_space,
                         model=model,
                         framework=framework,
                         **kwargs)

        self.beta = beta
        self.rho = rho
        self.k_nn = k_nn
        self.embeds_dim = embeds_dim
        if encoder_net_config is None:
            encoder_net_config = self.policy_config["model"].copy()
        self.encoder_net_config = encoder_net_config

        # Auto-detection of underlying exploration functionality.
        if sub_exploration is None:
            # For discrete action spaces, use an underlying EpsilonGreedy with
            # a special schedule.
            if isinstance(self.action_space, Discrete):
                sub_exploration = {
                    "type": "EpsilonGreedy",
                    "epsilon_schedule": {
                        "type":
                        "PiecewiseSchedule",
                        # Step function (see [2]).
                        "endpoints": [
                            (0, 1.0),
                            (random_timesteps + 1, 1.0),
                            (random_timesteps + 2, 0.01),
                        ],
                        "outside_value":
                        0.01,
                    },
                }
            elif isinstance(self.action_space, Box):
                sub_exploration = {
                    "type": "OrnsteinUhlenbeckNoise",
                    "random_timesteps": random_timesteps,
                }
            else:
                raise NotImplementedError

        self.sub_exploration = sub_exploration

        # Creates ModelV2 embedding module / layers.
        self._encoder_net = ModelCatalog.get_model_v2(
            self.model.obs_space,
            self.action_space,
            self.embeds_dim,
            model_config=self.encoder_net_config,
            framework=self.framework,
            name="encoder_net",
        )
        if self.framework == "tf":
            self._obs_ph = get_placeholder(space=self.model.obs_space,
                                           name="_encoder_obs")
            self._obs_embeds = tf.stop_gradient(
                self._encoder_net({SampleBatch.OBS: self._obs_ph})[0])

        # This is only used to select the correct action
        self.exploration_submodule = from_config(
            cls=Exploration,
            config=self.sub_exploration,
            action_space=self.action_space,
            framework=self.framework,
            policy_config=self.policy_config,
            model=self.model,
            num_workers=self.num_workers,
            worker_index=self.worker_index,
        )