Beispiel #1
0
    def replay(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \
            Optional[SampleBatchType]:
        buffer = self.replay_buffers[policy_id]
        # Return None, if:
        # - Buffer empty or
        # - `replay_ratio` < 1.0 (new samples required in returned batch)
        #   and no new samples to mix with replayed ones.
        if len(buffer) == 0 or (len(self.last_added_batches[policy_id]) == 0
                                and self.replay_ratio < 1.0):
            return None

        # Mix buffer's last added batches with older replayed batches.
        with self.replay_timer:
            output_batches = self.last_added_batches[policy_id]
            self.last_added_batches[policy_id] = []

            # No replay desired -> Return here.
            if self.replay_ratio == 0.0:
                return SampleBatch.concat_samples(output_batches)
            # Only replay desired -> Return a (replayed) sample from the
            # buffer.
            elif self.replay_ratio == 1.0:
                return buffer.replay()

            # Replay ratio = old / [old + new]
            # Replay proportion: old / new
            num_new = len(output_batches)
            replay_proportion = self.replay_proportion
            while random.random() < num_new * replay_proportion:
                replay_proportion -= 1
                output_batches.append(buffer.replay())
            return SampleBatch.concat_samples(output_batches)
Beispiel #2
0
    def test_concat_max_seq_len(self):
        """Tests, SampleBatches.concat_samples() max_seq_len."""
        s1 = SampleBatch({
            "a": np.array([1, 2, 3]),
            "b": {
                "c": np.array([4, 5, 6])
            },
            SampleBatch.SEQ_LENS: [1, 2]
        })
        s2 = SampleBatch({
            "a": np.array([2, 3, 4]),
            "b": {
                "c": np.array([5, 6, 7])
            },
            SampleBatch.SEQ_LENS: [3]
        })

        s3 = SampleBatch({
            "a": np.array([2, 3, 4]),
            "b": {
                "c": np.array([5, 6, 7])
            },
        })

        concatd = SampleBatch.concat_samples([s1, s2])
        check(concatd.max_seq_len, s2.max_seq_len)

        with self.assertRaises(ValueError):
            SampleBatch.concat_samples([s1, s2, s3])
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.workers.remote_workers():
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote()
                            for e in self.workers.remote_workers()
                        ]))
                else:
                    samples.append(self.workers.local_worker().sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        # Unfortunate to have to hack it like this, but not sure how else to do it.
        # Setting the phase to zeros results in policy optimization, and to ones results in aux optimization.
        # These have to be added prior to the policy sgd.
        samples["phase"] = np.zeros(samples.count)

        with self.grad_timer:
            fetches = do_minibatch_sgd(samples, self.policies,
                                       self.workers.local_worker(),
                                       self.num_sgd_iter,
                                       self.sgd_minibatch_size,
                                       self.standardize_fields)
        self.grad_timer.push_units_processed(samples.count)

        if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
            self.learner_stats = fetches[DEFAULT_POLICY_ID]
        else:
            self.learner_stats = fetches

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count

        if self.num_steps_sampled > self.aux_loss_start_after_num_steps:
            # Add samples to the memory to be provided to the aux loss.
            self._remove_unnecessary_data(samples)
            self.memory.append(samples)

            # Optionally run the aux optimization.
            if len(self.memory) >= self.aux_loss_every_k:
                samples = SampleBatch.concat_samples(self.memory)
                self._add_policy_logits(samples)
                # Ones indicate aux phase.
                samples["phase"] = np.ones_like(samples["phase"])
                do_minibatch_sgd(samples, self.policies,
                                 self.workers.local_worker(),
                                 self.aux_loss_num_sgd_iter,
                                 self.sgd_minibatch_size, [])
                self.memory = []

        return self.learner_stats
Beispiel #4
0
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.workers.remote_workers():
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote()
                            for e in self.workers.remote_workers()
                        ]))
                else:
                    samples.append(self.workers.local_worker().sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        # Handle everything as if multiagent
        if isinstance(samples, SampleBatch):
            samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples},
                                      samples.count)

        fetches = {}
        with self.grad_timer:
            for policy_id, policy in self.policies.items():
                if policy_id not in samples.policy_batches:
                    continue

                batch = samples.policy_batches[policy_id]
                for field in self.standardize_fields:
                    value = batch[field]
                    standardized = (value - value.mean()) / max(
                        1e-4, value.std())
                    batch[field] = standardized

                for i in range(self.num_sgd_iter):
                    iter_extra_fetches = defaultdict(list)
                    for minibatch in self._minibatches(batch):
                        batch_fetches = (
                            self.workers.local_worker().learn_on_batch(
                                MultiAgentBatch({policy_id: minibatch},
                                                minibatch.count)))[policy_id]
                        for k, v in batch_fetches[LEARNER_STATS_KEY].items():
                            iter_extra_fetches[k].append(v)
                    logger.debug("{} {}".format(i,
                                                _averaged(iter_extra_fetches)))
                fetches[policy_id] = _averaged(iter_extra_fetches)

        self.grad_timer.push_units_processed(samples.count)
        if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
            self.learner_stats = fetches[DEFAULT_POLICY_ID]
        else:
            self.learner_stats = fetches
        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return self.learner_stats
