Пример #1
0
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.workers.remote_workers():
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote()
                            for e in self.workers.remote_workers()
                        ]))
                else:
                    samples.append(self.workers.local_worker().sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        # Handle everything as if multiagent
        if isinstance(samples, SampleBatch):
            samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples},
                                      samples.count)

        fetches = {}
        with self.grad_timer:
            for policy_id, policy in self.policies.items():
                if policy_id not in samples.policy_batches:
                    continue

                batch = samples.policy_batches[policy_id]
                for field in self.standardize_fields:
                    value = batch[field]
                    standardized = (value - value.mean()) / max(
                        1e-4, value.std())
                    batch[field] = standardized

                for i in range(self.num_sgd_iter):
                    iter_extra_fetches = defaultdict(list)
                    for minibatch in self._minibatches(batch):
                        batch_fetches = (
                            self.workers.local_worker().learn_on_batch(
                                MultiAgentBatch({policy_id: minibatch},
                                                minibatch.count)))[policy_id]
                        for k, v in batch_fetches[LEARNER_STATS_KEY].items():
                            iter_extra_fetches[k].append(v)
                    logger.debug("{} {}".format(i,
                                                _averaged(iter_extra_fetches)))
                fetches[policy_id] = _averaged(iter_extra_fetches)

        self.grad_timer.push_units_processed(samples.count)
        if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
            self.learner_stats = fetches[DEFAULT_POLICY_ID]
        else:
            self.learner_stats = fetches
        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return self.learner_stats
    def replay(self) -> SampleBatchType:
        """If this buffer was given a fake batch, return it, otherwise return
        a MultiAgentBatch with samples.
        """
        if self._fake_batch:
            fake_batch = SampleBatch(self._fake_batch)
            return MultiAgentBatch({
                DEFAULT_POLICY_ID: fake_batch
            }, fake_batch.count)

        if self.num_added < self.replay_starts:
            return None
        with self.replay_timer:
            # Lockstep mode: Sample from all policies at the same time an
            # equal amount of steps.
            if self.replay_mode == "lockstep":
                return self.replay_buffers[_ALL_POLICIES].sample(
                    self.replay_batch_size, beta=self.prioritized_replay_beta)
            else:
                samples = {}
                for policy_id, replay_buffer in self.replay_buffers.items():
                    samples[policy_id] = replay_buffer.sample(
                        self.replay_batch_size,
                        beta=self.prioritized_replay_beta)
                return MultiAgentBatch(samples, self.replay_batch_size)
Пример #3
0
    def replay(self):
        if self._fake_batch:
            fake_batch = SampleBatch(self._fake_batch)
            return MultiAgentBatch({DEFAULT_POLICY_ID: fake_batch},
                                   fake_batch.count)

        if self.num_added < self.replay_starts:
            return None

        with self.replay_timer:
            samples = {}
            idxes = None
            for policy_id, replay_buffer in self.replay_buffers.items():
                if self.multiagent_sync_replay:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(
                            self.replay_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(self.replay_batch_size)
                (obses_t, actions, rewards, obses_tp1, dones, weights,
                 batch_indexes) = replay_buffer.sample_with_idxes(
                     idxes, beta=self.prioritized_replay_beta)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
            return MultiAgentBatch(samples, self.replay_batch_size)
Пример #4
0
def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter,
                     sgd_minibatch_size, standardize_fields):
    """Execute minibatch SGD.

    Args:
        samples (SampleBatch): Batch of samples to optimize.
        policies (dict): Dictionary of policies to optimize.
        local_worker (RolloutWorker): Master rollout worker instance.
        num_sgd_iter (int): Number of epochs of optimization to take.
        sgd_minibatch_size (int): Size of minibatches to use for optimization.
        standardize_fields (list): List of sample field names that should be
            normalized prior to optimization.

    Returns:
        averaged info fetches over the last SGD epoch taken.
    """
    if isinstance(samples, SampleBatch):
        samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count)

    # Use LearnerInfoBuilder as a unified way to build the final
    # results dict from `learn_on_loaded_batch` call(s).
    # This makes sure results dicts always have the same structure
    # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
    # tf vs torch).
    learner_info_builder = LearnerInfoBuilder(num_devices=1)
    for policy_id, policy in policies.items():
        if policy_id not in samples.policy_batches:
            continue

        batch = samples.policy_batches[policy_id]
        for field in standardize_fields:
            batch[field] = standardized(batch[field])

        # Check to make sure that the sgd_minibatch_size is not smaller
        # than max_seq_len otherwise this will cause indexing errors while
        # performing sgd when using a RNN or Attention model
        if policy.is_recurrent() and \
           policy.config["model"]["max_seq_len"] > sgd_minibatch_size:
            raise ValueError("`sgd_minibatch_size` ({}) cannot be smaller than"
                             "`max_seq_len` ({}).".format(
                                 sgd_minibatch_size,
                                 policy.config["model"]["max_seq_len"]))

        for i in range(num_sgd_iter):
            for minibatch in minibatches(batch, sgd_minibatch_size):
                results = (local_worker.learn_on_batch(
                    MultiAgentBatch({policy_id: minibatch},
                                    minibatch.count)))[policy_id]
                learner_info_builder.add_learn_on_batch_results(
                    results, policy_id)

    learner_info = learner_info_builder.finalize()
    return learner_info
