Пример #1
0
 def _lazy_tensor_dict(self, postprocessed_batch):
     train_batch = UsageTrackingDict(postprocessed_batch)
     train_batch.set_get_interceptor(self._convert_to_tensor)
     return train_batch
Пример #2
0
    def _initialize_loss(self):
        def fake_array(tensor):
            shape = tensor.shape.as_list()
            shape[0] = 1
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        dummy_batch = {
            SampleBatch.CUR_OBS: fake_array(self._obs_input),
            SampleBatch.NEXT_OBS: fake_array(self._obs_input),
            SampleBatch.DONES: np.array([False], dtype=np.bool),
            SampleBatch.ACTIONS: fake_array(
                ModelCatalog.get_action_placeholder(self.action_space)),
            SampleBatch.REWARDS: np.array([0], dtype=np.float32),
        }
        if self._obs_include_prev_action_reward:
            dummy_batch.update({
                SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
                SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
            })
        state_init = self.get_initial_state()
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
            dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
        if state_init:
            dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())
        postprocessed_batch = self.postprocess_trajectory(
            SampleBatch(dummy_batch))

        if self._obs_include_prev_action_reward:
            batch_tensors = UsageTrackingDict({
                SampleBatch.PREV_ACTIONS: self._prev_action_input,
                SampleBatch.PREV_REWARDS: self._prev_reward_input,
                SampleBatch.CUR_OBS: self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            batch_tensors = UsageTrackingDict({
                SampleBatch.CUR_OBS: self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        for k, v in postprocessed_batch.items():
            if k in batch_tensors:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            shape = (None, ) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            batch_tensors[k] = placeholder

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(batch_tensors)))

        loss = self._do_loss_init(batch_tensors)
        for k in sorted(batch_tensors.accessed_keys):
            loss_inputs.append((k, batch_tensors[k]))

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
        self._sess.run(tf.global_variables_initializer())
Пример #3
0
    def _initialize_loss_from_dummy_batch(
        self,
        auto_remove_unneeded_view_reqs: bool = True,
        stats_fn=None,
    ) -> None:
        """Performs test calls through policy's model and loss.

        NOTE: This base method should work for define-by-run Policies such as
        torch and tf-eager policies.

        If required, will thereby detect automatically, which data views are
        required by a) the forward pass, b) the postprocessing, and c) the loss
        functions, and remove those from self.view_requirements that are not
        necessary for these computations (to save data storage and transfer).

        Args:
            auto_remove_unneeded_view_reqs (bool): Whether to automatically
                remove those ViewRequirements records from
                self.view_requirements that are not needed.
            stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str,
                TensorType]]]): An optional stats function to be called after
                the loss.
        """
        sample_batch_size = max(self.batch_divisibility_req * 4, 32)
        self._dummy_batch = self._get_dummy_batch_from_view_requirements(
            sample_batch_size)
        input_dict = self._lazy_tensor_dict(self._dummy_batch)
        actions, state_outs, extra_outs = \
            self.compute_actions_from_input_dict(input_dict, explore=False)
        # Add all extra action outputs to view reqirements (these may be
        # filtered out later again, if not needed for postprocessing or loss).
        for key, value in extra_outs.items():
            self._dummy_batch[key] = np.zeros_like(value)
            if key not in self.view_requirements:
                self.view_requirements[key] = \
                    ViewRequirement(space=gym.spaces.Box(
                        -1.0, 1.0, shape=value.shape[1:], dtype=value.dtype))
        batch_for_postproc = UsageTrackingDict(self._dummy_batch)
        batch_for_postproc.count = self._dummy_batch.count
        postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
        if state_outs:
            B = 4  # For RNNs, have B=2, T=[depends on sample_batch_size]
            # TODO: (sven) This hack will not work for attention net traj.
            #  view setup.
            i = 0
            while "state_in_{}".format(i) in postprocessed_batch:
                postprocessed_batch["state_in_{}".format(i)] = \
                    postprocessed_batch["state_in_{}".format(i)][:B]
                if "state_out_{}".format(i) in postprocessed_batch:
                    postprocessed_batch["state_out_{}".format(i)] = \
                        postprocessed_batch["state_out_{}".format(i)][:B]
                i += 1
            seq_len = sample_batch_size // B
            postprocessed_batch["seq_lens"] = \
                np.array([seq_len for _ in range(B)], dtype=np.int32)
        # Remove the UsageTrackingDict wrap to prep for wrapping the
        # train batch with a to-tensor UsageTrackingDict.
        train_batch = {k: v for k, v in postprocessed_batch.items()}
        train_batch = self._lazy_tensor_dict(train_batch)
        train_batch.count = self._dummy_batch.count
        # Call the loss function, if it exists.
        if self._loss is not None:
            self._loss(self, self.model, self.dist_class, train_batch)
        # Call the stats fn, if given.
        if stats_fn is not None:
            stats_fn(self, train_batch)

        # Add new columns automatically to view-reqs.
        if self.config["_use_trajectory_view_api"] and \
                auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | \
                                batch_for_postproc.accessed_keys | \
                                batch_for_postproc.added_keys
            for key in all_accessed_keys:
                if key not in self.view_requirements:
                    self.view_requirements[key] = ViewRequirement()
            if self._loss:
                # Tag those only needed for post-processing.
                for key in batch_for_postproc.accessed_keys:
                    if key not in train_batch.accessed_keys and \
                            key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                # Remove those not needed at all (leave those that are needed
                # by Sampler to properly execute sample collection).
                # Also always leave DONES and REWARDS, no matter what.
                for key in list(self.view_requirements.keys()):
                    if key not in all_accessed_keys and key not in [
                        SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                        SampleBatch.UNROLL_ID, SampleBatch.DONES,
                        SampleBatch.REWARDS] and \
                            key not in self.model.inference_view_requirements:
                        # If user deleted this key manually in postprocessing
                        # fn, warn about it and do not remove from
                        # view-requirements.
                        if key in batch_for_postproc.deleted_keys:
                            logger.warning(
                                "SampleBatch key '{}' was deleted manually in "
                                "postprocessing function! RLlib will "
                                "automatically remove non-used items from the "
                                "data stream. Remove the `del` from your "
                                "postprocessing function.".format(key))
                        else:
                            del self.view_requirements[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if vr.data_col is not None and \
                        vr.data_col not in self.view_requirements:
                    used_for_training = \
                        vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = \
                        ViewRequirement(
                            space=vr.space,
                            used_for_training=used_for_training)
Пример #4
0
    def _initialize_loss_from_dummy_batch(
            self,
            auto_remove_unneeded_view_reqs: bool = True,
            stats_fn=None) -> None:

        # Test calls depend on variable init, so initialize model first.
        self._sess.run(tf1.global_variables_initializer())

        if self.config["_use_trajectory_view_api"]:
            logger.info("Testing `compute_actions` w/ dummy batch.")
            actions, state_outs, extra_fetches = \
                self.compute_actions_from_input_dict(
                    self._dummy_batch, explore=False, timestep=0)
            for key, value in extra_fetches.items():
                self._dummy_batch[key] = np.zeros_like(value)
                self._input_dict[key] = get_placeholder(value=value, name=key)
                if key not in self.view_requirements:
                    logger.info("Adding extra-action-fetch `{}` to "
                                "view-reqs.".format(key))
                    self.view_requirements[key] = \
                        ViewRequirement(space=gym.spaces.Box(
                            -1.0, 1.0, shape=value.shape[1:],
                            dtype=value.dtype))
            dummy_batch = self._dummy_batch
        else:

            def fake_array(tensor):
                shape = tensor.shape.as_list()
                shape = [s if s is not None else 1 for s in shape]
                return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

            dummy_batch = {
                SampleBatch.CUR_OBS:
                fake_array(self._obs_input),
                SampleBatch.NEXT_OBS:
                fake_array(self._obs_input),
                SampleBatch.DONES:
                np.array([False], dtype=np.bool),
                SampleBatch.ACTIONS:
                fake_array(
                    ModelCatalog.get_action_placeholder(self.action_space)),
                SampleBatch.REWARDS:
                np.array([0], dtype=np.float32),
            }
            if self._obs_include_prev_action_reward:
                dummy_batch.update({
                    SampleBatch.PREV_ACTIONS:
                    fake_array(self._prev_action_input),
                    SampleBatch.PREV_REWARDS:
                    fake_array(self._prev_reward_input),
                })
            state_init = self.get_initial_state()
            state_batches = []
            for i, h in enumerate(state_init):
                dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
                dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
                state_batches.append(np.expand_dims(h, 0))
            if state_init:
                dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
            for k, v in self.extra_compute_action_fetches().items():
                dummy_batch[k] = fake_array(v)

        sb = SampleBatch(dummy_batch)
        batch_for_postproc = UsageTrackingDict(sb)
        batch_for_postproc.count = sb.count
        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
        # Add new columns automatically to (loss) input_dict.
        if self.config["_use_trajectory_view_api"]:
            for key in batch_for_postproc.added_keys:
                if key not in self._input_dict:
                    self._input_dict[key] = get_placeholder(
                        value=batch_for_postproc[key], name=key)
                if key not in self.view_requirements:
                    self.view_requirements[key] = \
                        ViewRequirement(space=gym.spaces.Box(
                            -1.0, 1.0, shape=batch_for_postproc[key].shape[1:],
                            dtype=batch_for_postproc[key].dtype))

        if not self.config["_use_trajectory_view_api"]:
            train_batch = UsageTrackingDict(
                dict({
                    SampleBatch.CUR_OBS: self._obs_input,
                }, **self._loss_input_dict))
            if self._obs_include_prev_action_reward:
                train_batch.update({
                    SampleBatch.PREV_ACTIONS:
                    self._prev_action_input,
                    SampleBatch.PREV_REWARDS:
                    self._prev_reward_input,
                })

            for k, v in postprocessed_batch.items():
                if k in train_batch:
                    continue
                elif v.dtype == np.object:
                    continue  # can't handle arbitrary objects in TF
                elif k == "seq_lens" or k.startswith("state_in_"):
                    continue
                shape = (None, ) + v.shape[1:]
                dtype = np.float32 if v.dtype == np.float64 else v.dtype
                placeholder = tf1.placeholder(dtype, shape=shape, name=k)
                train_batch[k] = placeholder

            for i, si in enumerate(self._state_inputs):
                train_batch["state_in_{}".format(i)] = si
        else:
            train_batch = UsageTrackingDict(self._input_dict)

        if self._state_inputs:
            train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})
        loss = self._do_loss_init(train_batch)

        all_accessed_keys = \
            train_batch.accessed_keys | batch_for_postproc.accessed_keys | \
            batch_for_postproc.added_keys | set(
                self.model.inference_view_requirements.keys())

        TFPolicy._initialize_loss(
            self, loss,
            [(k, v) for k, v in train_batch.items() if k in all_accessed_keys])

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))

        # Add new columns automatically to view-reqs.
        if self.config["_use_trajectory_view_api"] and \
                auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | \
                                batch_for_postproc.accessed_keys
            # Tag those only needed for post-processing.
            for key in batch_for_postproc.accessed_keys:
                if key not in train_batch.accessed_keys:
                    self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES and REWARDS, no matter what.
            for key in list(self.view_requirements.keys()):
                if key not in all_accessed_keys and key not in [
                    SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                    SampleBatch.UNROLL_ID, SampleBatch.DONES,
                    SampleBatch.REWARDS] and \
                        key not in self.model.inference_view_requirements:
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in batch_for_postproc.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    else:
                        del self.view_requirements[key]
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = \
                        vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }

        # Initialize again after loss init.
        self._sess.run(tf1.global_variables_initializer())