Beispiel #5
0
    def estimate(
        self,
        batch: SampleBatchType,
    ) -> OffPolicyEstimate:
        self.check_can_estimate_for(batch)
        estimates = []
        # Split data into train and test batches
        for train_episodes, test_episodes in train_test_split(
                batch,
                self.train_test_split_val,
                self.k,
        ):

            # Train Q-function
            if train_episodes:
                # Reinitialize model
                self.model.reset()
                train_batch = SampleBatch.concat_samples(train_episodes)
                losses = self.train(train_batch)
                self.losses.append(losses)

            # Calculate doubly robust OPE estimates
            for episode in test_episodes:
                rewards, old_prob = episode["rewards"], episode["action_prob"]
                new_prob = np.exp(self.action_log_likelihood(episode))

                v_old = 0.0
                v_new = 0.0
                q_values = self.model.estimate_q(episode[SampleBatch.OBS],
                                                 episode[SampleBatch.ACTIONS])
                q_values = convert_to_numpy(q_values)

                all_actions = np.zeros(
                    [episode.count, self.policy.action_space.n])
                all_actions[:] = np.arange(self.policy.action_space.n)
                # Two transposes required for torch.distributions to work
                tmp_episode = episode.copy()
                tmp_episode[SampleBatch.ACTIONS] = all_actions.T
                action_probs = np.exp(
                    self.action_log_likelihood(tmp_episode)).T
                v_values = self.model.estimate_v(episode[SampleBatch.OBS],
                                                 action_probs)
                v_values = convert_to_numpy(v_values)

                for t in reversed(range(episode.count)):
                    v_old = rewards[t] + self.gamma * v_old
                    v_new = v_values[t] + (new_prob[t] / old_prob[t]) * (
                        rewards[t] + self.gamma * v_new - q_values[t])
                v_new = v_new.item()

                estimates.append(
                    OffPolicyEstimate(
                        self.name,
                        {
                            "v_old": v_old,
                            "v_new": v_new,
                            "v_gain": v_new / max(1e-8, v_old),
                        },
                    ))
        return estimates
Beispiel #6
0
        def inner_adaptation_steps(itr):
            buf = []
            split = []
            metrics = {}
            for samples in itr:

                # Processing Samples (Standardize Advantages)
                split_lst = []
                for sample in samples:
                    sample["advantages"] = standardized(sample["advantages"])
                    split_lst.append(sample.count)

                buf.extend(samples)
                split.append(split_lst)

                adapt_iter = len(split) - 1
                metrics = post_process_metrics(adapt_iter, workers, metrics)
                if len(split) > inner_steps:
                    out = SampleBatch.concat_samples(buf)
                    out["split"] = np.array(split)
                    buf = []
                    split = []

                    # Reporting Adaptation Rew Diff
                    ep_rew_pre = metrics["episode_reward_mean"]
                    ep_rew_post = metrics["episode_reward_mean_adapt_" +
                                          str(inner_steps)]
                    metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre
                    yield out, metrics
                    metrics = {}
                else:
                    inner_adaptation(workers, samples)