Пример #5
0
 def add_batch(self, batch):
     # Make a copy so the replay buffer doesn't pin plasma memory.
     batch = batch.copy()
     # Handle everything as if multiagent
     if isinstance(batch, SampleBatch):
         batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
     with self.add_batch_timer:
         if self.replay_mode == "lockstep":
             for s in batch.timeslices(self.replay_sequence_length):
                 self.replay_buffers[_ALL_POLICIES].add(s)
         else:
             for policy_id, b in batch.policy_batches.items():
                 for s in b.timeslices(self.replay_sequence_length):
                     self.replay_buffers[policy_id].add(s)
     self.num_added += batch.count
Пример #6
0
        def _add_multi_agent_batch_to_buffer(
            buffer, num_policies, num_batches=5, seq_lens=False, **kwargs
        ):
            def _generate_data(policy_id):
                batch = SampleBatch(
                    {
                        SampleBatch.T: [0, 1],
                        SampleBatch.ACTIONS: 2 * [np.random.choice([0, 1])],
                        SampleBatch.REWARDS: 2 * [np.random.rand()],
                        SampleBatch.OBS: 2 * [np.random.random((4,))],
                        SampleBatch.NEXT_OBS: 2 * [np.random.random((4,))],
                        SampleBatch.DONES: [False, True],
                        SampleBatch.EPS_ID: 2 * [self.batch_id],
                        SampleBatch.AGENT_INDEX: 2 * [0],
                        SampleBatch.SEQ_LENS: [2],
                        "batch_id": 2 * [self.batch_id],
                        "policy_id": 2 * [policy_id],
                    }
                )
                if not seq_lens:
                    del batch[SampleBatch.SEQ_LENS]
                self.batch_id += 1
                return batch

            for i in range(num_batches):
                # genera a few policy batches
                policy_batches = {
                    idx: _generate_data(idx)
                    for idx, _ in enumerate(range(num_policies))
                }
                batch = MultiAgentBatch(policy_batches, num_batches * 2)
                buffer.add(batch, **kwargs)
Пример #7
0
def _from_json(batch: str) -> SampleBatchType:
    if isinstance(batch, bytes):  # smart_open S3 doesn't respect "r"
        batch = batch.decode("utf-8")
    data = json.loads(batch)

    if "type" in data:
        data_type = data.pop("type")
    else:
        raise ValueError("JSON record missing 'type' field")

    if data_type == "SampleBatch":
        for k, v in data.items():
            data[k] = unpack_if_needed(v)
        return SampleBatch(data)
    elif data_type == "MultiAgentBatch":
        policy_batches = {}
        for policy_id, policy_batch in data["policy_batches"].items():
            inner = {}
            for k, v in policy_batch.items():
                inner[k] = unpack_if_needed(v)
            policy_batches[policy_id] = SampleBatch(inner)
        return MultiAgentBatch(policy_batches, data["count"])
    else:
        raise ValueError(
            "Type field must be one of ['SampleBatch', 'MultiAgentBatch']",
            data_type)