Пример #5
0
 def _lazy_numpy_dict(self, postprocessed_batch):
     train_batch = UsageTrackingDict(postprocessed_batch)
     train_batch.set_get_interceptor(convert_to_non_tf_type)
     return train_batch
Пример #6
0
    def _initialize_loss(self):
        def fake_array(tensor):
            shape = tensor.shape.as_list()
            shape = [s if s is not None else 1 for s in shape]
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        dummy_batch = {
            SampleBatch.CUR_OBS:
            fake_array(self._obs_input),
            SampleBatch.NEXT_OBS:
            fake_array(self._obs_input),
            SampleBatch.DONES:
            np.array([False], dtype=np.bool),
            SampleBatch.ACTIONS:
            fake_array(ModelCatalog.get_action_placeholder(self.action_space)),
            SampleBatch.REWARDS:
            np.array([0], dtype=np.float32),
        }
        if self._obs_include_prev_action_reward:
            dummy_batch.update({
                SampleBatch.PREV_ACTIONS:
                fake_array(self._prev_action_input),
                SampleBatch.PREV_REWARDS:
                fake_array(self._prev_reward_input),
            })
        state_init = self.get_initial_state()
        state_batches = []
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
            dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
            state_batches.append(np.expand_dims(h, 0))
        if state_init:
            dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())

        postprocessed_batch = self.postprocess_trajectory(
            SampleBatch(dummy_batch))

        # model forward pass for the loss (needed after postprocess to
        # overwrite any tensor state from that call)
        self.model(self._input_dict, self._state_in, self._seq_lens)

        if self._obs_include_prev_action_reward:
            train_batch = UsageTrackingDict({
                SampleBatch.PREV_ACTIONS:
                self._prev_action_input,
                SampleBatch.PREV_REWARDS:
                self._prev_reward_input,
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            train_batch = UsageTrackingDict({
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        for k, v in postprocessed_batch.items():
            if k in train_batch:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            elif k == "seq_lens" or k.startswith("state_in_"):
                continue
            shape = (None, ) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            train_batch[k] = placeholder

        for i, si in enumerate(self._state_in):
            train_batch["state_in_{}".format(i)] = si
        train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        self._loss_input_dict = train_batch
        loss = self._do_loss_init(train_batch)
        for k in sorted(train_batch.accessed_keys):
            if k != "seq_lens" and not k.startswith("state_in_"):
                loss_inputs.append((k, train_batch[k]))

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))
        self._sess.run(tf.global_variables_initializer())
Пример #7
0
def make_batch(obs_space, action_space, batch_size=4):
    batch = UsageTrackingDict(
        fake_batch(obs_space, action_space, batch_size=batch_size))
    batch.set_get_interceptor(partial(convert_to_tensor, device="cpu"))
    return batch
Пример #8
0
    def _initialize_loss(self):
        def fake_array(tensor):
            shape = tensor.shape.as_list()
            shape[0] = 1
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        dummy_batch = {
            SampleBatch.CUR_OBS:
            fake_array(self._obs_input),
            SampleBatch.NEXT_OBS:
            fake_array(self._obs_input),
            SampleBatch.DONES:
            np.array([False], dtype=np.bool),
            SampleBatch.ACTIONS:
            fake_array(ModelCatalog.get_action_placeholder(self.action_space)),
            SampleBatch.REWARDS:
            np.array([0], dtype=np.float32),
        }
        if self._obs_include_prev_action_reward:
            dummy_batch.update({
                SampleBatch.PREV_ACTIONS:
                fake_array(self._prev_action_input),
                SampleBatch.PREV_REWARDS:
                fake_array(self._prev_reward_input),
            })
        state_init = self.get_initial_state()
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
            dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
        if state_init:
            dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())
        postprocessed_batch = self.postprocess_trajectory(
            SampleBatch(dummy_batch))

        if self._obs_include_prev_action_reward:
            batch_tensors = UsageTrackingDict({
                SampleBatch.PREV_ACTIONS:
                self._prev_action_input,
                SampleBatch.PREV_REWARDS:
                self._prev_reward_input,
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            batch_tensors = UsageTrackingDict({
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        for k, v in postprocessed_batch.items():
            if k in batch_tensors:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            shape = (None, ) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            batch_tensors[k] = placeholder

        if log_once("loss_init"):
            logger.info(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(batch_tensors)))

        loss = self._do_loss_init(batch_tensors)
        for k in sorted(batch_tensors.accessed_keys):
            loss_inputs.append((k, batch_tensors[k]))

        # XXX experimental support for automatically eagerifying the loss.
        # The main limitation right now is that TF doesn't support mixing eager
        # and non-eager tensors, so losses that read non-eager tensors through
        # `policy` need to use `policy.convert_to_eager(tensor)`.
        if self.config["use_eager"]:
            if not self.model:
                raise ValueError("eager not implemented in this case")
            graph_tensors = list(self._needs_eager_conversion)

            def gen_loss(model_outputs, *args):
                # fill in the batch tensor dict with eager ensors
                eager_inputs = dict(
                    zip([k for (k, v) in loss_inputs],
                        args[:len(loss_inputs)]))
                # fill in the eager versions of all accessed graph tensors
                self._eager_tensors = dict(
                    zip(graph_tensors, args[len(loss_inputs):]))
                # patch the action dist to use eager mode tensors
                self.action_dist.inputs = model_outputs
                return self._loss_fn(self, eager_inputs)

            # TODO(ekl) also handle the stats funcs
            loss = tf.py_function(
                gen_loss,
                # cast works around TypeError: Cannot convert provided value
                # to EagerTensor. Provided value: 0.0 Requested dtype: int64
                [self.model.outputs] +
                [tf.cast(v, tf.float32) for (k, v) in loss_inputs] +
                [tf.cast(t, tf.float32) for t in graph_tensors],
                tf.float32)

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
        self._sess.run(tf.global_variables_initializer())
Пример #9
0
 def _lazy_numpy_dict(self, postprocessed_batch):
     train_batch = UsageTrackingDict(postprocessed_batch)
     train_batch.set_get_interceptor(
         functools.partial(convert_to_non_torch_type))
     return train_batch
Пример #10
0
 def _lazy_tensor_dict(self, postprocessed_batch):
     train_batch = UsageTrackingDict(postprocessed_batch)
     train_batch.set_get_interceptor(
         functools.partial(convert_to_torch_tensor, device=self.device))
     return train_batch
Пример #11
0
 def _lazy_tensor_dict(self, postprocessed_batch):
     batch_tensors = UsageTrackingDict(postprocessed_batch)
     batch_tensors.set_get_interceptor(
         lambda arr: torch.from_numpy(arr).to(self.device))
     return batch_tensors
Пример #12
0
 def _lazy_tensor_dict(self, batch):
     if not isinstance(batch, UsageTrackingDict):
         batch = UsageTrackingDict(batch)
     batch.set_get_interceptor(
         functools.partial(convert_to_torch_tensor, device=self.device))
     return batch
Пример #13
0
    def _initialize_loss(self):
        def fake_array(tensor):
            shape = tensor.shape.as_list()
            shape = [s if s is not None else 1 for s in shape]
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        dummy_batch = {
            SampleBatch.CUR_OBS:
            fake_array(self._obs_input),
            SampleBatch.NEXT_OBS:
            fake_array(self._obs_input),
            SampleBatch.DONES:
            np.array([False], dtype=np.bool),
            SampleBatch.ACTIONS:
            fake_array(ModelCatalog.get_action_placeholder(self.action_space)),
            SampleBatch.REWARDS:
            np.array([0], dtype=np.float32),
        }

        # Add dummy things PENGZHENGHAO
        # for name, val in self.model.mask_placeholder_dict.items():
        #     shape = val.shape.as_list()
        #     shape = [1] + [s if s is not None else 1 for s in shape]
        #     dummy_batch[name] = \
        #         np.zeros(shape, dtype=val.dtype.as_numpy_dtype)

        if self._obs_include_prev_action_reward:
            dummy_batch.update({
                SampleBatch.PREV_ACTIONS:
                fake_array(self._prev_action_input),
                SampleBatch.PREV_REWARDS:
                fake_array(self._prev_reward_input),
            })
        state_init = self.get_initial_state()
        state_batches = []
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
            dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
            state_batches.append(np.expand_dims(h, 0))
        if state_init:
            dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)

        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())

        postprocessed_batch = self.postprocess_trajectory(
            SampleBatch(dummy_batch))

        # model forward pass for the loss (needed after postprocess to
        # overwrite any tensor state from that call)
        self.model(self._input_dict, self._state_in, self._seq_lens)

        if self._obs_include_prev_action_reward:
            train_batch = UsageTrackingDict({
                SampleBatch.PREV_ACTIONS:
                self._prev_action_input,
                SampleBatch.PREV_REWARDS:
                self._prev_reward_input,
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            train_batch = UsageTrackingDict({
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        # When using the mask, the key of postprocessed_batch is :
        # dict_keys(['obs', 'new_obs', 'dones', 'actions', 'rewards',
        # 'fc_1_mask', 'fc_2_mask', 'prev_actions', 'prev_rewards',
        # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits',
        # 'layer0', 'layer1', 'advantages', 'value_targets'])

        # When not using the mask, the keys is:
        # dict_keys(['obs', 'new_obs', 'dones', 'actions', 'rewards',
        # 'fc_1_mask', 'fc_2_mask', 'prev_actions', 'prev_rewards',
        # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits',
        # 'layer0', 'layer1', 'advantages', 'value_targets'])
        for k, v in postprocessed_batch.items():
            if k in train_batch:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            elif k == "seq_lens" or k.startswith("state_in_"):
                continue
            shape = (None, ) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            train_batch[k] = placeholder

        # When using the mask. At this time, the train_batch contain 17
        # element.
        # <class 'list'>: ['prev_actions', 'prev_rewards', 'obs', 'new_obs',
        # 'dones', 'actions', 'rewards', 'fc_1_mask', 'fc_2_mask',
        # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits',
        # 'layer0', 'layer1', 'advantages', 'value_targets']
        for i, si in enumerate(self._state_in):
            train_batch["state_in_{}".format(i)] = si
        train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        self._loss_input_dict = train_batch
        # At this time, the accessed_keys: <class 'set'>:
        # {'obs', 'prev_rewards', 'value_targets', 'behaviour_logits',
        # 'prev_actions', 'advantages', 'action_logp', 'actions',
        # 'vf_preds', 'accessed_keys', 'intercepted_values'}

        # However, in the no-mask exp, current accessed_keys:
        # <class 'set'>: {'intercepted_values', 'accessed_keys'}

        loss = self._do_loss_init(train_batch)
        # after the above line, the accessed_keys: <class 'set'>:
        # {'advantages', 'action_logp', 'behaviour_logits', 'prev_rewards',
        # 'prev_actions', 'vf_preds', 'actions', 'value_targets', 'obs'}

        # However, in the no-mask exp, above line lead to: They are same.
        # but different order.
        # {'action_logp', 'prev_actions', 'behaviour_logits',
        # 'value_targets', 'obs', 'prev_rewards', 'advantages', 'vf_preds',
        # 'actions'}

        # at this time, the loss input already has: prev_actions,
        # prev_rewards, obs
        for k in sorted(train_batch.accessed_keys):
            # sorted train_batch.accessed_keys: <class 'list'>: [
            # 'action_logp', 'actions', 'advantages', 'behaviour_logits',
            # 'obs', 'prev_actions', 'prev_rewards', 'value_targets',
            # 'vf_preds']
            if k != "seq_lens" and not k.startswith("state_in_"):
                loss_inputs.append((k, train_batch[k]))

        # PENGZHENGHAO
        # for name, ph in self.model.mask_placeholder_dict.items():
        #     loss_inputs.append((name, ph))

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))
        self._sess.run(tf.global_variables_initializer())