Beispiel #7
0
def collect_samples(agents, sample_batch_size, num_envs_per_worker,
                    train_batch_size):
    """Collects at least train_batch_size samples, never discarding any."""

    num_timesteps_so_far = 0
    trajectories = []
    agent_dict = {}

    for agent in agents:
        fut_sample = agent.sample.remote()
        agent_dict[fut_sample] = agent

    while agent_dict:
        [fut_sample], _ = ray.wait(list(agent_dict))
        agent = agent_dict.pop(fut_sample)
        next_sample = ray_get_and_free(fut_sample)
        num_timesteps_so_far += next_sample.count
        trajectories.append(next_sample)

        # Only launch more tasks if we don't already have enough pending
        pending = len(agent_dict) * sample_batch_size * num_envs_per_worker
        if num_timesteps_so_far + pending < train_batch_size:
            fut_sample2 = agent.sample.remote()
            agent_dict[fut_sample2] = agent

    return SampleBatch.concat_samples(trajectories)
Beispiel #8
0
    def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
        _check_sample_batch_type(batch)
        self.buffer.append(batch)

        if self.count_steps_by == "env_steps":
            self.count += batch.count
        else:
            assert isinstance(batch, MultiAgentBatch), \
                "`count_steps_by=agent_steps` only allowed in multi-agent " \
                "environments!"
            self.count += batch.agent_steps()

        if self.count >= self.min_batch_size:
            if self.count > self.min_batch_size * 2:
                logger.info("Collected more training samples than expected "
                            "(actual={}, expected={}). ".format(
                                self.count, self.min_batch_size) +
                            "This may be because you have many workers or "
                            "long episodes in 'complete_episodes' batch mode.")
            out = SampleBatch.concat_samples(self.buffer)
            timer = _get_shared_metrics().timers[SAMPLE_TIMER]
            timer.push(time.perf_counter() - self.batch_start_time)
            timer.push_units_processed(self.count)
            self.batch_start_time = None
            self.buffer = []
            self.count = 0
            return [out]
        return []
Beispiel #9
0
    def test_sequence_size(self):
        # Seq-len=1.
        buffer = PrioritizedReplayBuffer(
            capacity=100, alpha=0.1, storage_unit="fragments"
        )
        for _ in range(200):
            buffer.add(self._generate_data())
        assert len(buffer._storage) == 100, len(buffer._storage)
        assert buffer.stats()["added_count"] == 200, buffer.stats()
        # Test get_state/set_state.
        state = buffer.get_state()
        new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
        new_memory.set_state(state)
        assert len(new_memory._storage) == 100, len(new_memory._storage)
        assert new_memory.stats()["added_count"] == 200, new_memory.stats()

        # Seq-len=5.
        buffer = PrioritizedReplayBuffer(
            capacity=100, alpha=0.1, storage_unit="fragments"
        )
        for _ in range(40):
            buffer.add(
                SampleBatch.concat_samples([self._generate_data() for _ in range(5)])
            )
        assert len(buffer._storage) == 20, len(buffer._storage)
        assert buffer.stats()["added_count"] == 200, buffer.stats()
        # Test get_state/set_state.
        state = buffer.get_state()
        new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
        new_memory.set_state(state)
        assert len(new_memory._storage) == 20, len(new_memory._storage)
        assert new_memory.stats()["added_count"] == 200, new_memory.stats()
Beispiel #10
0
    def postprocess_with_HER(policy,
                             sample_batch,
                             _other_agent_batches=None,
                             _episode=None):
        """
            postprocess the sampled batch, inject modified trajectory with modified goal condition
        """

        # Hindsight Experience Replay trajectory augmentation
        if type(sample_batch) is SampleBatch:
            # init list of new trajectories
            augmented_trajs = [sample_batch]
            # init HER sampling strategy
            her_sampler = SamplingStrategy(policy, sample_batch)
            # sample n new trajectories using sampling strategy
            for i in range(policy.config['num_her_traj']):
                augmented_trajs.append(her_sampler.sample_trajectory())
            # concatenate sampled trajectories
            sample_batch = SampleBatch.concat_samples(augmented_trajs)

        # RLlib Original DQN postprocess_fn Implementation
        sample_batch = postprocess_nstep_and_prio(policy, sample_batch,
                                                  _other_agent_batches,
                                                  _episode)

        return sample_batch