Пример #8
0
    def gen_replay(timeout):
        while True:
            samples = {}
            idxes = None
            for policy_id, replay_buffer in replay_buffers.items():
                if synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    metrics = LocalIterator.get_metrics()
                    num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER]
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=prioritized_replay_beta.value(num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
            yield MultiAgentBatch(samples, train_batch_size)
Пример #9
0
def _collect_joint_dataset(trainer, worker, sample_size):
    joint_obs = []
    if hasattr(trainer.optimizer, "replay_buffers"):
        # If we are using maddpg, it use ReplayOptimizer, which has this
        # attribute.
        for policy_id, replay_buffer in \
                trainer.optimizer.replay_buffers.items():
            obs = replay_buffer.sample(sample_size)[0]
            joint_obs.append(obs)
    else:
        # If we are using individual PPO, it has no replay buffer,
        # so it seems we have to rollout here to collect the observations

        # Force to collect enough data for us to use.
        tmp_batch = worker.sample()
        count_dict = {k: v.count for k, v in tmp_batch.policy_batches.items()}
        for k in worker.policy_map.keys():
            if k not in count_dict:
                count_dict[k] = 0
        samples = [tmp_batch]
        while any(c < sample_size for c in count_dict.values()):
            tmp_batch = worker.sample()
            for k, v in tmp_batch.policy_batches.items():
                assert k in count_dict, count_dict
                count_dict[k] += v.count
            samples.append(tmp_batch)
        multi_agent_batch = MultiAgentBatch.concat_samples(samples)
        for pid, batch in multi_agent_batch.policy_batches.items():
            batch.shuffle()
            assert batch.count >= sample_size, (batch, batch.count, [
                b.count for b in batch.policy_batches.values()
            ])
            joint_obs.append(batch.slice(0, sample_size)['obs'])
    joint_obs = np.concatenate(joint_obs)
    return joint_obs
Пример #10
0
    def build_and_reset(
            self,
            episode: Optional[MultiAgentEpisode] = None) -> MultiAgentBatch:
        """Returns the accumulated sample batches for each policy.

        Any unprocessed rows will be first postprocessed with a policy
        postprocessor. The internal state of this builder will be reset.

        Args:
            episode (Optional[MultiAgentEpisode]): The Episode object that
                holds this MultiAgentBatchBuilder object or None.

        Returns:
            MultiAgentBatch: Returns the accumulated sample batches for each
                policy.
        """

        self.postprocess_batch_so_far(episode)
        policy_batches = {}
        for policy_id, builder in self.policy_builders.items():
            if builder.count > 0:
                policy_batches[policy_id] = builder.build_and_reset()
        old_count = self.count
        self.count = 0
        return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
Пример #11
0
def before_learn_on_batch(multi_agent_batch: MultiAgentBatch, policies,
                          train_batch_size):
    samples = {}

    # Modify keys.
    for pid, p in policies.items():
        i = p.agent_idx
        keys = multi_agent_batch.policy_batches[pid].data.keys()
        keys = ["_".join([k, str(i)]) for k in keys]
        samples.update(
            dict(zip(keys,
                     multi_agent_batch.policy_batches[pid].data.values())))

    # Make ops and feed_dict to get "new_obs" from target action sampler.
    new_obs_ph_n = [p.new_obs_ph for p in policies.values()]
    new_obs_n = list()
    for k, v in samples.items():
        if "new_obs" in k:
            new_obs_n.append(v)

    # target_act_sampler_n = [p.target_act_sampler for p in policies.values()]
    feed_dict = dict(zip(new_obs_ph_n, new_obs_n))
    new_act_n = [
        p.sess.run(p.target_act_sampler, feed_dict) for p in policies.values()
    ]
    samples.update(
        {"new_actions_%d" % i: new_act
         for i, new_act in enumerate(new_act_n)})

    # Share samples among agents.
    policy_batches = {pid: SampleBatch(samples) for pid in policies.keys()}
    return MultiAgentBatch(policy_batches, train_batch_size)
Пример #12
0
        def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs):
            super().on_sample_end(worker=worker, samples=samples, **kwargs)
            assert isinstance(samples, MultiAgentBatch)

            for policy_samples in samples.policy_batches.values():
                if "action_prob" in policy_samples.data:
                    del policy_samples.data["action_prob"]
                if "action_logp" in policy_samples.data:
                    del policy_samples.data["action_logp"]

            for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"),
                                                    ("average_policy_1", "best_response_1")]:
                for policy_id, policy_samples in samples.policy_batches.items():
                    if policy_id == br_policy_id:
                        store_to_avg_policy_buffer(MultiAgentBatch(policy_batches={
                            average_policy_id: policy_samples
                        }, env_steps=policy_samples.count))
                if average_policy_id in samples.policy_batches:

                    if br_policy_id in samples.policy_batches:
                        all_policies_samples = samples.policy_batches[br_policy_id].concat(
                            other=samples.policy_batches[average_policy_id])
                    else:
                        all_policies_samples = samples.policy_batches[average_policy_id]
                    del samples.policy_batches[average_policy_id]
                    samples.policy_batches[br_policy_id] = all_policies_samples
