예제 #1
0
 def extra_compute_action_feed_dict(self):
     if extra_action_feed_fn:
         return extra_action_feed_fn(self)
     else:
         return TFPolicy.extra_compute_action_feed_dict(self)
예제 #2
0
 def extra_compute_grad_feed_dict(self):
     if extra_learn_feed_fn:
         return extra_learn_feed_fn(self)
     else:
         return TFPolicy.extra_compute_grad_feed_dict(self)
예제 #3
0
 def gradients(self, optimizer, loss):
     if gradients_fn:
         return gradients_fn(self, optimizer, loss)
     else:
         return TFPolicy.gradients(self, optimizer, loss)
예제 #4
0
 def extra_compute_action_fetches(self):
     return dict(TFPolicy.extra_compute_action_fetches(self),
                 **self._extra_action_fetches)
예제 #5
0
    def _initialize_loss(self):
        def fake_array(tensor):
            shape = tensor.shape.as_list()
            shape = [s if s is not None else 1 for s in shape]
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        dummy_batch = {
            SampleBatch.CUR_OBS:
            fake_array(self._obs_input),
            SampleBatch.NEXT_OBS:
            fake_array(self._obs_input),
            SampleBatch.DONES:
            np.array([False], dtype=np.bool),
            SampleBatch.ACTIONS:
            fake_array(ModelCatalog.get_action_placeholder(self.action_space)),
            SampleBatch.REWARDS:
            np.array([0], dtype=np.float32),
        }
        if self._obs_include_prev_action_reward:
            dummy_batch.update({
                SampleBatch.PREV_ACTIONS:
                fake_array(self._prev_action_input),
                SampleBatch.PREV_REWARDS:
                fake_array(self._prev_reward_input),
            })
        state_init = self.get_initial_state()
        state_batches = []
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
            dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
            state_batches.append(np.expand_dims(h, 0))
        if state_init:
            dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())

        postprocessed_batch = self.postprocess_trajectory(
            SampleBatch(dummy_batch))

        # model forward pass for the loss (needed after postprocess to
        # overwrite any tensor state from that call)
        self.model(self._input_dict, self._state_in, self._seq_lens)

        if self._obs_include_prev_action_reward:
            train_batch = UsageTrackingDict({
                SampleBatch.PREV_ACTIONS:
                self._prev_action_input,
                SampleBatch.PREV_REWARDS:
                self._prev_reward_input,
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            train_batch = UsageTrackingDict({
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        for k, v in postprocessed_batch.items():
            if k in train_batch:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            elif k == "seq_lens" or k.startswith("state_in_"):
                continue
            shape = (None, ) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            train_batch[k] = placeholder

        for i, si in enumerate(self._state_in):
            train_batch["state_in_{}".format(i)] = si
        train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        self._loss_input_dict = train_batch
        loss = self._do_loss_init(train_batch)
        for k in sorted(train_batch.accessed_keys):
            if k != "seq_lens" and not k.startswith("state_in_"):
                loss_inputs.append((k, train_batch[k]))

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))
        self._sess.run(tf.global_variables_initializer())
예제 #6
0
 def optimizer(self):
     if optimizer_fn:
         return optimizer_fn(self, self.config)
     else:
         return TFPolicy.optimizer(self)
예제 #7
0
    def __init__(
        self,
        obs_space,
        action_space,
        config,
        loss_fn,
        stats_fn=None,
        grad_stats_fn=None,
        before_loss_init=None,
        make_model=None,
        action_sampler_fn=None,
        existing_inputs=None,
        existing_model=None,
        get_batch_divisibility_req=None,
        obs_include_prev_action_reward=True,
    ):
        """Initialize a dynamic TF policy.

        Arguments:
            observation_space (gym.Space): Observation space of the policy.
            action_space (gym.Space): Action space of the policy.
            config (dict): Policy-specific configuration data.
            loss_fn (func): function that returns a loss tensor the policy
                graph, and dict of experience tensor placeholders
            stats_fn (func): optional function that returns a dict of
                TF fetches given the policy and batch input tensors
            grad_stats_fn (func): optional function that returns a dict of
                TF fetches given the policy and loss gradient tensors
            before_loss_init (func): optional function to run prior to loss
                init that takes the same arguments as __init__
            make_model (func): optional function that returns a ModelV2 object
                given (policy, obs_space, action_space, config).
                All policy variables should be created in this function. If not
                specified, a default model will be created.
            action_sampler_fn (func): optional function that returns a
                tuple of action and action logp tensors given
                (policy, model, input_dict, obs_space, action_space, config).
                If not specified, a default action distribution will be used.
            existing_inputs (OrderedDict): when copying a policy, this
                specifies an existing dict of placeholders to use instead of
                defining new ones
            existing_model (ModelV2): when copying a policy, this specifies
                an existing model to clone and share weights with
            get_batch_divisibility_req (func): optional function that returns
                the divisibility requirement for sample batches
            obs_include_prev_action_reward (bool): whether to include the
                previous action and reward in the model input

        Attributes:
            config: config of the policy
            model: model instance, if any
        """
        start = time.time()

        self.config = config
        self._loss_fn = loss_fn
        self._stats_fn = stats_fn
        self._grad_stats_fn = grad_stats_fn
        self._obs_include_prev_action_reward = obs_include_prev_action_reward

        # Setup standard placeholders
        prev_actions = None
        prev_rewards = None
        if existing_inputs is not None:
            obs = existing_inputs[SampleBatch.CUR_OBS]
            if self._obs_include_prev_action_reward:
                prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
                prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
        else:
            obs = tf.placeholder(tf.float32,
                                 shape=[None] + list(obs_space.shape),
                                 name="observation")
            if self._obs_include_prev_action_reward:
                prev_actions = ModelCatalog.get_action_placeholder(
                    action_space)
                prev_rewards = tf.placeholder(tf.float32, [None],
                                              name="prev_reward")

        self._input_dict = {
            SampleBatch.CUR_OBS: obs,
            SampleBatch.PREV_ACTIONS: prev_actions,
            SampleBatch.PREV_REWARDS: prev_rewards,
            "is_training": self._get_is_training_placeholder(),
        }
        self._seq_lens = tf.placeholder(dtype=tf.int32,
                                        shape=[None],
                                        name="seq_lens")

        # Setup model
        if action_sampler_fn:
            if not make_model:
                raise ValueError(
                    "make_model is required if action_sampler_fn is given")
            self._dist_class = None
        else:
            self._dist_class, logit_dim = ModelCatalog.get_action_dist(
                action_space, self.config["model"])

        if existing_model:
            # DEBUG
            # print("\t\tdynamic_tf_policy.py: Using existing model.")
            self.model = existing_model
        elif make_model:
            # DEBUG
            # print("\t\tdynamic_tf_policy.py: Using `make_model` function.")
            self.model = make_model(self, obs_space, action_space, config)
        else:
            # DEBUG
            # print("\t\tdynamic_tf_policy.py: Grabbing model from catalog.")
            start = time.time()

            self.model = ModelCatalog.get_model_v2(obs_space,
                                                   action_space,
                                                   logit_dim,
                                                   self.config["model"],
                                                   framework="tf")
            # DEBUG
            # print("\t\tdynamic_tf_policy.py: Done grabbing model from catalog: %fs" % (time.time() - start))

        if existing_inputs:
            self._state_in = [
                v for k, v in existing_inputs.items()
                if k.startswith("state_in_")
            ]
            if self._state_in:
                self._seq_lens = existing_inputs["seq_lens"]
        else:
            self._state_in = [
                tf.placeholder(shape=(None, ) + s.shape, dtype=s.dtype)
                for s in self.model.get_initial_state()
            ]

        model_out, self._state_out = self.model(self._input_dict,
                                                self._state_in, self._seq_lens)

        # Setup action sampler
        if action_sampler_fn:
            action_sampler, action_logp = action_sampler_fn(
                self, self.model, self._input_dict, obs_space, action_space,
                config)
        else:
            action_dist = self._dist_class(model_out, self.model)
            action_sampler = action_dist.sample()
            action_logp = action_dist.sampled_action_logp()

        # Phase 1 init
        print("\t\t dynamic_tf_policy.py: default sesh:",
              tf.get_default_session())
        sess = tf.get_default_session() or tf.Session()
        if get_batch_divisibility_req:
            batch_divisibility_req = get_batch_divisibility_req(self)
        else:
            batch_divisibility_req = 1

        # DEBUG
        # print("\t\tdynamic_tf_policy.py: Until TFPolicy call: %fs" % (time.time() - start))
        # print("\t\tdynamic_tf_policy.py: Calling TFPolicy init.")
        TFPolicy.__init__(
            self,
            obs_space,
            action_space,
            sess,
            obs_input=obs,
            action_sampler=action_sampler,
            action_logp=action_logp,
            loss=None,  # dynamically initialized on run
            loss_inputs=[],
            model=self.model,
            state_inputs=self._state_in,
            state_outputs=self._state_out,
            prev_action_input=prev_actions,
            prev_reward_input=prev_rewards,
            seq_lens=self._seq_lens,
            max_seq_len=config["model"]["max_seq_len"],
            batch_divisibility_req=batch_divisibility_req,
        )
        # DEBUG
        # print("\t\tdynamic_tf_policy.py: Done calling TFPolicy init.")
        print("\t\tdynamic_tf_policy.py: Sess passed to TFPolicy init: %s" %
              str(sess))

        # Phase 2 init
        # DEBUG
        # print("\t\tdynamic_tf_policy.py: Starting beforelossinit init.")
        before_loss_init(self, obs_space, action_space, config)
        # print("\t\tdynamic_tf_policy.py: Done beforelossinit init.")
        if not existing_inputs:
            # print("\t\tdynamic_tf_policy.py: Init loss.")
            start = time.time()
            self._initialize_loss()
예제 #8
0
파일: ddpg_policy.py 프로젝트: zqxyz73/ray
    def __init__(self, observation_space, action_space, config):
        config = dict(ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG, **config)
        if not isinstance(action_space, Box):
            raise UnsupportedSpaceException(
                "Action space {} is not supported for DDPG.".format(
                    action_space))
        if len(action_space.shape) > 1:
            raise UnsupportedSpaceException(
                "Action space has multiple dimensions "
                "{}. ".format(action_space.shape) +
                "Consider reshaping this into a single dimension, "
                "using a Tuple action space, or the multi-agent API.")

        self.config = config

        # Create global step for counting the number of update operations.
        self.global_step = tf.train.get_or_create_global_step()
        # Create sampling timestep placeholder.
        timestep = tf.placeholder(tf.int32, (), name="timestep")

        # use separate optimizers for actor & critic
        self._actor_optimizer = tf.train.AdamOptimizer(
            learning_rate=self.config["actor_lr"])
        self._critic_optimizer = tf.train.AdamOptimizer(
            learning_rate=self.config["critic_lr"])

        # Observation inputs.
        self.cur_observations = tf.placeholder(tf.float32,
                                               shape=(None, ) +
                                               observation_space.shape,
                                               name="cur_obs")

        with tf.variable_scope(POLICY_SCOPE) as scope:
            policy_out, self.policy_model = self._build_policy_network(
                self.cur_observations, observation_space, action_space)
            self.policy_vars = scope_vars(scope.name)

        # Noise vars for P network except for layer normalization vars
        if self.config["parameter_noise"]:
            self._build_parameter_noise([
                var for var in self.policy_vars if "LayerNorm" not in var.name
            ])

        # Create exploration component.
        self.exploration = self._create_exploration(action_space, config)
        explore = tf.placeholder_with_default(True, (), name="is_exploring")
        # Action outputs
        with tf.variable_scope(ACTION_SCOPE):
            self.output_actions, _ = self.exploration.get_exploration_action(
                policy_out, Deterministic, self.policy_model, timestep,
                explore)

        # Replay inputs
        self.obs_t = tf.placeholder(tf.float32,
                                    shape=(None, ) + observation_space.shape,
                                    name="observation")
        self.act_t = tf.placeholder(tf.float32,
                                    shape=(None, ) + action_space.shape,
                                    name="action")
        self.rew_t = tf.placeholder(tf.float32, [None], name="reward")
        self.obs_tp1 = tf.placeholder(tf.float32,
                                      shape=(None, ) + observation_space.shape)
        self.done_mask = tf.placeholder(tf.float32, [None], name="done")
        self.importance_weights = tf.placeholder(tf.float32, [None],
                                                 name="weight")

        # policy network evaluation
        with tf.variable_scope(POLICY_SCOPE, reuse=True) as scope:
            prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
            self.policy_t, _ = self._build_policy_network(
                self.obs_t, observation_space, action_space)
            policy_batchnorm_update_ops = list(
                set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
                prev_update_ops)

        # target policy network evaluation
        with tf.variable_scope(POLICY_TARGET_SCOPE) as scope:
            policy_tp1, _ = self._build_policy_network(self.obs_tp1,
                                                       observation_space,
                                                       action_space)
            target_policy_vars = scope_vars(scope.name)

        # Action outputs
        with tf.variable_scope(ACTION_SCOPE, reuse=True):
            if config["smooth_target_policy"]:
                target_noise_clip = self.config["target_noise_clip"]
                clipped_normal_sample = tf.clip_by_value(
                    tf.random_normal(tf.shape(policy_tp1),
                                     stddev=self.config["target_noise"]),
                    -target_noise_clip, target_noise_clip)
                policy_tp1_smoothed = tf.clip_by_value(
                    policy_tp1 + clipped_normal_sample,
                    action_space.low * tf.ones_like(policy_tp1),
                    action_space.high * tf.ones_like(policy_tp1))
            else:
                # no smoothing, just use deterministic actions
                policy_tp1_smoothed = policy_tp1

        # q network evaluation
        prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
        with tf.variable_scope(Q_SCOPE) as scope:
            # Q-values for given actions & observations in given current
            q_t, self.q_model = self._build_q_network(self.obs_t,
                                                      observation_space,
                                                      action_space, self.act_t)
            self.q_func_vars = scope_vars(scope.name)
        self.stats = {
            "mean_q": tf.reduce_mean(q_t),
            "max_q": tf.reduce_max(q_t),
            "min_q": tf.reduce_min(q_t),
        }
        with tf.variable_scope(Q_SCOPE, reuse=True):
            # Q-values for current policy (no noise) in given current state
            q_t_det_policy, _ = self._build_q_network(self.obs_t,
                                                      observation_space,
                                                      action_space,
                                                      self.policy_t)
        if self.config["twin_q"]:
            with tf.variable_scope(TWIN_Q_SCOPE) as scope:
                twin_q_t, self.twin_q_model = self._build_q_network(
                    self.obs_t, observation_space, action_space, self.act_t)
                self.twin_q_func_vars = scope_vars(scope.name)
        q_batchnorm_update_ops = list(
            set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)

        # target q network evaluation
        with tf.variable_scope(Q_TARGET_SCOPE) as scope:
            q_tp1, _ = self._build_q_network(self.obs_tp1, observation_space,
                                             action_space, policy_tp1_smoothed)
            target_q_func_vars = scope_vars(scope.name)
        if self.config["twin_q"]:
            with tf.variable_scope(TWIN_Q_TARGET_SCOPE) as scope:
                twin_q_tp1, _ = self._build_q_network(self.obs_tp1,
                                                      observation_space,
                                                      action_space,
                                                      policy_tp1_smoothed)
                twin_target_q_func_vars = scope_vars(scope.name)

        if self.config["twin_q"]:
            self.critic_loss, self.actor_loss, self.td_error \
                = self._build_actor_critic_loss(
                    q_t, q_tp1, q_t_det_policy, twin_q_t=twin_q_t,
                    twin_q_tp1=twin_q_tp1)
        else:
            self.critic_loss, self.actor_loss, self.td_error \
                = self._build_actor_critic_loss(
                    q_t, q_tp1, q_t_det_policy)

        if config["l2_reg"] is not None:
            for var in self.policy_vars:
                if "bias" not in var.name:
                    self.actor_loss += (config["l2_reg"] * tf.nn.l2_loss(var))
            for var in self.q_func_vars:
                if "bias" not in var.name:
                    self.critic_loss += (config["l2_reg"] * tf.nn.l2_loss(var))
            if self.config["twin_q"]:
                for var in self.twin_q_func_vars:
                    if "bias" not in var.name:
                        self.critic_loss += (config["l2_reg"] *
                                             tf.nn.l2_loss(var))

        # update_target_fn will be called periodically to copy Q network to
        # target Q network
        self.tau_value = config.get("tau")
        self.tau = tf.placeholder(tf.float32, (), name="tau")
        update_target_expr = []
        for var, var_target in zip(
                sorted(self.q_func_vars, key=lambda v: v.name),
                sorted(target_q_func_vars, key=lambda v: v.name)):
            update_target_expr.append(
                var_target.assign(self.tau * var +
                                  (1.0 - self.tau) * var_target))
        if self.config["twin_q"]:
            for var, var_target in zip(
                    sorted(self.twin_q_func_vars, key=lambda v: v.name),
                    sorted(twin_target_q_func_vars, key=lambda v: v.name)):
                update_target_expr.append(
                    var_target.assign(self.tau * var +
                                      (1.0 - self.tau) * var_target))
        for var, var_target in zip(
                sorted(self.policy_vars, key=lambda v: v.name),
                sorted(target_policy_vars, key=lambda v: v.name)):
            update_target_expr.append(
                var_target.assign(self.tau * var +
                                  (1.0 - self.tau) * var_target))
        self.update_target_expr = tf.group(*update_target_expr)

        self.sess = tf.get_default_session()
        self.loss_inputs = [
            (SampleBatch.CUR_OBS, self.obs_t),
            (SampleBatch.ACTIONS, self.act_t),
            (SampleBatch.REWARDS, self.rew_t),
            (SampleBatch.NEXT_OBS, self.obs_tp1),
            (SampleBatch.DONES, self.done_mask),
            (PRIO_WEIGHTS, self.importance_weights),
        ]
        input_dict = dict(self.loss_inputs)

        if self.config["use_state_preprocessor"]:
            # Model self-supervised losses
            self.actor_loss = self.policy_model.custom_loss(
                self.actor_loss, input_dict)
            self.critic_loss = self.q_model.custom_loss(
                self.critic_loss, input_dict)
            if self.config["twin_q"]:
                self.critic_loss = self.twin_q_model.custom_loss(
                    self.critic_loss, input_dict)

        TFPolicy.__init__(self,
                          observation_space,
                          action_space,
                          self.config,
                          self.sess,
                          obs_input=self.cur_observations,
                          sampled_action=self.output_actions,
                          loss=self.actor_loss + self.critic_loss,
                          loss_inputs=self.loss_inputs,
                          update_ops=q_batchnorm_update_ops +
                          policy_batchnorm_update_ops,
                          explore=explore,
                          timestep=timestep)
        self.sess.run(tf.global_variables_initializer())

        # Note that this encompasses both the policy and Q-value networks and
        # their corresponding target networks
        self.variables = ray.experimental.tf_utils.TensorFlowVariables(
            tf.group(q_t_det_policy, q_tp1, self._actor_optimizer.variables(),
                     self._critic_optimizer.variables()), self.sess)

        # Hard initial update
        self.update_target(tau=1.0)
예제 #9
0
    def _initialize_loss_from_dummy_batch(
            self,
            auto_remove_unneeded_view_reqs: bool = True,
            stats_fn=None) -> None:

        # Create the optimizer/exploration optimizer here. Some initialization
        # steps (e.g. exploration postprocessing) may need this.
        self._optimizer = self.optimizer()

        # Test calls depend on variable init, so initialize model first.
        self._sess.run(tf1.global_variables_initializer())

        logger.info("Testing `compute_actions` w/ dummy batch.")
        actions, state_outs, extra_fetches = \
            self.compute_actions_from_input_dict(
                self._dummy_batch, explore=False, timestep=0)
        for key, value in extra_fetches.items():
            self._dummy_batch[key] = value
            self._input_dict[key] = get_placeholder(value=value, name=key)
            if key not in self.view_requirements:
                logger.info("Adding extra-action-fetch `{}` to "
                            "view-reqs.".format(key))
                self.view_requirements[key] = \
                    ViewRequirement(space=gym.spaces.Box(
                        -1.0, 1.0, shape=value.shape[1:],
                        dtype=value.dtype))
        dummy_batch = self._dummy_batch

        logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
        self.exploration.postprocess_trajectory(self, dummy_batch, self._sess)
        _ = self.postprocess_trajectory(dummy_batch)
        # Add new columns automatically to (loss) input_dict.
        for key in dummy_batch.added_keys:
            if key not in self._input_dict:
                self._input_dict[key] = get_placeholder(value=dummy_batch[key],
                                                        name=key)
            if key not in self.view_requirements:
                self.view_requirements[key] = \
                    ViewRequirement(space=gym.spaces.Box(
                        -1.0, 1.0, shape=dummy_batch[key].shape[1:],
                        dtype=dummy_batch[key].dtype))

        train_batch = SampleBatch(
            dict(self._input_dict, **self._loss_input_dict))

        if self._state_inputs:
            train_batch["seq_lens"] = self._seq_lens
            self._loss_input_dict.update({"seq_lens": train_batch["seq_lens"]})

        self._loss_input_dict.update({k: v for k, v in train_batch.items()})

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        loss = self._do_loss_init(train_batch)

        all_accessed_keys = \
            train_batch.accessed_keys | dummy_batch.accessed_keys | \
            dummy_batch.added_keys | set(
                self.model.view_requirements.keys())

        TFPolicy._initialize_loss(
            self, loss,
            [(k, v)
             for k, v in train_batch.items() if k in all_accessed_keys] +
            ([("seq_lens",
               train_batch["seq_lens"])] if "seq_lens" in train_batch else []))

        if "is_training" in self._loss_input_dict:
            del self._loss_input_dict["is_training"]

        # Call the grads stats fn.
        # TODO: (sven) rename to simply stats_fn to match eager and torch.
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))

        # Add new columns automatically to view-reqs.
        if auto_remove_unneeded_view_reqs:
            # Add those needed for postprocessing and training.
            all_accessed_keys = train_batch.accessed_keys | \
                                dummy_batch.accessed_keys
            # Tag those only needed for post-processing (with some exceptions).
            for key in dummy_batch.accessed_keys:
                if key not in train_batch.accessed_keys and \
                        key not in self.model.view_requirements and \
                        key not in [
                            SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                            SampleBatch.UNROLL_ID, SampleBatch.DONES,
                            SampleBatch.REWARDS, SampleBatch.INFOS]:
                    if key in self.view_requirements:
                        self.view_requirements[key].used_for_training = False
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Remove those not needed at all (leave those that are needed
            # by Sampler to properly execute sample collection).
            # Also always leave DONES, REWARDS, and INFOS, no matter what.
            for key in list(self.view_requirements.keys()):
                if key not in all_accessed_keys and key not in [
                    SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
                    SampleBatch.UNROLL_ID, SampleBatch.DONES,
                    SampleBatch.REWARDS, SampleBatch.INFOS] and \
                        key not in self.model.view_requirements:
                    # If user deleted this key manually in postprocessing
                    # fn, warn about it and do not remove from
                    # view-requirements.
                    if key in dummy_batch.deleted_keys:
                        logger.warning(
                            "SampleBatch key '{}' was deleted manually in "
                            "postprocessing function! RLlib will "
                            "automatically remove non-used items from the "
                            "data stream. Remove the `del` from your "
                            "postprocessing function.".format(key))
                    else:
                        del self.view_requirements[key]
                    if key in self._loss_input_dict:
                        del self._loss_input_dict[key]
            # Add those data_cols (again) that are missing and have
            # dependencies by view_cols.
            for key in list(self.view_requirements.keys()):
                vr = self.view_requirements[key]
                if (vr.data_col is not None
                        and vr.data_col not in self.view_requirements):
                    used_for_training = \
                        vr.data_col in train_batch.accessed_keys
                    self.view_requirements[vr.data_col] = ViewRequirement(
                        space=vr.space, used_for_training=used_for_training)

        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }

        # Initialize again after loss init.
        self._sess.run(tf1.global_variables_initializer())
예제 #10
0
 def set_state(self, state):
     TFPolicy.set_state(self, state[0])
     self.set_epsilon(state[1])
예제 #11
0
 def get_state(self):
     return [TFPolicy.get_state(self), self.cur_epsilon]