Beispiel #11
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray_get_and_free(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count
Beispiel #12
0
def sample_min_n_steps_from_buffer(
        replay_buffer: ReplayBuffer, min_steps: int,
        count_by_agent_steps: bool) -> Optional[SampleBatchType]:
    """Samples a minimum of n timesteps from a given replay buffer.

    This utility method is primarily used by the QMIX algorithm and helps with
    sampling a given number of time steps which has stored samples in units
    of sequences or complete episodes. Samples n batches from replay buffer
    until the total number of timesteps reaches `train_batch_size`.

    Args:
        replay_buffer: The replay buffer to sample from
        num_timesteps: The number of timesteps to sample
        count_by_agent_steps: Whether to count agent steps or env steps

    Returns:
        A concatenated SampleBatch or MultiAgentBatch with samples from the
        buffer.
    """
    train_batch_size = 0
    train_batches = []
    while train_batch_size < min_steps:
        batch = replay_buffer.sample(num_items=1)
        batch_len = batch.agent_steps(
        ) if count_by_agent_steps else batch.env_steps()
        if batch_len == 0:
            # Replay has not started, so we can't accumulate timesteps here
            return batch
        train_batches.append(batch)
        train_batch_size += batch_len
    # All batch types are the same type, hence we can use any concat_samples()
    train_batch = SampleBatch.concat_samples(train_batches)
    return train_batch
Beispiel #13
0
    def test_sequence_size(self):
        # Seq-len=1.
        memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
        for _ in range(200):
            memory.add(self._generate_data(), weight=None)
        assert len(memory._storage) == 100, len(memory._storage)
        assert memory.stats()["added_count"] == 200, memory.stats()
        # Test get_state/set_state.
        state = memory.get_state()
        new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
        new_memory.set_state(state)
        assert len(new_memory._storage) == 100, len(new_memory._storage)
        assert new_memory.stats()["added_count"] == 200, new_memory.stats()

        # Seq-len=5.
        memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
        for _ in range(40):
            memory.add(
                SampleBatch.concat_samples(
                    [self._generate_data() for _ in range(5)]),
                weight=None,
            )
        assert len(memory._storage) == 20, len(memory._storage)
        assert memory.stats()["added_count"] == 200, memory.stats()
        # Test get_state/set_state.
        state = memory.get_state()
        new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
        new_memory.set_state(state)
        assert len(new_memory._storage) == 20, len(new_memory._storage)
        assert new_memory.stats()["added_count"] == 200, new_memory.stats()
    def _add_sample_batch_to_buffer(self,
                                    buffer,
                                    batch_size,
                                    num_batches=5,
                                    **kwargs):
        self.eps_id = 0

        def _generate_data():
            self.eps_id += 1
            return SampleBatch({
                SampleBatch.T: [0, 1],
                SampleBatch.ACTIONS:
                2 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS:
                2 * [np.random.rand()],
                SampleBatch.OBS:
                2 * [np.random.random((4, ))],
                SampleBatch.NEXT_OBS:
                2 * [np.random.random((4, ))],
                SampleBatch.DONES:
                2 * [np.random.choice([False, True])],
                SampleBatch.EPS_ID:
                2 * [self.eps_id],
                SampleBatch.AGENT_INDEX:
                2 * [0],
                "batch_id":
                2 * [self.batch_id],
            })

        for i in range(num_batches):
            data = [_generate_data() for _ in range(batch_size)]
            self.batch_id += 1
            batch = SampleBatch.concat_samples(data)
            buffer.add(batch, **kwargs)
    def postprocess_with_HER(policy,
                             sample_batch,
                             _other_agent_batches=None,
                             _episode=None):
        """
            postprocess the sampled batch, inject modified trajectory with modified goal condition
        """
        import numpy as np
        # Hindsight Experience Replay trajectory augmentation
        if (type(sample_batch) is SampleBatch) and (
                policy.config['use_HER']) and (sample_batch['obs'].shape[0] >
                                               0):
            # init list of new trajectories
            augmented_trajs = [sample_batch]
            # init HER sampling strategy
            her_sampler = SamplingStrategy(policy, sample_batch)
            # sample n new trajectories using sampling strategy
            for i in range(policy.config['num_HER_traj']):
                augmented_trajs.append(her_sampler.sample_trajectory())
            # concatenate sampled trajectories
            sample_batch = SampleBatch.concat_samples(augmented_trajs)

        # Original postprocess_fn Implementation
        sample_batch = postprocess_fn(policy, sample_batch,
                                      _other_agent_batches, _episode)
        # code.interact(local=locals())
        return sample_batch
Beispiel #16
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.remote_evaluators:
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.learn_on_batch(samples)
                self.learner_stats = get_learner_stats(fetches)
                if self.num_sgd_iter > 1:
                    logger.debug("{} {}".format(i, fetches))
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return self.learner_stats
Beispiel #17
0
        def inner_adaptation_steps(itr):
            buf = []
            split = []
            metrics = {}
            for samples in itr:
                print("Collecting Samples, Inner Adaptation {}".format(
                    len(split)))
                # Processing Samples (Standardize Advantages)
                samples, split_lst = post_process_samples(samples, config)

                buf.extend(samples)
                split.append(split_lst)

                adapt_iter = len(split) - 1
                prefix = "DynaTrajInner_" + str(adapt_iter)
                metrics = post_process_metrics(prefix, workers, metrics)

                if len(split) > num_inner_steps:
                    out = SampleBatch.concat_samples(buf)
                    out["split"] = np.array(split)
                    buf = []
                    split = []

                    yield out, metrics
                    metrics = {}
                else:
                    inner_adaptation(workers, samples)
Beispiel #18
0
def collect_samples_straggler_mitigation(agents, train_batch_size):
    """Collects at least train_batch_size samples.

    This is the legacy behavior as of 0.6, and launches extra sample tasks to
    potentially improve performance but can result in many wasted samples.
    """

    num_timesteps_so_far = 0
    trajectories = []
    agent_dict = {}

    for agent in agents:
        fut_sample = agent.sample.remote()
        agent_dict[fut_sample] = agent

    while num_timesteps_so_far < train_batch_size:
        # TODO(pcm): Make wait support arbitrary iterators and remove the
        # conversion to list here.
        [fut_sample], _ = ray.wait(list(agent_dict))
        agent = agent_dict.pop(fut_sample)
        # Start task with next trajectory and record it in the dictionary.
        fut_sample2 = agent.sample.remote()
        agent_dict[fut_sample2] = agent

        next_sample = ray_get_and_free(fut_sample)
        num_timesteps_so_far += next_sample.count
        trajectories.append(next_sample)

    logger.info("Discarding {} sample tasks".format(len(agent_dict)))
    return SampleBatch.concat_samples(trajectories)
Beispiel #19
0
def ParallelRollouts(workers: WorkerSet,
                     mode="bulk_sync") -> LocalIterator[SampleBatch]:
    """Operator to collect experiences in parallel from rollout workers.

    If there are no remote workers, experiences will be collected serially from
    the local worker instance instead.

    Arguments:
        workers (WorkerSet): set of rollout workers to use.
        mode (str): One of {'async', 'bulk_sync'}.
            - In 'async' mode, batches are returned as soon as they are
              computed by rollout workers with no order guarantees.
            - In 'bulk_sync' mode, we collect one batch from each worker
              and concatenate them together into a large batch to return.

    Returns:
        A local iterator over experiences collected in parallel.

    Examples:
        >>> rollouts = ParallelRollouts(workers, mode="async")
        >>> batch = next(rollouts)
        >>> print(batch.count)
        50  # config.sample_batch_size

        >>> rollouts = ParallelRollouts(workers, mode="bulk_sync")
        >>> batch = next(rollouts)
        >>> print(batch.count)
        200  # config.sample_batch_size * config.num_workers

    Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context.
    """
    def report_timesteps(batch):
        metrics = LocalIterator.get_metrics()
        metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
        return batch

    if not workers.remote_workers():
        # Handle the serial sampling case.
        def sampler(_):
            while True:
                yield workers.local_worker().sample()

        return (LocalIterator(sampler,
                              MetricsContext()).for_each(report_timesteps))

    # Create a parallel iterator over generated experiences.
    rollouts = from_actors(workers.remote_workers())

    if mode == "bulk_sync":
        return rollouts \
            .batch_across_shards() \
            .for_each(lambda batches: SampleBatch.concat_samples(batches)) \
            .for_each(report_timesteps)
    elif mode == "async":
        return rollouts.gather_async().for_each(report_timesteps)
    else:
        raise ValueError(
            "mode must be one of 'bulk_sync', 'async', got '{}'".format(mode))
 def _sgd_step(self):
     samples = [random.choice(self.replay_buffer)]
     while sum(s.count for s in samples) < self.train_batch_size:
         samples.append(random.choice(self.replay_buffer))
     samples = SampleBatch.concat_samples(samples)
     info_dict = self.workers.local_worker().learn_on_batch(samples)
     for policy_id, info in info_dict.items():
         self.learner_stats[policy_id] = get_learner_stats(info)
     self.num_steps_trained += samples.count
     return info_dict
Beispiel #21
0
 def test_concat(self):
     b1 = SampleBatch({"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])})
     b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])})
     b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])})
     b12 = b1.concat(b2)
     self.assertEqual(b12["a"].tolist(), [1, 2, 3, 1])
     self.assertEqual(b12["b"].tolist(), [4, 5, 6, 4])
     b = SampleBatch.concat_samples([b1, b2, b3])
     self.assertEqual(b["a"].tolist(), [1, 2, 3, 1, 1])
     self.assertEqual(b["b"].tolist(), [4, 5, 6, 4, 5])