Пример #13
0
def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
    # Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
    if "type" in json_data:
        data_type = json_data.pop("type")
    else:
        raise ValueError("JSON record missing 'type' field")

    if data_type == "SampleBatch":
        if worker is not None and len(worker.policy_map) != 1:
            raise ValueError(
                "Found single-agent SampleBatch in input file, but our "
                "PolicyMap contains more than 1 policy!")
        for k, v in json_data.items():
            json_data[k] = unpack_if_needed(v)
        if worker is not None:
            policy = next(iter(worker.policy_map.values()))
            json_data = _adjust_obs_actions_for_policy(json_data, policy)
        return SampleBatch(json_data)
    elif data_type == "MultiAgentBatch":
        policy_batches = {}
        for policy_id, policy_batch in json_data["policy_batches"].items():
            inner = {}
            for k, v in policy_batch.items():
                inner[k] = unpack_if_needed(v)
            if worker is not None:
                policy = worker.policy_map[policy_id]
                inner = _adjust_obs_actions_for_policy(inner, policy)
            policy_batches[policy_id] = SampleBatch(inner)
        return MultiAgentBatch(policy_batches, json_data["count"])
    else:
        raise ValueError(
            "Type field must be one of ['SampleBatch', 'MultiAgentBatch']",
            data_type)
Пример #14
0
 def add_batch(self, batch: SampleBatchType) -> None:
     # Make a copy so the replay buffer doesn't pin plasma memory.
     batch = batch.copy()
     # Handle everything as if multiagent
     if isinstance(batch, SampleBatch):
         batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
     with self.add_batch_timer:
         # Lockstep mode: Store under _ALL_POLICIES key (we will always
         # only sample from all policies at the same time).
         if self.replay_mode == "lockstep":
             # Note that prioritization is not supported in this mode.
             for s in batch.timeslices(self.replay_sequence_length):
                 self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
         else:
             for policy_id, sample_batch in batch.policy_batches.items():
                 if self.replay_sequence_length == 1:
                     timeslices = sample_batch.timeslices(1)
                 else:
                     timeslices = timeslice_along_seq_lens_with_overlap(
                         sample_batch=sample_batch,
                         zero_pad_max_seq_len=self.replay_sequence_length,
                         pre_overlap=self.replay_burn_in,
                         zero_init_states=self.replay_zero_init_states,
                     )
                 for time_slice in timeslices:
                     # If SampleBatch has prio-replay weights, average
                     # over these to use as a weight for the entire
                     # sequence.
                     if "weights" in time_slice:
                         weight = np.mean(time_slice["weights"])
                     else:
                         weight = None
                     self.replay_buffers[policy_id].add(time_slice,
                                                        weight=weight)
     self.num_added += batch.count
