Пример #1
0
 def compute_gradients(self, samples):
     if log_once("compute_gradients"):
         logger.info("Compute gradients on:\n\n{}\n".format(
             summarize(samples)))
     if isinstance(samples, MultiAgentBatch):
         grad_out, info_out = {}, {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "compute_gradients")
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid]._build_compute_gradients(
                         builder, batch))
             grad_out = {k: builder.get(v) for k, v in grad_out.items()}
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 grad_out[pid], info_out[pid] = (
                     self.policy_map[pid].compute_gradients(batch))
     else:
         grad_out, info_out = (
             self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
     info_out["batch_count"] = samples.count
     if log_once("grad_out"):
         logger.info("Compute grad info:\n\n{}\n".format(
             summarize(info_out)))
     return grad_out, info_out
Пример #2
0
 def learn_on_batch(self, samples):
     if log_once("learn_on_batch"):
         logger.info(
             "Training on concatenated sample batches:\n\n{}\n".format(
                 summarize(samples)))
     if isinstance(samples, MultiAgentBatch):
         info_out = {}
         to_fetch = {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
         else:
             builder = None
         for pid, batch in samples.policy_batches.items():
             if pid not in self.policies_to_train:
                 continue
             policy = self.policy_map[pid]
             if builder and hasattr(policy, "_build_learn_on_batch"):
                 to_fetch[pid] = policy._build_learn_on_batch(
                     builder, batch)
             else:
                 info_out[pid] = policy.learn_on_batch(batch)
         info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
     else:
         info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
             samples)
     if log_once("learn_out"):
         logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
     return info_out
Пример #3
0
 def learn_on_batch(self, samples):
     if log_once("learn_on_batch"):
         logger.info(
             "Training on concatenated sample batches:\n\n{}\n".format(
                 summarize(samples)))
     if isinstance(samples, MultiAgentBatch):
         info_out = {}
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 info_out[pid], _ = (
                     self.policy_map[pid]._build_learn_on_batch(
                         builder, batch))
             info_out = {k: builder.get(v) for k, v in info_out.items()}
         else:
             for pid, batch in samples.policy_batches.items():
                 if pid not in self.policies_to_train:
                     continue
                 info_out[pid], _ = (
                     self.policy_map[pid].learn_on_batch(batch))
     else:
         info_out, _ = (
             self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
     if log_once("learn_out"):
         logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
     return info_out
Пример #4
0
    def sample(self):
        """Evaluate the current policies and return a batch of experiences.

        Return:
            SampleBatch|MultiAgentBatch from evaluating the current policies.
        """

        if self._fake_sampler and self.last_batch is not None:
            return self.last_batch

        if log_once("sample_start"):
            logger.info("Generating sample batch of size {}".format(
                self.sample_batch_size))

        batches = [self.input_reader.next()]
        steps_so_far = batches[0].count

        # In truncate_episodes mode, never pull more than 1 batch per env.
        # This avoids over-running the target batch size.
        if self.batch_mode == "truncate_episodes":
            max_batches = self.num_envs
        else:
            max_batches = float("inf")

        while steps_so_far < self.sample_batch_size and len(
                batches) < max_batches:
            batch = self.input_reader.next()
            steps_so_far += batch.count
            batches.append(batch)
        batch = batches[0].concat_samples(batches)

        if self.callbacks.get("on_sample_end"):
            self.callbacks["on_sample_end"]({
                "evaluator": self,
                "samples": batch
            })

        # Always do writes prior to compression for consistency and to allow
        # for better compression inside the writer.
        self.output_writer.write(batch)

        # Do off-policy estimation if needed
        if self.reward_estimators:
            for sub_batch in batch.split_by_episode():
                for estimator in self.reward_estimators:
                    estimator.process(sub_batch)

        if log_once("sample_end"):
            logger.info("Completed sample batch:\n\n{}\n".format(
                summarize(batch)))

        if self.compress_observations == "bulk":
            batch.compress(bulk=True)
        elif self.compress_observations:
            batch.compress()

        if self._fake_sampler:
            self.last_batch = batch
        return batch
Пример #5
0
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
    """Call compute actions on observation batches to get next actions.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    for policy_id, eval_data in to_eval.items():
        rnn_in = [t.rnn_state for t in eval_data]
        policy = _get_or_raise(policies, policy_id)
        if builder and (policy.compute_actions.__code__ is
                        TFPolicy.compute_actions.__code__):
            rnn_in_cols = _to_column_format(rnn_in)
            # TODO(ekl): how can we make info batch available to TF code?
            # TODO(sven): Return dict from _build_compute_actions.
            # it's becoming more and more unclear otherwise, what's where in
            # the return tuple.
            pending_fetches[policy_id] = policy._build_compute_actions(
                builder,
                obs_batch=[t.obs for t in eval_data],
                state_batches=rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                timestep=policy.global_timestep)
        else:
            # TODO(sven): Does this work for LSTM torch?
            rnn_in_cols = [
                np.stack([row[i] for row in rnn_in])
                for i in range(len(rnn_in[0]))
            ]
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                state_batches=rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data],
                timestep=policy.global_timestep)
    if builder:
        for k, v in pending_fetches.items():
            eval_results[k] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
Пример #6
0
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
    """Call compute actions on observation batches to get next actions.

    Returns:
        eval_results: dict of policy to compute_action() outputs.
    """

    eval_results = {}

    if tf_sess:
        builder = TFRunBuilder(tf_sess, "policy_eval")
        pending_fetches = {}
    else:
        builder = None

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(
            summarize(to_eval)))

    for policy_id, eval_data in to_eval.items():
        rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
        policy = _get_or_raise(policies, policy_id)
        if builder and (policy.compute_actions.__code__ is
                        TFPolicy.compute_actions.__code__):
            # TODO(ekl): how can we make info batch available to TF code?
            # print('First: ' + str(eval_data))
            pending_fetches[policy_id] = policy._build_compute_actions(
                builder, [t.obs for t in eval_data],
                [t.neighbor_obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data])
        else:
            # print('Second: ' + str(eval_data))
            eval_results[policy_id] = policy.compute_actions(
                [t.obs for t in eval_data],
                [t.neighbor_obs for t in eval_data],
                rnn_in_cols,
                prev_action_batch=[t.prev_action for t in eval_data],
                prev_reward_batch=[t.prev_reward for t in eval_data],
                info_batch=[t.info for t in eval_data],
                episodes=[active_episodes[t.env_id] for t in eval_data])
    if builder:
        for k, v in pending_fetches.items():
            eval_results[k] = builder.get(v)

    if log_once("compute_actions_result"):
        logger.info("Outputs of compute_actions():\n\n{}\n".format(
            summarize(eval_results)))

    return eval_results
Пример #7
0
        def value_function(self):
            assert self.cur_instance, "must call forward first"

            with tf.variable_scope(self.variable_scope):
                with tf.variable_scope("value_function", reuse=tf.AUTO_REUSE):
                    # Simple case: sharing the feature layer
                    if self.model_config["vf_share_layers"]:
                        return tf.reshape(
                            linear(self.cur_instance.last_layer, 1,
                                   "value_function", normc_initializer(1.0)),
                            [-1])

                    # Create a new separate model with no RNN state, etc.
                    branch_model_config = self.model_config.copy()
                    branch_model_config["free_log_std"] = False
                    if branch_model_config["use_lstm"]:
                        branch_model_config["use_lstm"] = False
                        if log_once("vf_warn"):
                            logger.warning(
                                "It is not recommended to use a LSTM model "
                                "with vf_share_layers=False (consider setting "
                                "it to True). If you want to not share "
                                "layers, you can implement a custom LSTM "
                                "model that overrides the value_function() "
                                "method.")
                    branch_instance = self.legacy_model_cls(
                        self.cur_instance.input_dict,
                        self.obs_space,
                        self.action_space,
                        1,
                        branch_model_config,
                        state_in=None,
                        seq_lens=None)
                    return tf.reshape(branch_instance.outputs, [-1])
Пример #8
0
def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None):
    if timeline_dir:
        from tensorflow.python.client import timeline

        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()
        start = time.time()
        fetches = sess.run(ops,
                           options=run_options,
                           run_metadata=run_metadata,
                           feed_dict=feed_dict)
        trace = timeline.Timeline(step_stats=run_metadata.step_stats)
        global _count
        outf = os.path.join(
            timeline_dir,
            "timeline-{}-{}-{}.json".format(debug_name, os.getpid(),
                                            _count % 10))
        _count += 1
        trace_file = open(outf, "w")
        logger.info("Wrote tf timeline ({} s) to {}".format(
            time.time() - start, os.path.abspath(outf)))
        trace_file.write(trace.generate_chrome_trace_format())
    else:
        if log_once("tf_timeline"):
            logger.info(
                "Executing TF run without tracing. To dump TF timeline traces "
                "to disk, set the TF_TIMELINE_DIR environment variable.")
        fetches = sess.run(ops, feed_dict=feed_dict)
    return fetches
Пример #9
0
    def learn_on_batch(self, samples):
        if log_once("learn_on_batch"):
            logger.info(
                "Training on concatenated sample batches:\n\n{}\n".format(
                    summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            info_out = {}
            to_fetch = {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
            else:
                builder = None
            for pid, batch in samples.policy_batches.items():
                if pid not in self.policies_to_train:
                    continue
                '''
                 For Attention !!!!!!!!!!!!!!!!!!!!
                 Build a dict for neighbor agents (also include current agent) observation 
                '''
                neighbor_batch_dic = {}
                neighbor_pid_list = [
                    pid_ for pid_ in
                    self.traffic_light_node_dict[self.inter_num_2_id(
                        int(pid.split('_')[1]))]['adjacency_row']
                    if pid_ != None
                ]
                for neighbor_pid in neighbor_pid_list:
                    neighbor_batch_dic['policy_{}'.format(
                        neighbor_pid)] = samples.policy_batches[
                            'policy_{}'.format(neighbor_pid)]
                # neighbor_batch_dic[pid] = samples.policy_batches[pid]
                # ------------------------------------------------------------------

                policy = self.policy_map[pid]
                if builder and hasattr(policy, "_build_learn_on_batch"):
                    to_fetch[pid] = policy._build_learn_on_batch(
                        builder, batch, neighbor_batch_dic)
                else:
                    info_out[pid] = policy.learn_on_batch(
                        batch, neighbor_batch_dic)
            info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
        else:
            info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
                samples)
        if log_once("learn_out"):
            logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
        return info_out
Пример #10
0
    def postprocess_batch_so_far(self, episode):
        """Apply policy postprocessors to any unprocessed rows.

        This pushes the postprocessed per-agent batches onto the per-policy
        builders, clearing per-agent state.

        Arguments:
            episode: current MultiAgentEpisode object or None
        """

        # Materialize the batches so far
        pre_batches = {}
        for agent_id, builder in self.agent_builders.items():
            pre_batches[agent_id] = (
                self.policy_map[self.agent_to_policy[agent_id]],
                builder.build_and_reset())

        # Apply postprocessor
        post_batches = {}
        if self.clip_rewards:
            for _, (_, pre_batch) in pre_batches.items():
                pre_batch["rewards"] = np.sign(pre_batch["rewards"])
        for agent_id, (_, pre_batch) in pre_batches.items():
            other_batches = pre_batches.copy()
            del other_batches[agent_id]
            policy = self.policy_map[self.agent_to_policy[agent_id]]
            if any(pre_batch["dones"][:-1]) or len(set(
                    pre_batch["eps_id"])) > 1:
                raise ValueError(
                    "Batches sent to postprocessing must only contain steps "
                    "from a single trajectory.", pre_batch)
            post_batches[agent_id] = policy.postprocess_trajectory(
                pre_batch, other_batches, episode)

        if log_once("after_post"):
            logger.info(
                "Trajectory fragment after postprocess_trajectory():\n\n{}\n".
                format(summarize(post_batches)))

        # Append into policy batches and reset
        for agent_id, post_batch in sorted(post_batches.items()):
            self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
                post_batch)
            if self.postp_callback:
                self.postp_callback({
                    "episode": episode,
                    "agent_id": agent_id,
                    "pre_batch": pre_batches[agent_id],
                    "post_batch": post_batch,
                    "all_pre_batches": pre_batches,
                })

        self.agent_builders.clear()
        self.agent_to_policy.clear()
Пример #11
0
    def learn_on_batch(self, samples):
        if log_once("learn_on_batch"):
            logger.info(
                "Training on concatenated sample batches:\n\n{}\n".format(
                    summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            info_out = {}
            to_fetch = {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
            else:
                builder = None
            for pid, batch in samples.policy_batches.items():
                if pid not in self.policies_to_train:
                    continue
                policy = self.policy_map[pid]
                if builder and hasattr(policy, "_build_learn_on_batch"):
                    to_fetch[pid] = policy._build_learn_on_batch(
                        builder, batch)
                else:
                    info_out[pid] = policy.learn_on_batch(batch)
            info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
        else:
            learn_on_batch_outputs = self.policy_map[
                DEFAULT_POLICY_ID].learn_on_batch(samples)

            if isinstance(learn_on_batch_outputs, tuple):
                info_out, beholder_arrays = learn_on_batch_outputs
            else:
                info_out, beholder_arrays = learn_on_batch_outputs, {}

            if self.policy_config["evaluation_config"]["beholder"]:
                with self.tf_sess.graph.as_default():
                    if self._beholder is None:
                        self._beholder = Beholder(self.io_context.log_dir)
                    self._beholder.update(self.tf_sess,
                                          arrays=beholder_arrays or None)

        if log_once("learn_out"):
            logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
        return info_out
Пример #12
0
    def _get_loss_inputs_dict(self, batch):
        feed_dict = {}
        if self._batch_divisibility_req > 1:
            meets_divisibility_reqs = (
                len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req
                == 0
                and max(batch[SampleBatch.AGENT_INDEX]) == 0)  # not multiagent
        else:
            meets_divisibility_reqs = True

        # Simple case: not RNN nor do we need to pad
        if not self._state_inputs and meets_divisibility_reqs:
            for k, ph in self._loss_inputs:
                feed_dict[ph] = batch[k]
            return feed_dict

        if self._state_inputs:
            max_seq_len = self._max_seq_len
            dynamic_max = True
        else:
            max_seq_len = self._batch_divisibility_req
            dynamic_max = False

        # RNN or multi-agent case
        feature_keys = [k for k, v in self._loss_inputs]
        state_keys = [
            "state_in_{}".format(i) for i in range(len(self._state_inputs))
        ]
        feature_sequences, initial_states, seq_lens = chop_into_sequences(
            batch[SampleBatch.EPS_ID],
            batch[SampleBatch.UNROLL_ID],
            batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
            [batch[k] for k in state_keys],
            max_seq_len,
            dynamic_max=dynamic_max)
        for k, v in zip(feature_keys, feature_sequences):
            feed_dict[self._loss_input_dict[k]] = v
        for k, v in zip(state_keys, initial_states):
            feed_dict[self._loss_input_dict[k]] = v
        feed_dict[self._seq_lens] = seq_lens

        if log_once("rnn_feed_dict"):
            logger.info("Padded input for RNN:\n\n{}\n".format(
                summarize({
                    "features": feature_sequences,
                    "initial_states": initial_states,
                    "seq_lens": seq_lens,
                    "max_seq_len": max_seq_len,
                })))
        return feed_dict
Пример #13
0
    def _initialize_loss(self, loss, loss_inputs):
        self._loss_inputs = loss_inputs
        self._loss_input_dict = dict(self._loss_inputs)
        for i, ph in enumerate(self._state_inputs):
            self._loss_input_dict["state_in_{}".format(i)] = ph

        if self.model:
            self._loss = self.model.custom_loss(loss, self._loss_input_dict)
            self._stats_fetches.update({
                "model":
                self.model.metrics() if isinstance(self.model, ModelV2) else
                self.model.custom_stats()
            })
        else:
            self._loss = loss

        self._optimizer = self.optimizer()
        self._grads_and_vars = [
            (g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
            if g is not None
        ]
        self._grads = [g for (g, v) in self._grads_and_vars]

        # TODO(sven/ekl): Deprecate support for v1 models.
        if hasattr(self, "model") and isinstance(self.model, ModelV2):
            self._variables = ray.experimental.tf_utils.TensorFlowVariables(
                [], self._sess, self.variables())
        else:
            self._variables = ray.experimental.tf_utils.TensorFlowVariables(
                self._loss, self._sess)

        # gather update ops for any batch norm layers
        if not self._update_ops:
            self._update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
        if self._update_ops:
            logger.info("Update ops to run on apply gradient: {}".format(
                self._update_ops))
        with tf.control_dependencies(self._update_ops):
            self._apply_op = self.build_apply_op(self._optimizer,
                                                 self._grads_and_vars)

        if log_once("loss_used"):
            logger.debug(
                "These tensors were used in the loss_fn:\n\n{}\n".format(
                    summarize(self._loss_input_dict)))

        self._sess.run(tf.global_variables_initializer())
Пример #14
0
        def _compute_gradients(self, samples):
            """Computes and returns grads as eager tensors."""

            self._is_training = True

            with tf.GradientTape(persistent=gradients_fn is not None) as tape:
                # TODO: set seq len and state-in properly
                state_in = []
                for i in range(self.num_state_tensors()):
                    state_in.append(samples["state_in_{}".format(i)])
                self._state_in = state_in

                self._seq_lens = None
                if len(state_in) > 0:
                    self._seq_lens = tf.ones(
                        samples[SampleBatch.CUR_OBS].shape[0], dtype=tf.int32)
                    samples["seq_lens"] = self._seq_lens

                model_out, _ = self.model(samples, self._state_in,
                                          self._seq_lens)
                loss = loss_fn(self, self.model, self.dist_class, samples)

            variables = self.model.trainable_variables()

            if gradients_fn:

                class OptimizerWrapper:
                    def __init__(self, tape):
                        self.tape = tape

                    def compute_gradients(self, loss, var_list):
                        return list(
                            zip(self.tape.gradient(loss, var_list), var_list))

                grads_and_vars = gradients_fn(self, OptimizerWrapper(tape),
                                              loss)
            else:
                grads_and_vars = list(
                    zip(tape.gradient(loss, variables), variables))

            if log_once("grad_vars"):
                for _, v in grads_and_vars:
                    logger.info("Optimizing variable {}".format(v.name))

            grads = [g for g, v in grads_and_vars]
            stats = self._stats(self, samples, grads)
            return grads_and_vars, stats
Пример #15
0
 def apply_gradients(self, grads):
     if log_once("apply_gradients"):
         logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
     if isinstance(grads, dict):
         if self.tf_sess is not None:
             builder = TFRunBuilder(self.tf_sess, "apply_gradients")
             outputs = {
                 pid: self.policy_map[pid]._build_apply_gradients(
                     builder, grad)
                 for pid, grad in grads.items()
             }
             return {k: builder.get(v) for k, v in outputs.items()}
         else:
             return {
                 pid: self.policy_map[pid].apply_gradients(g)
                 for pid, g in grads.items()
             }
     else:
         return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
Пример #16
0
        def _compute_gradients(self, samples):
            """Computes and returns grads as eager tensors."""

            self._is_training = True

            samples = {
                k: tf.convert_to_tensor(v)
                for k, v in samples.items() if v.dtype != np.object
            }

            with tf.GradientTape(persistent=gradients_fn is not None) as tape:
                # TODO: set seq len and state in properly
                self._seq_lens = tf.ones(len(samples[SampleBatch.CUR_OBS]))
                self._state_in = []
                model_out, _ = self.model(samples, self._state_in,
                                          self._seq_lens)
                loss = loss_fn(self, self.model, self.dist_class, samples)

            variables = self.model.trainable_variables()

            if gradients_fn:

                class OptimizerWrapper(object):
                    def __init__(self, tape):
                        self.tape = tape

                    def compute_gradients(self, loss, var_list):
                        return list(
                            zip(self.tape.gradient(loss, var_list), var_list))

                grads_and_vars = gradients_fn(self, OptimizerWrapper(tape),
                                              loss)
            else:
                grads_and_vars = list(
                    zip(tape.gradient(loss, variables), variables))

            if log_once("grad_vars"):
                for _, v in grads_and_vars:
                    logger.info("Optimizing variable {}".format(v.name))

            grads = [g for g, v in grads_and_vars]
            stats = self._stats(self, samples, grads)
            return grads_and_vars, stats
Пример #17
0
 def _debug_vars(self):
     if log_once("grad_vars"):
         for _, v in self._grads_and_vars:
             logger.info("Optimizing variable {}".format(v))
Пример #18
0
def _process_observations(base_env, policies, batch_builder_pool,
                          active_episodes, unfiltered_obs, rewards, dones,
                          infos, off_policy_actions, horizon, preprocessors,
                          obs_filters, unroll_length, pack, callbacks):
    """Record new data from the environment and prepare for policy evaluation.

    Returns:
        active_envs: set of non-terminated env ids
        to_eval: map of policy_id to list of agent PolicyEvalData
        outputs: list of metrics and samples to return from the sampler
    """

    active_envs = set()
    to_eval = defaultdict(list)
    outputs = []

    # For each environment
    for env_id, agent_obs in unfiltered_obs.items():
        new_episode = env_id not in active_episodes
        episode = active_episodes[env_id]
        if not new_episode:
            episode.length += 1
            episode.batch_builder.count += 1
            episode._add_agent_rewards(rewards[env_id])

        if (episode.batch_builder.total() > max(1000, unroll_length * 10)
                and log_once("large_batch_warning")):
            logger.warning(
                "More than {} observations for {} env steps ".format(
                    episode.batch_builder.total(),
                    episode.batch_builder.count) + "are buffered in "
                "the sampler. If this is more than you expected, check that "
                "that you set a horizon on your environment correctly. Note "
                "that in multi-agent environments, `sample_batch_size` sets "
                "the batch size based on environment steps, not the steps of "
                "individual agents, which can result in unexpectedly large "
                "batches.")

        # Check episode termination conditions
        if dones[env_id]["__all__"] or episode.length >= horizon:
            all_done = True
            atari_metrics = _fetch_atari_metrics(base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics))
        else:
            all_done = False
            active_envs.add(env_id)

        # For each agent in the environment
        for agent_id, raw_obs in agent_obs.items():
            policy_id = episode.policy_for(agent_id)
            prep_obs = _get_or_raise(preprocessors,
                                     policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))

            filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            agent_done = bool(all_done or dones[env_id].get(agent_id))
            if not agent_done:
                to_eval[policy_id].append(
                    PolicyEvalData(env_id, agent_id, filtered_obs,
                                   infos[env_id].get(agent_id, {}),
                                   episode.rnn_state_for(agent_id),
                                   episode.last_action_for(agent_id),
                                   rewards[env_id][agent_id] or 0.0))

            last_observation = episode.last_observation_for(agent_id)
            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))

            # Record transition info if applicable
            if (last_observation is not None and infos[env_id].get(
                    agent_id, {}).get("training_enabled", True)):
                episode.batch_builder.add_values(
                    agent_id,
                    policy_id,
                    t=episode.length - 1,
                    eps_id=episode.episode_id,
                    agent_index=episode._agent_index(agent_id),
                    obs=last_observation,
                    actions=episode.last_action_for(agent_id),
                    rewards=rewards[env_id][agent_id],
                    prev_actions=episode.prev_action_for(agent_id),
                    prev_rewards=episode.prev_reward_for(agent_id),
                    dones=agent_done,
                    infos=infos[env_id].get(agent_id, {}),
                    new_obs=filtered_obs,
                    **episode.last_pi_info_for(agent_id))

        # Invoke the step callback after the step is logged to the episode
        if callbacks.get("on_episode_step"):
            callbacks["on_episode_step"]({"env": base_env, "episode": episode})

        # Cut the batch if we're not packing multiple episodes into one,
        # or if we've exceeded the requested batch size.
        if episode.batch_builder.has_pending_data():
            if dones[env_id]["__all__"]:
                episode.batch_builder.check_missing_dones()
            if (all_done and not pack) or \
                    episode.batch_builder.count >= unroll_length:
                outputs.append(episode.batch_builder.build_and_reset(episode))
            elif all_done:
                # Make sure postprocessor stays within one episode
                episode.batch_builder.postprocess_batch_so_far(episode)

        if all_done:
            # Handle episode termination
            batch_builder_pool.append(episode.batch_builder)
            if callbacks.get("on_episode_end"):
                callbacks["on_episode_end"]({
                    "env": base_env,
                    "policy": policies,
                    "episode": episode
                })
            del active_episodes[env_id]
            resetted_obs = base_env.try_reset(env_id)
            if resetted_obs is None:
                # Reset not supported, drop this env from the ready list
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            else:
                # Creates a new episode
                episode = active_episodes[env_id]
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id = episode.policy_for(agent_id)
                    policy = _get_or_raise(policies, policy_id)
                    prep_obs = _get_or_raise(preprocessors,
                                             policy_id).transform(raw_obs)
                    filtered_obs = _get_or_raise(obs_filters,
                                                 policy_id)(prep_obs)
                    episode._set_last_observation(agent_id, filtered_obs)
                    to_eval[policy_id].append(
                        PolicyEvalData(
                            env_id, agent_id, filtered_obs,
                            episode.last_info_for(agent_id) or {},
                            episode.rnn_state_for(agent_id),
                            np.zeros_like(
                                _flatten_action(policy.action_space.sample())),
                            0.0))

    return active_envs, to_eval, outputs
Пример #19
0
def _env_runner(base_env,
                extra_batch_callback,
                policies,
                policy_mapping_fn,
                unroll_length,
                horizon,
                preprocessors,
                obs_filters,
                clip_rewards,
                clip_actions,
                pack,
                callbacks,
                tf_sess=None):
    """This implements the common experience collection logic.

    Args:
        base_env (BaseEnv): env implementing BaseEnv.
        extra_batch_callback (fn): function to send extra batch data to.
        policies (dict): Map of policy ids to PolicyGraph instances.
        policy_mapping_fn (func): Function that maps agent ids to policy ids.
            This is called when an agent first enters the environment. The
            agent is then "bound" to the returned policy for the episode.
        unroll_length (int): Number of episode steps before `SampleBatch` is
            yielded. Set to infinity to yield complete episodes.
        horizon (int): Horizon of the episode.
        preprocessors (dict): Map of policy id to preprocessor for the
            observations prior to filtering.
        obs_filters (dict): Map of policy id to filter used to process
            observations for the policy.
        clip_rewards (bool): Whether to clip rewards before postprocessing.
        pack (bool): Whether to pack multiple episodes into each batch. This
            guarantees batches will be exactly `unroll_length` in size.
        clip_actions (bool): Whether to clip actions to the space range.
        callbacks (dict): User callbacks to run on episode events.
        tf_sess (Session|None): Optional tensorflow session to use for batching
            TF policy evaluations.

    Yields:
        rollout (SampleBatch): Object containing state, action, reward,
            terminal condition, and other fields as dictated by `policy`.
    """

    try:
        if not horizon:
            horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps)
    except Exception:
        logger.debug("no episode horizon specified, assuming inf")
    if not horizon:
        horizon = float("inf")

    # Pool of batch builders, which can be shared across episodes to pack
    # trajectory data.
    batch_builder_pool = []

    def get_batch_builder():
        if batch_builder_pool:
            return batch_builder_pool.pop()
        else:
            return MultiAgentSampleBatchBuilder(policies, clip_rewards)

    def new_episode():
        episode = MultiAgentEpisode(policies, policy_mapping_fn,
                                    get_batch_builder, extra_batch_callback)
        if callbacks.get("on_episode_start"):
            callbacks["on_episode_start"]({
                "env": base_env,
                "policy": policies,
                "episode": episode,
            })
        return episode

    active_episodes = defaultdict(new_episode)

    while True:
        # Get observations from all ready agents
        unfiltered_obs, rewards, dones, infos, off_policy_actions = \
            base_env.poll()

        if log_once("env_returns"):
            logger.info("Raw obs from env: {}".format(
                summarize(unfiltered_obs)))
            logger.info("Info return from env: {}".format(summarize(infos)))

        # Process observations and prepare for policy evaluation
        active_envs, to_eval, outputs = _process_observations(
            base_env, policies, batch_builder_pool, active_episodes,
            unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
            preprocessors, obs_filters, unroll_length, pack, callbacks)
        for o in outputs:
            yield o

        # Do batched policy eval
        eval_results = _do_policy_eval(tf_sess, to_eval, policies,
                                       active_episodes)

        # Process results and update episode state
        actions_to_send = _process_policy_eval_results(
            to_eval, eval_results, active_episodes, active_envs,
            off_policy_actions, policies, clip_actions)

        # Return computed actions to ready envs. We also send to envs that have
        # taken off-policy actions; those envs are free to ignore the action.
        base_env.send_actions(actions_to_send)
Пример #20
0
    def _get_loss_inputs_dict(self, batch, shuffle):
        """Return a feed dict from a batch.

        Arguments:
            batch (SampleBatch): batch of data to derive inputs from
            shuffle (bool): whether to shuffle batch sequences. Shuffle may
                be done in-place. This only makes sense if you're further
                applying minibatch SGD after getting the outputs.

        Returns:
            feed dict of data
        """

        feed_dict = {}
        if self._batch_divisibility_req > 1:
            meets_divisibility_reqs = (
                len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req
                == 0
                and max(batch[SampleBatch.AGENT_INDEX]) == 0)  # not multiagent
        else:
            meets_divisibility_reqs = True

        # Simple case: not RNN nor do we need to pad
        if not self._state_inputs and meets_divisibility_reqs:
            if shuffle:
                batch.shuffle()
            for k, ph in self._loss_inputs:
                feed_dict[ph] = batch[k]
            return feed_dict

        if self._state_inputs:
            max_seq_len = self._max_seq_len
            dynamic_max = True
        else:
            max_seq_len = self._batch_divisibility_req
            dynamic_max = False

        # RNN or multi-agent case
        feature_keys = [k for k, v in self._loss_inputs]
        state_keys = [
            "state_in_{}".format(i) for i in range(len(self._state_inputs))
        ]
        feature_sequences, initial_states, seq_lens = chop_into_sequences(
            batch[SampleBatch.EPS_ID],
            batch[SampleBatch.UNROLL_ID],
            batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
            [batch[k] for k in state_keys],
            max_seq_len,
            dynamic_max=dynamic_max,
            shuffle=shuffle)
        for k, v in zip(feature_keys, feature_sequences):
            feed_dict[self._loss_input_dict[k]] = v
        for k, v in zip(state_keys, initial_states):
            feed_dict[self._loss_input_dict[k]] = v
        feed_dict[self._seq_lens] = seq_lens

        if log_once("rnn_feed_dict"):
            logger.info("Padded input for RNN:\n\n{}\n".format(
                summarize({
                    "features": feature_sequences,
                    "initial_states": initial_states,
                    "seq_lens": seq_lens,
                    "max_seq_len": max_seq_len,
                })))
        return feed_dict
Пример #21
0
    def _initialize_loss(self):
        def fake_array(tensor):
            shape = tensor.shape.as_list()
            shape[0] = 1
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        dummy_batch = {
            SampleBatch.CUR_OBS:
            fake_array(self._obs_input),
            SampleBatch.NEXT_OBS:
            fake_array(self._obs_input),
            SampleBatch.DONES:
            np.array([False], dtype=np.bool),
            SampleBatch.ACTIONS:
            fake_array(ModelCatalog.get_action_placeholder(self.action_space)),
            SampleBatch.REWARDS:
            np.array([0], dtype=np.float32),
        }
        if self._obs_include_prev_action_reward:
            dummy_batch.update({
                SampleBatch.PREV_ACTIONS:
                fake_array(self._prev_action_input),
                SampleBatch.PREV_REWARDS:
                fake_array(self._prev_reward_input),
            })
        state_init = self.get_initial_state()
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
            dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
        if state_init:
            dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())
        postprocessed_batch = self.postprocess_trajectory(
            SampleBatch(dummy_batch))

        if self._obs_include_prev_action_reward:
            batch_tensors = UsageTrackingDict({
                SampleBatch.PREV_ACTIONS:
                self._prev_action_input,
                SampleBatch.PREV_REWARDS:
                self._prev_reward_input,
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            batch_tensors = UsageTrackingDict({
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        for k, v in postprocessed_batch.items():
            if k in batch_tensors:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            shape = (None, ) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            batch_tensors[k] = placeholder

        if log_once("loss_init"):
            logger.info(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(batch_tensors)))

        loss = self._do_loss_init(batch_tensors)
        for k in sorted(batch_tensors.accessed_keys):
            loss_inputs.append((k, batch_tensors[k]))

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
        self._sess.run(tf.global_variables_initializer())
Пример #22
0
    def compute_gradients(self, samples):
        if log_once("compute_gradients"):
            logger.info("Compute gradients on:\n\n{}\n".format(
                summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            grad_out, info_out = {}, {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "compute_gradients")
                for pid, batch in samples.policy_batches.items():
                    if pid not in self.policies_to_train:
                        continue
                    '''
                     For Attention !!!!!!!!!!!!!!!!!!!!
                     Build a dict for neighbor agents (also include current agent) observation 
                    '''
                    neighbor_batch_dic = {}
                    neighbor_pid_list = [
                        pid_ for pid_ in
                        self.traffic_light_node_dict[self.inter_num_2_id(
                            int(pid.split('_')[1]))]['adjacency_row']
                        if pid_ != None
                    ]
                    for neighbor_pid in neighbor_pid_list:
                        neighbor_batch_dic['policy_{}'.format(
                            neighbor_pid)] = samples.policy_batches[
                                'policy_{}'.format(neighbor_pid)]
                    # neighbor_batch_dic[pid] = samples.policy_batches[pid]
                    # ------------------------------------------------------------------

                    grad_out[pid], info_out[pid] = (
                        self.policy_map[pid]._build_compute_gradients(
                            builder, batch, neighbor_batch_dic))
                grad_out = {k: builder.get(v) for k, v in grad_out.items()}
                info_out = {k: builder.get(v) for k, v in info_out.items()}
            else:
                for pid, batch in samples.policy_batches.items():
                    if pid not in self.policies_to_train:
                        continue
                    '''
                     For Attention !!!!!!!!!!!!!!!!!!!!
                     Build a dict for neighbor agents (also include current agent) observation 
                    '''
                    neighbor_batch_dic = {}
                    neighbor_pid_list = [
                        pid_ for pid_ in
                        self.traffic_light_node_dict[self.inter_num_2_id(
                            int(pid.split('_')[1]))]['adjacency_row']
                        if pid_ != None
                    ]
                    for neighbor_pid in neighbor_pid_list:
                        neighbor_batch_dic['policy_{}'.format(
                            neighbor_pid)] = samples.policy_batches[
                                'policy_{}'.format(neighbor_pid)]
                    # neighbor_batch_dic[pid] = samples.policy_batches[pid]
                    # ------------------------------------------------------------------
                    grad_out[pid], info_out[pid] = (
                        self.policy_map[pid].compute_gradients(
                            batch, neighbor_batch_dic))
        else:
            grad_out, info_out = (
                self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
        info_out["batch_count"] = samples.count
        if log_once("grad_out"):
            logger.info("Compute grad info:\n\n{}\n".format(
                summarize(info_out)))
        return grad_out, info_out
Пример #23
0
    def _initialize_loss(self):
        def fake_array(tensor):
            shape = tensor.shape.as_list()
            shape = [s if s is not None else 1 for s in shape]
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        dummy_batch = {
            SampleBatch.CUR_OBS:
            fake_array(self._obs_input),
            SampleBatch.NEXT_OBS:
            fake_array(self._obs_input),
            SampleBatch.DONES:
            np.array([False], dtype=np.bool),
            SampleBatch.ACTIONS:
            fake_array(ModelCatalog.get_action_placeholder(self.action_space)),
            SampleBatch.REWARDS:
            np.array([0], dtype=np.float32),
        }

        # Add dummy things PENGZHENGHAO
        # for name, val in self.model.mask_placeholder_dict.items():
        #     shape = val.shape.as_list()
        #     shape = [1] + [s if s is not None else 1 for s in shape]
        #     dummy_batch[name] = \
        #         np.zeros(shape, dtype=val.dtype.as_numpy_dtype)

        if self._obs_include_prev_action_reward:
            dummy_batch.update({
                SampleBatch.PREV_ACTIONS:
                fake_array(self._prev_action_input),
                SampleBatch.PREV_REWARDS:
                fake_array(self._prev_reward_input),
            })
        state_init = self.get_initial_state()
        state_batches = []
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
            dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
            state_batches.append(np.expand_dims(h, 0))
        if state_init:
            dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)

        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())

        postprocessed_batch = self.postprocess_trajectory(
            SampleBatch(dummy_batch))

        # model forward pass for the loss (needed after postprocess to
        # overwrite any tensor state from that call)
        self.model(self._input_dict, self._state_in, self._seq_lens)

        if self._obs_include_prev_action_reward:
            train_batch = UsageTrackingDict({
                SampleBatch.PREV_ACTIONS:
                self._prev_action_input,
                SampleBatch.PREV_REWARDS:
                self._prev_reward_input,
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            train_batch = UsageTrackingDict({
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        # When using the mask, the key of postprocessed_batch is :
        # dict_keys(['obs', 'new_obs', 'dones', 'actions', 'rewards',
        # 'fc_1_mask', 'fc_2_mask', 'prev_actions', 'prev_rewards',
        # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits',
        # 'layer0', 'layer1', 'advantages', 'value_targets'])

        # When not using the mask, the keys is:
        # dict_keys(['obs', 'new_obs', 'dones', 'actions', 'rewards',
        # 'fc_1_mask', 'fc_2_mask', 'prev_actions', 'prev_rewards',
        # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits',
        # 'layer0', 'layer1', 'advantages', 'value_targets'])
        for k, v in postprocessed_batch.items():
            if k in train_batch:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            elif k == "seq_lens" or k.startswith("state_in_"):
                continue
            shape = (None, ) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            train_batch[k] = placeholder

        # When using the mask. At this time, the train_batch contain 17
        # element.
        # <class 'list'>: ['prev_actions', 'prev_rewards', 'obs', 'new_obs',
        # 'dones', 'actions', 'rewards', 'fc_1_mask', 'fc_2_mask',
        # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits',
        # 'layer0', 'layer1', 'advantages', 'value_targets']
        for i, si in enumerate(self._state_in):
            train_batch["state_in_{}".format(i)] = si
        train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        self._loss_input_dict = train_batch
        # At this time, the accessed_keys: <class 'set'>:
        # {'obs', 'prev_rewards', 'value_targets', 'behaviour_logits',
        # 'prev_actions', 'advantages', 'action_logp', 'actions',
        # 'vf_preds', 'accessed_keys', 'intercepted_values'}

        # However, in the no-mask exp, current accessed_keys:
        # <class 'set'>: {'intercepted_values', 'accessed_keys'}

        loss = self._do_loss_init(train_batch)
        # after the above line, the accessed_keys: <class 'set'>:
        # {'advantages', 'action_logp', 'behaviour_logits', 'prev_rewards',
        # 'prev_actions', 'vf_preds', 'actions', 'value_targets', 'obs'}

        # However, in the no-mask exp, above line lead to: They are same.
        # but different order.
        # {'action_logp', 'prev_actions', 'behaviour_logits',
        # 'value_targets', 'obs', 'prev_rewards', 'advantages', 'vf_preds',
        # 'actions'}

        # at this time, the loss input already has: prev_actions,
        # prev_rewards, obs
        for k in sorted(train_batch.accessed_keys):
            # sorted train_batch.accessed_keys: <class 'list'>: [
            # 'action_logp', 'actions', 'advantages', 'behaviour_logits',
            # 'obs', 'prev_actions', 'prev_rewards', 'value_targets',
            # 'vf_preds']
            if k != "seq_lens" and not k.startswith("state_in_"):
                loss_inputs.append((k, train_batch[k]))

        # PENGZHENGHAO
        # for name, ph in self.model.mask_placeholder_dict.items():
        #     loss_inputs.append((name, ph))

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))
        self._sess.run(tf.global_variables_initializer())
Пример #24
0
    def _initialize_loss(self):
        def fake_array(tensor):
            shape = tensor.shape.as_list()
            shape = [s if s is not None else 1 for s in shape]
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        dummy_batch = {
            SampleBatch.CUR_OBS: fake_array(self._obs_input),
            SampleBatch.NEXT_OBS: fake_array(self._obs_input),
            SampleBatch.DONES: np.array([False], dtype=np.bool),
            SampleBatch.ACTIONS: fake_array(
                ModelCatalog.get_action_placeholder(self.action_space)),
            SampleBatch.REWARDS: np.array([0], dtype=np.float32),
        }
        if self._obs_include_prev_action_reward:
            dummy_batch.update({
                SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
                SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
            })
        state_init = self.get_initial_state()
        state_batches = []
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
            dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
            state_batches.append(np.expand_dims(h, 0))
        if state_init:
            dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())

        postprocessed_batch = self.postprocess_trajectory(
            SampleBatch(dummy_batch))

        # model forward pass for the loss (needed after postprocess to
        # overwrite any tensor state from that call)
        self.model(self._input_dict, self._state_in, self._seq_lens)

        if self._obs_include_prev_action_reward:
            train_batch = UsageTrackingDict({
                SampleBatch.PREV_ACTIONS: self._prev_action_input,
                SampleBatch.PREV_REWARDS: self._prev_reward_input,
                SampleBatch.CUR_OBS: self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            train_batch = UsageTrackingDict({
                SampleBatch.CUR_OBS: self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        for k, v in postprocessed_batch.items():
            if k in train_batch:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            elif k == "seq_lens" or k.startswith("state_in_"):
                continue
            shape = (None, ) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            train_batch[k] = placeholder

        for i, si in enumerate(self._state_in):
            train_batch["state_in_{}".format(i)] = si
        train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)))

        self._loss_input_dict = train_batch
        loss = self._do_loss_init(train_batch)
        for k in sorted(train_batch.accessed_keys):
            if k != "seq_lens" and not k.startswith("state_in_"):
                loss_inputs.append((k, train_batch[k]))

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads))
        self._sess.run(tf.global_variables_initializer())
Пример #25
0
def _process_observations(base_env, policies, batch_builder_pool,
                          active_episodes, unfiltered_obs, rewards, dones,
                          infos, off_policy_actions, horizon, preprocessors,
                          obs_filters, unroll_length, pack, callbacks,
                          soft_horizon, no_done_at_end):
    """Record new data from the environment and prepare for policy evaluation.

    Returns:
        active_envs: set of non-terminated env ids
        to_eval: map of policy_id to list of agent PolicyEvalData
        outputs: list of metrics and samples to return from the sampler
    """
    global i
    global tmp_dic
    global traffic_light_node_dict
    i += 1

    def inter_num_2_id(num):
        return list(tmp_dic.keys())[list(tmp_dic.values()).index(num)]

    def read_traffic_light_node_dict():
        path_to_read = os.path.join(record_dir, 'traffic_light_node_dict.conf')
        with open(path_to_read, 'r') as f:
            traffic_light_node_dict = eval(f.read())
            print("Read traffic_light_node_dict")
            return traffic_light_node_dict

    if i <= 1:
        # 此处用于从配置文件读入 neighbor 情况
        record_dir = base_env.envs[0].record_dir
        traffic_light_node_dict = base_env.envs[0].traffic_light_node_dict
        tmp_dic = traffic_light_node_dict['intersection_1_1'][
            'inter_id_to_index']

    active_envs = set()
    to_eval = defaultdict(list)
    outputs = []

    # For each environment
    for env_id, agent_obs in unfiltered_obs.items():
        new_episode = env_id not in active_episodes
        episode = active_episodes[env_id]
        if not new_episode:
            episode.length += 1
            episode.batch_builder.count += 1
            episode._add_agent_rewards(rewards[env_id])

        if (episode.batch_builder.total() > max(1000, unroll_length * 10)
                and log_once("large_batch_warning")):
            logger.warning(
                "More than {} observations for {} env steps ".format(
                    episode.batch_builder.total(),
                    episode.batch_builder.count) + "are buffered in "
                "the sampler. If this is more than you expected, check that "
                "that you set a horizon on your environment correctly. Note "
                "that in multi-agent environments, `sample_batch_size` sets "
                "the batch size based on environment steps, not the steps of "
                "individual agents, which can result in unexpectedly large "
                "batches.")

        # Check episode termination conditions
        if dones[env_id]["__all__"] or episode.length >= horizon:
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_done = True
            atari_metrics = _fetch_atari_metrics(base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {}))
        else:
            hit_horizon = False
            all_done = False
            active_envs.add(env_id)

        # For each agent in the environment
        for agent_id, raw_obs in agent_obs.items():
            policy_id = episode.policy_for(agent_id)  # eg: "policy_0"
            # print(policy_id)
            prep_obs = _get_or_raise(preprocessors,
                                     policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))

            filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
            '''
            For Attention !!!!!!!!!!!!!!!!!!!!
            这里要执行的是实时的Q eval, 因此要Q eval 网络传neighbor_obs值
            '''
            # 根据 traffic_light_node_dict 字典中的路网关系, 找到当前 policy_id 的 neighbor, 并保存成 "policy_0" 的形式
            neighbor_pid_list = [
                'policy_{}'.format(pid_)
                for pid_ in traffic_light_node_dict[inter_num_2_id(
                    int(policy_id.split('_')[1]))]['adjacency_row']
                if pid_ != None
            ]
            # print(neighbor_pid_list)
            neighbor_obs = []
            neighbor_obs.append([])

            # Size: (1, 5, 15) 只有这个形式才能传入neighbor_obs (batch, 5, 15) 的 Placeholder
            i = 0
            for neighbor_id in neighbor_pid_list:
                neighbor_prep_obs = _get_or_raise(
                    preprocessors, neighbor_id).transform(raw_obs)
                neighbor_filtered_obs = _get_or_raise(
                    obs_filters, neighbor_id)(neighbor_prep_obs)
                neighbor_obs[0].append(neighbor_filtered_obs)
                i += 1
            neighbor_obs = np.array(neighbor_obs).reshape(
                (len(neighbor_pid_list), len(raw_obs)))  # (5, 29)

            # ------------------------------------------------------------------
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            agent_done = bool(all_done or dones[env_id].get(agent_id))
            if not agent_done:
                to_eval[policy_id].append(
                    PolicyEvalData(env_id, agent_id, filtered_obs,
                                   neighbor_obs,
                                   infos[env_id].get(agent_id, {}),
                                   episode.rnn_state_for(agent_id),
                                   episode.last_action_for(agent_id),
                                   rewards[env_id][agent_id] or 0.0))

            last_observation = episode.last_observation_for(agent_id)
            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))

            # Record transition info if applicable
            if (last_observation is not None and infos[env_id].get(
                    agent_id, {}).get("training_enabled", True)):
                episode.batch_builder.add_values(
                    agent_id,
                    policy_id,
                    t=episode.length - 1,
                    eps_id=episode.episode_id,
                    agent_index=episode._agent_index(agent_id),
                    obs=last_observation,
                    actions=episode.last_action_for(agent_id),
                    rewards=rewards[env_id][agent_id],
                    prev_actions=episode.prev_action_for(agent_id),
                    prev_rewards=episode.prev_reward_for(agent_id),
                    dones=(False if
                           (no_done_at_end or
                            (hit_horizon and soft_horizon)) else agent_done),
                    infos=infos[env_id].get(agent_id, {}),
                    new_obs=filtered_obs,
                    **episode.last_pi_info_for(agent_id))

        # Invoke the step callback after the step is logged to the episode
        if callbacks.get("on_episode_step"):
            callbacks["on_episode_step"]({"env": base_env, "episode": episode})

        # Cut the batch if we're not packing multiple episodes into one,
        # or if we've exceeded the requested batch size.
        if episode.batch_builder.has_pending_data():
            if dones[env_id]["__all__"] and not no_done_at_end:
                episode.batch_builder.check_missing_dones()
            if (all_done and not pack) or \
                    episode.batch_builder.count >= unroll_length:
                outputs.append(episode.batch_builder.build_and_reset(episode))
            elif all_done:
                # Make sure postprocessor stays within one episode
                episode.batch_builder.postprocess_batch_so_far(episode)

        if all_done:
            # Handle episode termination
            batch_builder_pool.append(episode.batch_builder)
            if callbacks.get("on_episode_end"):
                callbacks["on_episode_end"]({
                    "env": base_env,
                    "policy": policies,
                    "episode": episode
                })
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs = agent_obs
            else:
                del active_episodes[env_id]
                resetted_obs = base_env.try_reset(env_id)
            if resetted_obs is None:
                # Reset not supported, drop this env from the ready list
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            elif resetted_obs != ASYNC_RESET_RETURN:
                # Creates a new episode if this is not async return
                # If reset is async, we will get its result in some future poll
                episode = active_episodes[env_id]
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id = episode.policy_for(agent_id)  # eg: "policy_0"
                    policy = _get_or_raise(policies, policy_id)
                    prep_obs = _get_or_raise(preprocessors,
                                             policy_id).transform(raw_obs)
                    filtered_obs = _get_or_raise(obs_filters,
                                                 policy_id)(prep_obs)
                    # print('policy_id' + str(policy_id))
                    # print('filtered_obs' + str(filtered_obs))
                    '''
                    For Attention !!!!!!!!!!!!!!!!!!!!
                    这里是episode终止, create a new episode
                    这里要执行的是实时的Q eval, 因此要Q eval 网络传neighbor_obs值
                    '''
                    # 根据 traffic_light_node_dict 字典中的路网关系, 找到当前 policy_id 的 neighbor, 并保存成 "policy_0" 的形式
                    neighbor_pid_list = [
                        'policy_{}'.format(pid_)
                        for pid_ in traffic_light_node_dict[inter_num_2_id(
                            int(policy_id.split('_')[1]))]['adjacency_row']
                        if pid_ != None
                    ]
                    # print(neighbor_pid_list)
                    neighbor_obs = []
                    neighbor_obs.append([])

                    # Size: (1, 5, 29) 只有这个形式才能传入neighbor_obs (batch, 5, 17) 的 Placeholder
                    i = 0
                    for neighbor_id in neighbor_pid_list:
                        neighbor_prep_obs = _get_or_raise(
                            preprocessors, neighbor_id).transform(raw_obs)
                        neighbor_filtered_obs = _get_or_raise(
                            obs_filters, neighbor_id)(neighbor_prep_obs)
                        neighbor_obs[0].append(neighbor_filtered_obs)
                        i += 1
                    neighbor_obs = np.squeeze(np.array(neighbor_obs))

                    # ------------------------------------------------------------------
                    episode._set_last_observation(agent_id, filtered_obs)
                    to_eval[policy_id].append(
                        PolicyEvalData(
                            env_id, agent_id, filtered_obs, neighbor_obs,
                            episode.last_info_for(agent_id) or {},
                            episode.rnn_state_for(agent_id),
                            np.zeros_like(
                                _flatten_action(policy.action_space.sample())),
                            0.0))

    return active_envs, to_eval, outputs
Пример #26
0
def _process_observations(base_env, policies, policies_to_train, dead_policies,
                          policy_config, observation_filter, tf_sess,
                          batch_builder_pool, active_episodes, unfiltered_obs,
                          rewards, dones, infos, off_policy_actions, horizon,
                          preprocessors, obs_filters, unroll_length, pack,
                          callbacks, soft_horizon, no_done_at_end):
    #===MOD===
    """Record new data from the environment and prepare for policy evaluation.

    Returns:
        active_envs: set of non-terminated env ids
        to_eval: map of policy_id to list of agent PolicyEvalData
        outputs: list of metrics and samples to return from the sampler
    """

    active_envs = set()
    to_eval = defaultdict(list)
    outputs = []

    # For each environment
    for env_id, agent_obs in unfiltered_obs.items():
        new_episode = env_id not in active_episodes
        episode = active_episodes[env_id]
        if not new_episode:
            episode.length += 1
            episode.batch_builder.count += 1
            episode._add_agent_rewards(rewards[env_id])

        if (episode.batch_builder.total() > max(1000, unroll_length * 10)
                and log_once("large_batch_warning")):
            logger.warning(
                "More than {} observations for {} env steps ".format(
                    episode.batch_builder.total(),
                    episode.batch_builder.count) + "are buffered in "
                "the sampler. If this is more than you expected, check that "
                "that you set a horizon on your environment correctly. Note "
                "that in multi-agent environments, `sample_batch_size` sets "
                "the batch size based on environment steps, not the steps of "
                "individual agents, which can result in unexpectedly large "
                "batches.")

        # Check episode termination conditions
        if dones[env_id]["__all__"] or episode.length >= horizon:
            # DEBUG
            # print("Trying to terminate.")
            # print("Dones of __all__ is set:", dones[env_id]["__all__"])
            # print("Horizon hit:", episode.length >= horizon)
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_done = True
            atari_metrics = _fetch_atari_metrics(base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {}))
        else:
            hit_horizon = False
            all_done = False
            active_envs.add(env_id)

        #===MOD===
        additional_builders_ids = set()
        #===MOD===

        # For each agent in the environment
        for agent_id, raw_obs in agent_obs.items():

            #===MOD===
            policy_id, policy_constructor_tuple = episode.policy_for(agent_id)
            pols_tuple = generate_policies(
                policy_id,
                policy_constructor_tuple,
                policies,
                policies_to_train,
                dead_policies,
                policy_config,
                preprocessors,
                obs_filters,
                observation_filter,
                tf_sess,
            )
            policies, preprocessors, obs_filters, policies_to_train, dead_policies = pols_tuple
            #===MOD===

            prep_obs = _get_or_raise(preprocessors,
                                     policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))

            filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            agent_done = bool(all_done or dones[env_id].get(agent_id))
            if not agent_done:
                to_eval[policy_id].append(
                    PolicyEvalData(env_id, agent_id, filtered_obs,
                                   infos[env_id].get(agent_id, {}),
                                   episode.rnn_state_for(agent_id),
                                   episode.last_action_for(agent_id),
                                   rewards[env_id][agent_id] or 0.0))

            last_observation = episode.last_observation_for(agent_id)
            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))

            # Record transition info if applicable
            if (last_observation is not None and infos[env_id].get(
                    agent_id, {}).get("training_enabled", True)):
                #===MOD===
                additional_builders_ids.add(agent_id)
                #===MOD===
                episode.batch_builder.add_values(
                    agent_id,
                    policy_id,
                    t=episode.length - 1,
                    eps_id=episode.episode_id,
                    agent_index=episode._agent_index(agent_id),
                    obs=last_observation,
                    actions=episode.last_action_for(agent_id),
                    rewards=rewards[env_id][agent_id],
                    prev_actions=episode.prev_action_for(agent_id),
                    prev_rewards=episode.prev_reward_for(agent_id),
                    dones=(False if
                           (no_done_at_end or
                            (hit_horizon and soft_horizon)) else agent_done),
                    infos=infos[env_id].get(agent_id, {}),
                    new_obs=filtered_obs,
                    **episode.last_pi_info_for(agent_id))

            #===MOD===
            if agent_done:
                # Does it make sense to remove agent id from `agent_builders`?
                dead_policies.add(policy_id)
                print("Removing agent id from agent builders: %s" %
                      str(agent_id))
                episode.batch_builder.agent_builders.pop(agent_id)
                if policy_id in to_eval:
                    to_eval.pop(policy_id)
                    # print("Popping policy id from toeval.")
            #===MOD===

        start = time.time()

        #===MOD===
        print("sampler.py: ids added to agent builders:\t",
              additional_builders_ids)
        # Update ``self.policy_map`` in ``MultiAgentSampleBatchBuilder``.
        # TODO: policies is not being pruned in this file.
        episode.batch_builder.policy_map = policies
        print("sampler.py: policies: \t", policies.keys())
        #===MOD===

        # Invoke the step callback after the step is logged to the episode
        if callbacks.get("on_episode_step"):
            callbacks["on_episode_step"]({"env": base_env, "episode": episode})

        # Cut the batch if we're not packing multiple episodes into one,
        # or if we've exceeded the requested batch size.
        if episode.batch_builder.has_pending_data():
            if dones[env_id]["__all__"] and not no_done_at_end:
                episode.batch_builder.check_missing_dones()
            if (all_done and not pack) or \
                    episode.batch_builder.count >= unroll_length:
                outputs.append(episode.batch_builder.build_and_reset(episode))
            elif all_done:
                # Make sure postprocessor stays within one episode
                # KEYERROR
                episode.batch_builder.postprocess_batch_so_far(episode)

        if all_done:
            # Handle episode termination
            batch_builder_pool.append(episode.batch_builder)
            if callbacks.get("on_episode_end"):
                callbacks["on_episode_end"]({
                    "env": base_env,
                    "policy": policies,
                    "episode": episode
                })
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs = agent_obs
            else:
                del active_episodes[env_id]
                resetted_obs = base_env.try_reset(env_id)
            if resetted_obs is None:
                # Reset not supported, drop this env from the ready list
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            elif resetted_obs != ASYNC_RESET_RETURN:
                # print("Executing new epsiode non-async return.")
                time.sleep(1)
                raise NotImplementedError(
                    "Multiple episodes not supported by design.")
                # Creates a new episode if this is not async return
                # If reset is async, we will get its result in some future poll
                episode = active_episodes[env_id]
                for agent_id, raw_obs in resetted_obs.items():

                    #===MOD===
                    policy_id, policy_constructor_tuple = episode.policy_for(
                        agent_id)
                    # with tf_sess.as_default():
                    pols_tuple = generate_policies(
                        policy_id,
                        policy_constructor_tuple,
                        policies,
                        policies_to_train,
                        dead_policies,
                        policy_config,
                        preprocessors,
                        obs_filters,
                        observation_filter,
                        tf_sess,
                    )
                    policies, preprocessors, obs_filters, policies_to_train, dead_policies = pols_tuple
                    #===MOD===

                    policy = _get_or_raise(policies, policy_id)
                    prep_obs = _get_or_raise(preprocessors,
                                             policy_id).transform(raw_obs)
                    filtered_obs = _get_or_raise(obs_filters,
                                                 policy_id)(prep_obs)
                    episode._set_last_observation(agent_id, filtered_obs)
                    to_eval[policy_id].append(
                        PolicyEvalData(
                            env_id, agent_id, filtered_obs,
                            episode.last_info_for(agent_id) or {},
                            episode.rnn_state_for(agent_id),
                            np.zeros_like(
                                _flatten_action(policy.action_space.sample())),
                            0.0))

        #===MOD===
        pols_tuple = (policies, preprocessors, obs_filters, policies_to_train,
                      dead_policies)
        #===MOD===
    #===MOD===
    return active_envs, to_eval, outputs, pols_tuple
Пример #27
0
    def _get_loss_inputs_dict(self, batch, neighbor_batch_dic, shuffle):
        """Return a feed dict from a batch.

        Arguments:
            batch (SampleBatch): batch of data to derive inputs from
            neighbor_batch_dic (dict, SampleBatch): batch of data for neighbor of the main policy
            shuffle (bool): whether to shuffle batch sequences. Shuffle may
                be done in-place. This only makes sense if you're further
                applying minibatch SGD after getting the outputs.

        Returns:
            feed dict of data
        """

        feed_dict = {}
        if self._batch_divisibility_req > 1:
            meets_divisibility_reqs = (
                len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req
                == 0
                and max(batch[SampleBatch.AGENT_INDEX]) == 0)  # not multiagent
        else:
            meets_divisibility_reqs = True

        neighbor_list = [None] * 5
        neighbor_count = 0
        for k in neighbor_batch_dic:
            neighbor_list[neighbor_count] = k
            neighbor_count += 1

        tmp_dic = {}
        # Simple case: not RNN nor do we need to pad
        if not self._state_inputs and meets_divisibility_reqs:
            if shuffle:
                batch.shuffle()
            for k, ph in self._loss_inputs:
                '''
                For Attention
                这里是用于Q target的replay buffer, 从SampleBatch中整理出 neighbor_obs的信息
                '''
                if 'neighbor_obs' in k:
                    tmp_dic[ph] = {}
                    feed_dict[ph] = []
                    # neighbor_id = neighbor_list[int(k.split('_')[2])]
                    # if neighbor_id is None:
                    #     continue
                    # else:
                    for neighbor_id in neighbor_list:
                        tmp_dic[ph][neighbor_id] = neighbor_batch_dic[
                            neighbor_id]['obs']
                    neighbor_id = neighbor_list[0]
                    # [neighbor, batch, feather] -> [batch, neighbor, feather]
                    for batch_item in range(len(tmp_dic[ph][neighbor_id])):
                        feed_dict[ph].append([])
                        for neighbor_id in neighbor_list:
                            feed_dict[ph][batch_item].append(
                                tmp_dic[ph][neighbor_id][batch_item])
                    feed_dict[ph] = np.array(feed_dict[ph])
                    # feed_dict[ph] = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 1)))(tmp_dic[ph])
                # ----------------------------------------------------------------
                else:
                    feed_dict[ph] = batch[k]
            return feed_dict

        if self._state_inputs:
            max_seq_len = self._max_seq_len
            dynamic_max = True
        else:
            max_seq_len = self._batch_divisibility_req
            dynamic_max = False

        # RNN or multi-agent case
        feature_keys = [k for k, v in self._loss_inputs]
        state_keys = [
            "state_in_{}".format(i) for i in range(len(self._state_inputs))
        ]
        feature_sequences, initial_states, seq_lens = chop_into_sequences(
            batch[SampleBatch.EPS_ID],
            batch[SampleBatch.UNROLL_ID],
            batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
            [batch[k] for k in state_keys],
            max_seq_len,
            dynamic_max=dynamic_max,
            shuffle=shuffle)
        for k, v in zip(feature_keys, feature_sequences):
            feed_dict[self._loss_input_dict[k]] = v
        for k, v in zip(state_keys, initial_states):
            feed_dict[self._loss_input_dict[k]] = v
        feed_dict[self._seq_lens] = seq_lens

        if log_once("rnn_feed_dict"):
            logger.info("Padded input for RNN:\n\n{}\n".format(
                summarize({
                    "features": feature_sequences,
                    "initial_states": initial_states,
                    "seq_lens": seq_lens,
                    "max_seq_len": max_seq_len,
                })))
        return feed_dict
Пример #28
0
    def _initialize_loss(self):
        def fake_array(tensor):
            shape = tensor.shape.as_list()
            shape[0] = 1
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        dummy_batch = {
            SampleBatch.CUR_OBS:
            fake_array(self._obs_input),
            SampleBatch.NEXT_OBS:
            fake_array(self._obs_input),
            SampleBatch.DONES:
            np.array([False], dtype=np.bool),
            SampleBatch.ACTIONS:
            fake_array(ModelCatalog.get_action_placeholder(self.action_space)),
            SampleBatch.REWARDS:
            np.array([0], dtype=np.float32),
        }
        if self._obs_include_prev_action_reward:
            dummy_batch.update({
                SampleBatch.PREV_ACTIONS:
                fake_array(self._prev_action_input),
                SampleBatch.PREV_REWARDS:
                fake_array(self._prev_reward_input),
            })
        state_init = self.get_initial_state()
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
            dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
        if state_init:
            dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())
        postprocessed_batch = self.postprocess_trajectory(
            SampleBatch(dummy_batch))

        if self._obs_include_prev_action_reward:
            batch_tensors = UsageTrackingDict({
                SampleBatch.PREV_ACTIONS:
                self._prev_action_input,
                SampleBatch.PREV_REWARDS:
                self._prev_reward_input,
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            batch_tensors = UsageTrackingDict({
                SampleBatch.CUR_OBS:
                self._obs_input,
            })
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        for k, v in postprocessed_batch.items():
            if k in batch_tensors:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            shape = (None, ) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            batch_tensors[k] = placeholder

        if log_once("loss_init"):
            logger.info(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(batch_tensors)))

        loss = self._do_loss_init(batch_tensors)
        for k in sorted(batch_tensors.accessed_keys):
            loss_inputs.append((k, batch_tensors[k]))

        # XXX experimental support for automatically eagerifying the loss.
        # The main limitation right now is that TF doesn't support mixing eager
        # and non-eager tensors, so losses that read non-eager tensors through
        # `policy` need to use `policy.convert_to_eager(tensor)`.
        if self.config["use_eager"]:
            if not self.model:
                raise ValueError("eager not implemented in this case")
            graph_tensors = list(self._needs_eager_conversion)

            def gen_loss(model_outputs, *args):
                # fill in the batch tensor dict with eager ensors
                eager_inputs = dict(
                    zip([k for (k, v) in loss_inputs],
                        args[:len(loss_inputs)]))
                # fill in the eager versions of all accessed graph tensors
                self._eager_tensors = dict(
                    zip(graph_tensors, args[len(loss_inputs):]))
                # patch the action dist to use eager mode tensors
                self.action_dist.inputs = model_outputs
                return self._loss_fn(self, eager_inputs)

            # TODO(ekl) also handle the stats funcs
            loss = tf.py_function(
                gen_loss,
                # cast works around TypeError: Cannot convert provided value
                # to EagerTensor. Provided value: 0.0 Requested dtype: int64
                [self.model.outputs] +
                [tf.cast(v, tf.float32) for (k, v) in loss_inputs] +
                [tf.cast(t, tf.float32) for t in graph_tensors],
                tf.float32)

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
        self._sess.run(tf.global_variables_initializer())
Пример #29
0
    def load_data(self, sess, inputs, state_inputs):
        """Bulk loads the specified inputs into device memory.

        The shape of the inputs must conform to the shapes of the input
        placeholders this optimizer was constructed with.

        The data is split equally across all the devices. If the data is not
        evenly divisible by the batch size, excess data will be discarded.

        Args:
            sess: TensorFlow session.
            inputs: List of arrays matching the input placeholders, of shape
                [BATCH_SIZE, ...].
            state_inputs: List of RNN input arrays. These arrays have size
                [BATCH_SIZE / MAX_SEQ_LEN, ...].

        Returns:
            The number of tuples loaded per device.
        """

        if log_once("load_data"):
            logger.info(
                "Training on concatenated sample batches:\n\n{}\n".format(
                    summarize({
                        "placeholders": self.loss_inputs,
                        "inputs": inputs,
                        "state_inputs": state_inputs
                    })))

        feed_dict = {}
        assert len(self.loss_inputs) == len(inputs + state_inputs), \
            (self.loss_inputs, inputs, state_inputs)

        # Let's suppose we have the following input data, and 2 devices:
        # 1 2 3 4 5 6 7                              <- state inputs shape
        # A A A B B B C C C D D D E E E F F F G G G  <- inputs shape
        # The data is truncated and split across devices as follows:
        # |---| seq len = 3
        # |---------------------------------| seq batch size = 6 seqs
        # |----------------| per device batch size = 9 tuples

        if len(state_inputs) > 0:
            smallest_array = state_inputs[0]
            seq_len = len(inputs[0]) // len(state_inputs[0])
            self._loaded_max_seq_len = seq_len
        else:
            smallest_array = inputs[0]
            self._loaded_max_seq_len = 1

        sequences_per_minibatch = (self.max_per_device_batch_size //
                                   self._loaded_max_seq_len *
                                   len(self.devices))
        if sequences_per_minibatch < 1:
            logger.warn(
                ("Target minibatch size is {}, however the rollout sequence "
                 "length is {}, hence the minibatch size will be raised to "
                 "{}.").format(self.max_per_device_batch_size,
                               self._loaded_max_seq_len,
                               self._loaded_max_seq_len * len(self.devices)))
            sequences_per_minibatch = 1

        if len(smallest_array) < sequences_per_minibatch:
            # Dynamically shrink the batch size if insufficient data
            sequences_per_minibatch = make_divisible_by(
                len(smallest_array), len(self.devices))

        if log_once("data_slicing"):
            logger.info(
                ("Divided {} rollout sequences, each of length {}, among "
                 "{} devices.").format(len(smallest_array),
                                       self._loaded_max_seq_len,
                                       len(self.devices)))

        if sequences_per_minibatch < len(self.devices):
            raise ValueError(
                "Must load at least 1 tuple sequence per device. Try "
                "increasing `sgd_minibatch_size` or reducing `max_seq_len` "
                "to ensure that at least one sequence fits per device.")
        self._loaded_per_device_batch_size = (sequences_per_minibatch //
                                              len(self.devices) *
                                              self._loaded_max_seq_len)

        if len(state_inputs) > 0:
            # First truncate the RNN state arrays to the sequences_per_minib.
            state_inputs = [
                make_divisible_by(arr, sequences_per_minibatch)
                for arr in state_inputs
            ]
            # Then truncate the data inputs to match
            inputs = [arr[:len(state_inputs[0]) * seq_len] for arr in inputs]
            assert len(state_inputs[0]) * seq_len == len(inputs[0]), \
                (len(state_inputs[0]), sequences_per_minibatch, seq_len,
                 len(inputs[0]))
            for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
                feed_dict[ph] = arr
            truncated_len = len(inputs[0])
        else:
            for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
                truncated_arr = make_divisible_by(arr, sequences_per_minibatch)
                feed_dict[ph] = truncated_arr
                truncated_len = len(truncated_arr)

        sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)

        self.num_tuples_loaded = truncated_len
        tuples_per_device = truncated_len // len(self.devices)
        assert tuples_per_device > 0, "No data loaded?"
        assert tuples_per_device % self._loaded_per_device_batch_size == 0
        return tuples_per_device
Пример #30
0
def _env_runner(base_env, extra_batch_callback, policies, policies_to_train,
                policy_config, observation_filter, policy_mapping_fn,
                unroll_length, horizon, preprocessors, obs_filters,
                clip_rewards, clip_actions, pack, callbacks, tf_sess,
                perf_stats, soft_horizon, no_done_at_end):
    #===MOD===
    """This implements the common experience collection logic.

    Args:
        base_env (BaseEnv): env implementing BaseEnv.
        extra_batch_callback (fn): function to send extra batch data to.
        policies (dict): Map of policy ids to Policy instances.
        policy_mapping_fn (func): Function that maps agent ids to policy ids.
            This is called when an agent first enters the environment. The
            agent is then "bound" to the returned policy for the episode.
        unroll_length (int): Number of episode steps before `SampleBatch` is
            yielded. Set to infinity to yield complete episodes.
        horizon (int): Horizon of the episode.
        preprocessors (dict): Map of policy id to preprocessor for the
            observations prior to filtering.
        obs_filters (dict): Map of policy id to filter used to process
            observations for the policy.
        clip_rewards (bool): Whether to clip rewards before postprocessing.
        pack (bool): Whether to pack multiple episodes into each batch. This
            guarantees batches will be exactly `unroll_length` in size.
        clip_actions (bool): Whether to clip actions to the space range.
        callbacks (dict): User callbacks to run on episode events.
        tf_sess (Session|None): Optional tensorflow session to use for batching
            TF policy evaluations.
        perf_stats (PerfStats): Record perf stats into this object.
        soft_horizon (bool): Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end (bool): Ignore the done=True at the end of the episode
            and instead record done=False.

    Yields:
        rollout (SampleBatch): Object containing state, action, reward,
            terminal condition, and other fields as dictated by `policy`.
    """

    try:
        if not horizon:
            horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps)
    except Exception:
        logger.debug("no episode horizon specified, assuming inf")
    if not horizon:
        horizon = float("inf")

    # Pool of batch builders, which can be shared across episodes to pack
    # trajectory data.
    batch_builder_pool = []

    def get_batch_builder():
        if batch_builder_pool:
            return batch_builder_pool.pop()
        else:
            #===MOD===
            # We use the below.
            #===MOD===
            return MultiAgentSampleBatchBuilder(
                policies, clip_rewards, callbacks.get("on_postprocess_traj"))

    def new_episode():
        episode = MultiAgentEpisode(policies, policy_mapping_fn,
                                    get_batch_builder, extra_batch_callback)
        if callbacks.get("on_episode_start"):
            callbacks["on_episode_start"]({
                "env": base_env,
                "policy": policies,
                "episode": episode,
            })
        return episode

    active_episodes = defaultdict(new_episode)

    #===MOD===
    dead_policies = set()
    #===MOD===
    while True:
        perf_stats.iters += 1
        t0 = time.time()
        # Get observations from all ready agents
        unfiltered_obs, rewards, dones, infos, off_policy_actions = \
            base_env.poll()
        perf_stats.env_wait_time += time.time() - t0

        if log_once("env_returns"):
            logger.info("Raw obs from env: {}".format(
                summarize(unfiltered_obs)))
            logger.info("Info return from env: {}".format(summarize(infos)))

        # Process observations and prepare for policy evaluation
        t1 = time.time()
        #===MOD===
        # DEBUG
        # print("Iterations:", perf_stats.iters)
        """
        Added arguments:
            policies_to_train,
            dead_policies,
            policy_config,
            observation_filter,
            tf_sess,
        """
        active_envs, to_eval, outputs, pols_tuple = _process_observations(
            base_env, policies, policies_to_train, dead_policies,
            policy_config, observation_filter, tf_sess, batch_builder_pool,
            active_episodes, unfiltered_obs, rewards, dones, infos,
            off_policy_actions, horizon, preprocessors, obs_filters,
            unroll_length, pack, callbacks, soft_horizon, no_done_at_end)
        policies, preprocessors, obs_filters, policies_to_train, dead_policies = pols_tuple
        #===MOD===
        perf_stats.processing_time += time.time() - t1
        for o in outputs:
            yield o

        # Do batched policy eval
        t2 = time.time()
        #===MOD===
        for policy_id in dead_policies:
            assert policy_id not in to_eval
        # _do_policy_eval(tf_sess, ... -> ..._eval(None, ...
        eval_results = _do_policy_eval(None, to_eval, policies,
                                       active_episodes)
        #===MOD===
        perf_stats.inference_time += time.time() - t2
        # DEBUG
        # print("sampler.py: t2: %fs" % (time.time() - t2))

        # Process results and update episode state
        t3 = time.time()
        actions_to_send = _process_policy_eval_results(to_eval, eval_results,
                                                       active_episodes,
                                                       active_envs,
                                                       off_policy_actions,
                                                       policies, clip_actions)
        perf_stats.processing_time += time.time() - t3

        # Return computed actions to ready envs. We also send to envs that have
        # taken off-policy actions; those envs are free to ignore the action.
        t4 = time.time()
        base_env.send_actions(actions_to_send)
        perf_stats.env_wait_time += time.time() - t4