Beispiel #22
0
        def mix_batches(_policy_id):
            """Mixes old with new samples.

            Tries to mix according to self.replay_ratio on average.
            If not enough new samples are available, mixes in less old samples
            to retain self.replay_ratio on average.
            """

            def round_up_or_down(value, ratio):
                """Returns an integer averaging to value*ratio."""
                product = value * ratio
                ceil_prob = product % 1
                if random.uniform(0, 1) < ceil_prob:
                    return int(np.ceil(product))
                else:
                    return int(np.floor(product))

            max_num_new = round_up_or_down(num_items, 1 - self.replay_ratio)
            # if num_samples * self.replay_ratio is not round,
            # we need one more sample with a probability of
            # (num_items*self.replay_ratio) % 1

            _buffer = self.replay_buffers[_policy_id]
            output_batches = self.last_added_batches[_policy_id][:max_num_new]
            self.last_added_batches[_policy_id] = self.last_added_batches[_policy_id][
                max_num_new:
            ]

            # No replay desired
            if self.replay_ratio == 0.0:
                return SampleBatch.concat_samples(output_batches)
            # Only replay desired
            elif self.replay_ratio == 1.0:
                return _buffer.sample(num_items, **kwargs)

            num_new = len(output_batches)

            if np.isclose(num_new, num_items * (1 - self.replay_ratio)):
                # The optimal case, we can mix in a round number of old
                # samples on average
                num_old = num_items - max_num_new
            else:
                # We never want to return more elements than num_items
                num_old = min(
                    num_items - max_num_new,
                    round_up_or_down(
                        num_new, self.replay_ratio / (1 - self.replay_ratio)
                    ),
                )

            output_batches.append(_buffer.sample(num_old, **kwargs))
            # Depending on the implementation of underlying buffers, samples
            # might be SampleBatches
            output_batches = [batch.as_multi_agent() for batch in output_batches]
            return MultiAgentBatch.concat_samples(output_batches)