Пример #15
0
    def _add_multi_agent_batch_to_buffer(
        self, buffer, num_policies, num_batches=5, **kwargs
    ):
        def _generate_data(policy_id):
            batch = SampleBatch(
                {
                    SampleBatch.T: [0],
                    SampleBatch.ACTIONS: [np.random.choice([0, 1])],
                    SampleBatch.REWARDS: [np.random.rand()],
                    SampleBatch.OBS: [np.random.random((4,))],
                    SampleBatch.NEXT_OBS: [np.random.random((4,))],
                    SampleBatch.DONES: [np.random.choice([False, True])],
                    SampleBatch.EPS_ID: [self.batch_id],
                    SampleBatch.AGENT_INDEX: [self.batch_id],
                    "batch_id": [self.batch_id],
                    "policy_id": [policy_id],
                }
            )
            return batch

        for i in range(num_batches):
            # genera a few policy batches
            policy_batches = {idx: _generate_data(idx) for idx in range(num_policies)}
            self.batch_id += 1
            batch = MultiAgentBatch(policy_batches, 1)
            buffer.add(batch, **kwargs)
Пример #16
0
    def __call__(self, batch: MultiAgentBatch) -> List[SampleBatchType]:
        _check_sample_batch_type(batch)
        batch_count = batch.policy_batches[self.policy_id_to_count_for].count
        if self.drop_samples_for_other_agents:
            batch = MultiAgentBatch(policy_batches={
                self.policy_id_to_count_for:
                batch.policy_batches[self.policy_id_to_count_for]
            },
                                    env_steps=batch.policy_batches[
                                        self.policy_id_to_count_for].count)

        self.buffer.append(batch)
        self.count += batch_count

        if self.count >= self.min_batch_size:
            if self.count > self.min_batch_size * 2:
                logger.info("Collected more training samples than expected "
                            "(actual={}, expected={}). ".format(
                                self.count, self.min_batch_size) +
                            "This may be because you have many workers or "
                            "long episodes in 'complete_episodes' batch mode.")
            out = SampleBatch.concat_samples(self.buffer)
            timer = _get_shared_metrics().timers[SAMPLE_TIMER]
            timer.push(time.perf_counter() - self.batch_start_time)
            timer.push_units_processed(self.count)
            self.batch_start_time = None
            self.buffer = []
            self.count = 0
            return [out]
        return []