Beispiel #23
0
    def replay(
        self, policy_id: PolicyID = DEFAULT_POLICY_ID
    ) -> Optional[SampleBatchType]:
        if self.replay_mode == ReplayMode.LOCKSTEP and policy_id != _ALL_POLICIES:
            raise ValueError(
                "Trying to sample from single policy's buffer in lockstep "
                "mode. In lockstep mode, all policies' experiences are "
                "sampled from a single replay buffer which is accessed "
                "with the policy id `{}`".format(_ALL_POLICIES)
            )

        buffer = self.replay_buffers[policy_id]
        # Return None, if:
        # - Buffer empty or
        # - `replay_ratio` < 1.0 (new samples required in returned batch)
        #   and no new samples to mix with replayed ones.
        if len(buffer) == 0 or (
            len(self.last_added_batches[policy_id]) == 0 and self.replay_ratio < 1.0
        ):
            return None

        # Mix buffer's last added batches with older replayed batches.
        with self.replay_timer:
            output_batches = self.last_added_batches[policy_id]
            self.last_added_batches[policy_id] = []

            # No replay desired -> Return here.
            if self.replay_ratio == 0.0:
                return SampleBatch.concat_samples(output_batches)
            # Only replay desired -> Return a (replayed) sample from the
            # buffer.
            elif self.replay_ratio == 1.0:
                return buffer.replay()

            # Replay ratio = old / [old + new]
            # Replay proportion: old / new
            num_new = len(output_batches)
            replay_proportion = self.replay_proportion
            while random.random() < num_new * replay_proportion:
                replay_proportion -= 1
                output_batches.append(buffer.replay())
            return SampleBatch.concat_samples(output_batches)
Beispiel #24
0
 def _optimize(self):
     samples = [random.choice(self.replay_buffer)]
     while sum(s.count for s in samples) < self.train_batch_size:
         samples.append(random.choice(self.replay_buffer))
     samples = SampleBatch.concat_samples(samples)
     with self.grad_timer:
         info_dict = self.local_evaluator.learn_on_batch(samples)
         for policy_id, info in info_dict.items():
             self.learner_stats[policy_id] = get_learner_stats(info)
         self.grad_timer.push_units_processed(samples.count)
     self.num_steps_trained += samples.count
     return info_dict
Beispiel #25
0
 def __call__(self, batch: SampleBatch) -> List[SampleBatch]:
     if not isinstance(batch, SampleBatch):
         raise ValueError("Expected type SampleBatch, got {}: {}".format(
             type(batch), batch))
     self.buffer.append(batch)
     self.count += batch.count
     if self.count >= self.min_batch_size:
         out = SampleBatch.concat_samples(self.buffer)
         self.buffer = []
         self.count = 0
         return [out]
     return []
Beispiel #26
0
        def mix_batches(_policy_id):
            _buffer = self.replay_buffers[policy_id]
            output_batches = self.last_added_batches[_policy_id]
            self.last_added_batches[_policy_id] = []

            # No replay desired
            if self.replay_ratio == 0.0:
                return SampleBatch.concat_samples(output_batches)
            # Only replay desired
            elif self.replay_ratio == 1.0:
                return _buffer.sample(num_items,
                                      beta=self.prioritized_replay_beta)

            # Replay ratio = old / [old + new]
            # Replay proportion: old / new
            num_new = len(output_batches)
            replay_proportion = self.replay_proportion
            while random.random() < num_new * replay_proportion:
                replay_proportion -= 1
                output_batches.append(_buffer.sample(num_items))
            return SampleBatch.concat_samples(output_batches)
Beispiel #27
0
    def training_step(self) -> ResultDict:
        """TODO:

        Returns:
            The results dict from executing the training iteration.
        """

        # Sample n MultiAgentBatches from n workers.
        new_sample_batches = synchronous_parallel_sample(
            worker_set=self.workers, concat=False)

        for batch in new_sample_batches:
            # Update sampling step counters.
            self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
            self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
            # Store new samples in the replay buffer
            # Use deprecated add_batch() to support old replay buffers for now
            if self.local_replay_buffer is not None:
                self.local_replay_buffer.add(batch)

        if self.local_replay_buffer is not None:
            train_batch = self.local_replay_buffer.sample(
                self.config["train_batch_size"])
        else:
            train_batch = SampleBatch.concat_samples(new_sample_batches)

        # Learn on the training batch.
        # Use simple optimizer (only for multi-agent or tf-eager; all other
        # cases should use the multi-GPU optimizer, even if only using 1 GPU)
        train_results = {}
        if train_batch is not None:
            if self.config.get("simple_optimizer") is True:
                train_results = train_one_step(self, train_batch)
            else:
                train_results = multi_gpu_train_one_step(self, train_batch)

        # TODO: Move training steps counter update outside of `train_one_step()` method.
        # # Update train step counters.
        # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
        # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()

        # Update weights and global_vars - after learning on the local worker - on all
        # remote workers.
        global_vars = {
            "timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
        }
        with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
            self.workers.sync_weights(global_vars=global_vars)

        # Return all collected metrics for the iteration.
        return train_results