Пример #17
0
    def _optimize(self):
        if self._fake_batch:
            fake_batch = SampleBatch(self._fake_batch)
            samples = MultiAgentBatch({
                DEFAULT_POLICY_ID: fake_batch
            }, fake_batch.count)
        else:
            samples = self._replay()

        with self.grad_timer:
            if self.before_learn_on_batch:
                samples = self.before_learn_on_batch(
                    samples,
                    self.workers.local_worker().policy_map,
                    self.train_batch_size)
            info_dict = self.workers.local_worker().learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    # TODO(sven): This is currently structured differently for
                    #  torch/tf. Clean up these results/info dicts across
                    #  policies (note: fixing this in torch_policy.py will
                    #  break e.g. DDPPO!).
                    td_error = info.get("td_error",
                                        info["learner_stats"].get("td_error"))
                    new_priorities = (
                        np.abs(td_error) + self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count
Пример #18
0
    def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType:
        """If this buffer was given a fake batch, return it, otherwise return
        a MultiAgentBatch with samples.
        """
        if self._fake_batch:
            if not isinstance(self._fake_batch, MultiAgentBatch):
                self._fake_batch = SampleBatch(
                    self._fake_batch).as_multi_agent()
            return self._fake_batch

        if self.num_added < self.replay_starts:
            return None
        with self.replay_timer:
            # Lockstep mode: Sample from all policies at the same time an
            # equal amount of steps.
            if self.replay_mode == "lockstep":
                assert (
                    policy_id is None
                ), "`policy_id` specifier not allowed in `locksetp` mode!"
                return self.replay_buffers[_ALL_POLICIES].sample(
                    self.replay_batch_size, beta=self.prioritized_replay_beta)
            elif policy_id is not None:
                return self.replay_buffers[policy_id].sample(
                    self.replay_batch_size, beta=self.prioritized_replay_beta)
            else:
                samples = {}
                for policy_id, replay_buffer in self.replay_buffers.items():
                    samples[policy_id] = replay_buffer.sample(
                        self.replay_batch_size,
                        beta=self.prioritized_replay_beta)
                return MultiAgentBatch(samples, self.replay_batch_size)
    def _replay(self):
        samples = {}
        idxes = None
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if self.synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(
                            self.train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(self.train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)
Пример #20
0
def before_learn_on_batch(multi_agent_batch, policies, train_batch_size):
    samples = {}

    # Modify keys.
    for pid, p in policies.items():
        i = p.config["agent_id"]
        keys = multi_agent_batch.policy_batches[pid].keys()
        keys = ["_".join([k, str(i)]) for k in keys]
        samples.update(
            dict(zip(keys, multi_agent_batch.policy_batches[pid].values())))

    # Make ops and feed_dict to get "new_obs" from target action sampler.
    new_obs_ph_n = [p.new_obs_ph for p in policies.values()]
    new_obs_n = list()
    for k, v in samples.items():
        if "new_obs" in k:
            new_obs_n.append(v)

    for i, p in enumerate(policies.values()):
        feed_dict = {new_obs_ph_n[i]: new_obs_n[i]}
        new_act = p.get_session().run(p.target_act_sampler, feed_dict)
        samples.update({"new_actions_%d" % i: new_act})

    # Share samples among agents.
    policy_batches = {pid: SampleBatch(samples) for pid in policies.keys()}
    return MultiAgentBatch(policy_batches, train_batch_size)
Пример #21
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray_get_and_free(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count
Пример #22
0
    def sample(self,
               num_items: int,
               policy_id: Optional[PolicyID] = None,
               **kwargs) -> Optional[SampleBatchType]:
        """Samples a MultiAgentBatch of `num_items` per one policy's buffer.

        If less than `num_items` records are in the policy's buffer,
        some samples in the results may be repeated to fulfil the batch size
        `num_items` request. Returns an empty batch if there are no items in
        the buffer.

        Args:
            num_items: Number of items to sample from a policy's buffer.
            policy_id: ID of the policy that created the experiences we sample.
                If none is given, sample from all policies.

        Returns:
            Concatenated MultiAgentBatch of items.
            **kwargs: Forward compatibility kwargs.
        """
        # Merge kwargs, overwriting standard call arguments
        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args,
                                          kwargs)

        if self._num_added < self.replay_starts:
            return MultiAgentBatch({}, 0)
        with self.replay_timer:
            # Lockstep mode: Sample from all policies at the same time an
            # equal amount of steps.
            if self.replay_mode == ReplayMode.LOCKSTEP:
                assert (
                    policy_id is None
                ), "`policy_id` specifier not allowed in `lockstep` mode!"
                # In lockstep mode we sample MultiAgentBatches
                return self.replay_buffers[_ALL_POLICIES].sample(
                    num_items, **kwargs)
            elif policy_id is not None:
                sample = self.replay_buffers[policy_id].sample(
                    num_items, **kwargs)
                return MultiAgentBatch({policy_id: sample}, sample.count)
            else:
                samples = {}
                for policy_id, replay_buffer in self.replay_buffers.items():
                    samples[policy_id] = replay_buffer.sample(
                        num_items, **kwargs)
                return MultiAgentBatch(samples,
                                       sum(s.count for s in samples.values()))
Пример #23
0
def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter,
                     sgd_minibatch_size, standardize_fields):
    """Execute minibatch SGD.

    Args:
        samples (SampleBatch): Batch of samples to optimize.
        policies (dict): Dictionary of policies to optimize.
        local_worker (RolloutWorker): Master rollout worker instance.
        num_sgd_iter (int): Number of epochs of optimization to take.
        sgd_minibatch_size (int): Size of minibatches to use for optimization.
        standardize_fields (list): List of sample field names that should be
            normalized prior to optimization.

    Returns:
        averaged info fetches over the last SGD epoch taken.
    """
    if isinstance(samples, SampleBatch):
        samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count)

    fetches = defaultdict(dict)
    for policy_id in policies.keys():
        if policy_id not in samples.policy_batches:
            continue

        batch = samples.policy_batches[policy_id]
        for field in standardize_fields:
            batch[field] = standardized(batch[field])

        learner_stats = defaultdict(list)
        model_stats = defaultdict(list)
        custom_callbacks_stats = defaultdict(list)

        for i in range(num_sgd_iter):
            for minibatch in minibatches(batch, sgd_minibatch_size):
                batch_fetches = (local_worker.learn_on_batch(
                    MultiAgentBatch({policy_id: minibatch},
                                    minibatch.count)))[policy_id]
                for k, v in batch_fetches.get(LEARNER_STATS_KEY, {}).items():
                    learner_stats[k].append(v)
                for k, v in batch_fetches.get("model", {}).items():
                    model_stats[k].append(v)
                for k, v in batch_fetches.get("custom_metrics", {}).items():
                    custom_callbacks_stats[k].append(v)
        fetches[policy_id][LEARNER_STATS_KEY] = averaged(learner_stats)
        fetches[policy_id]["model"] = averaged(model_stats)
        fetches[policy_id]["custom_metrics"] = averaged(custom_callbacks_stats)
    return fetches