Beispiel #28
0
def get_cross_policy_object(multi_agent_batch, self_optimizer):
    """Add contents into cross_policy_object, which passed to each policy."""
    config = self_optimizer.workers._remote_config

    if not config["use_joint_dataset"]:
        joint_obs = SampleBatch.concat_samples(
            list(multi_agent_batch.policy_batches.values()))[
                SampleBatch.CUR_OBS]
    else:
        sample_size = config.get("joint_dataset_sample_batch_size")
        assert sample_size is not None, "You should specify the value of: " \
                                        "joint_dataset_sample_batch_size " \
                                        "in config!"
        samples = [multi_agent_batch]
        count_dict = {
            k: v.count
            for k, v in multi_agent_batch.policy_batches.items()
        }
        for k in self_optimizer.workers.local_worker().policy_map.keys():
            if k not in count_dict:
                count_dict[k] = 0

        while any([v < sample_size for v in count_dict.values()]):
            tmp_batch = self_optimizer.workers.local_worker().sample()
            samples.append(tmp_batch)
            for k, v in tmp_batch.policy_batches.items():
                assert k in count_dict, count_dict
                count_dict[k] += v.count
        multi_agent_batch = MultiAgentBatch.concat_samples(samples)

        joint_obs = []
        pid_list = []
        for pid, batch in multi_agent_batch.policy_batches.items():
            batch.shuffle()
            assert batch.count >= sample_size, batch
            joint_obs.append(batch.slice(0, sample_size)['obs'])
            pid_list.append(pid)
        joint_obs = np.concatenate(joint_obs)

    def _replay(policy, pid):
        act, _, infos = policy.compute_actions(joint_obs)
        return pid, act, infos

    # ATTENTION!!! Here is MYSELF replaying JOINT OBSERVATION
    ret = {
        pid: act
        for pid, act, infos in
        self_optimizer.workers.local_worker().foreach_policy(_replay)
    }
    return {JOINT_OBS: joint_obs, PEER_ACTION: ret}
Beispiel #29
0
 def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
     _check_sample_batch_type(batch)
     self.buffer.append(batch)
     self.count += batch.count
     if self.count >= self.min_batch_size:
         out = SampleBatch.concat_samples(self.buffer)
         timer = LocalIterator.get_metrics().timers[SAMPLE_TIMER]
         timer.push(time.perf_counter() - self.batch_start_time)
         timer.push_units_processed(self.count)
         self.batch_start_time = None
         self.buffer = []
         self.count = 0
         return [out]
     return []
Beispiel #30
0
def training_workflow(config, reporter):
    # Setup policy and policy evaluation actors
    env = gym.make("CartPole-v0")
    policy = CustomPolicy(env.observation_space, env.action_space, {})
    workers = [
        RolloutWorker.as_remote().remote(
            env_creator=lambda c: gym.make("CartPole-v0"), policy=CustomPolicy)
        for _ in range(config["num_workers"])
    ]

    for _ in range(config["num_iters"]):
        # Broadcast weights to the policy evaluation workers
        weights = ray.put({DEFAULT_POLICY_ID: policy.get_weights()})
        for w in workers:
            w.set_weights.remote(weights)

        # Gather a batch of samples
        T1 = SampleBatch.concat_samples(
            ray.get([w.sample.remote() for w in workers]))

        # Update the remote policy replicas and gather another batch of samples
        new_value = policy.w * 2.0
        for w in workers:
            w.for_policy.remote(lambda p: p.update_some_value(new_value))

        # Gather another batch of samples
        T2 = SampleBatch.concat_samples(
            ray.get([w.sample.remote() for w in workers]))

        # Improve the policy using the T1 batch
        policy.learn_on_batch(T1)

        # Do some arbitrary updates based on the T2 batch
        policy.update_some_value(sum(T2["rewards"]))

        reporter(**collect_metrics(remote_workers=workers))