Пример #24
0
    def __call__(self, samples: SampleBatchType) -> SampleBatchType:
        _check_sample_batch_type(samples)

        if isinstance(samples, MultiAgentBatch):
            if self.local_worker:
                samples = MultiAgentBatch({
                    pid: batch
                    for pid, batch in samples.policy_batches.items()
                    if self.local_worker.is_policy_to_train(pid, batch)
                }, samples.count)
            else:
                samples = MultiAgentBatch({
                    k: v
                    for k, v in samples.policy_batches.items()
                    if k in self.policy_ids
                }, samples.count)

        return samples
Пример #25
0
 def add_batch(self, batch):
     # Handle everything as if multiagent
     if isinstance(batch, SampleBatch):
         batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
     self.buffer.append(batch)
     self.cur_size += batch.count
     self.num_added += batch.count
     while self.cur_size > self.buffer_size:
         self.cur_size -= self.buffer.pop(0).count
Пример #26
0
def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter,
                     sgd_minibatch_size, standardize_fields):
    """Execute minibatch SGD.

    Args:
        samples (SampleBatch): Batch of samples to optimize.
        policies (dict): Dictionary of policies to optimize.
        local_worker (RolloutWorker): Master rollout worker instance.
        num_sgd_iter (int): Number of epochs of optimization to take.
        sgd_minibatch_size (int): Size of minibatches to use for optimization.
        standardize_fields (list): List of sample field names that should be
            normalized prior to optimization.

    Returns:
        averaged info fetches over the last SGD epoch taken.
    """
    if isinstance(samples, SampleBatch):
        samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count)

    # Use LearnerInfoBuilder as a unified way to build the final
    # results dict from `learn_on_loaded_batch` call(s).
    # This makes sure results dicts always have the same structure
    # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
    # tf vs torch).
    learner_info_builder = LearnerInfoBuilder(num_devices=1)
    for policy_id in policies.keys():
        if policy_id not in samples.policy_batches:
            continue

        batch = samples.policy_batches[policy_id]
        for field in standardize_fields:
            batch[field] = standardized(batch[field])

        for i in range(num_sgd_iter):
            for minibatch in minibatches(batch, sgd_minibatch_size):
                results = (local_worker.learn_on_batch(
                    MultiAgentBatch({policy_id: minibatch},
                                    minibatch.count)))[policy_id]
                learner_info_builder.add_learn_on_batch_results(
                    results, policy_id)

    learner_info = learner_info_builder.finalize()
    return learner_info
Пример #27
0
        def mix_batches(_policy_id):
            """Mixes old with new samples.

            Tries to mix according to self.replay_ratio on average.
            If not enough new samples are available, mixes in less old samples
            to retain self.replay_ratio on average.
            """

            def round_up_or_down(value, ratio):
                """Returns an integer averaging to value*ratio."""
                product = value * ratio
                ceil_prob = product % 1
                if random.uniform(0, 1) < ceil_prob:
                    return int(np.ceil(product))
                else:
                    return int(np.floor(product))

            max_num_new = round_up_or_down(num_items, 1 - self.replay_ratio)
            # if num_samples * self.replay_ratio is not round,
            # we need one more sample with a probability of
            # (num_items*self.replay_ratio) % 1

            _buffer = self.replay_buffers[_policy_id]
            output_batches = self.last_added_batches[_policy_id][:max_num_new]
            self.last_added_batches[_policy_id] = self.last_added_batches[_policy_id][
                max_num_new:
            ]

            # No replay desired
            if self.replay_ratio == 0.0:
                return SampleBatch.concat_samples(output_batches)
            # Only replay desired
            elif self.replay_ratio == 1.0:
                return _buffer.sample(num_items, **kwargs)

            num_new = len(output_batches)

            if np.isclose(num_new, num_items * (1 - self.replay_ratio)):
                # The optimal case, we can mix in a round number of old
                # samples on average
                num_old = num_items - max_num_new
            else:
                # We never want to return more elements than num_items
                num_old = min(
                    num_items - max_num_new,
                    round_up_or_down(
                        num_new, self.replay_ratio / (1 - self.replay_ratio)
                    ),
                )

            output_batches.append(_buffer.sample(num_old, **kwargs))
            # Depending on the implementation of underlying buffers, samples
            # might be SampleBatches
            output_batches = [batch.as_multi_agent() for batch in output_batches]
            return MultiAgentBatch.concat_samples(output_batches)
Пример #28
0
 def add_batch(self, batch):
     # Make a copy so the replay buffer doesn't pin plasma memory.
     batch = batch.copy()
     # Handle everything as if multiagent
     if isinstance(batch, SampleBatch):
         batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
     with self.add_batch_timer:
         if self.replay_mode == "lockstep":
             # Note that prioritization is not supported in this mode.
             for s in batch.timeslices(self.replay_sequence_length):
                 self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
         else:
             for policy_id, b in batch.policy_batches.items():
                 for s in b.timeslices(self.replay_sequence_length):
                     if "weights" in s:
                         weight = np.mean(s["weights"])
                     else:
                         weight = None
                     self.replay_buffers[policy_id].add(s, weight=weight)
     self.num_added += batch.count
Пример #29
0
 def add_batch(self, batch):
     # Handle everything as if multiagent
     if isinstance(batch, SampleBatch):
         batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
     with self.add_batch_timer:
         for policy_id, s in batch.policy_batches.items():
             for row in s.rows():
                 self.replay_buffers[policy_id].add(
                     row["obs"], row["actions"], row["rewards"],
                     row["new_obs"], row["dones"], row["weights"])
     self.num_added += batch.count
Пример #30
0
    def replay(self):
        if self._fake_batch:
            fake_batch = SampleBatch(self._fake_batch)
            return MultiAgentBatch({DEFAULT_POLICY_ID: fake_batch},
                                   fake_batch.count)

        if self.num_added < self.replay_starts:
            return None

        with self.replay_timer:
            if self.replay_mode == "lockstep":
                return self.replay_buffers[_ALL_POLICIES].sample(
                    self.replay_batch_size, beta=self.prioritized_replay_beta)
            else:
                samples = {}
                for policy_id, replay_buffer in self.replay_buffers.items():
                    samples[policy_id] = replay_buffer.sample(
                        self.replay_batch_size,
                        beta=self.prioritized_replay_beta)
                return MultiAgentBatch(samples, self.replay_batch_size)