Пример #1
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batches = ray.get(
                    [e.sample.remote() for e in self.remote_evaluators])
            else:
                batches = [self.local_evaluator.sample()]

            # Handle everything as if multiagent
            tmp = []
            for batch in batches:
                if isinstance(batch, SampleBatch):
                    batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                            batch.count)
                tmp.append(batch)
            batches = tmp

            for batch in batches:
                self.replay_buffer.append(batch)
                self.num_steps_sampled += batch.count
                self.buffer_size += batch.count
                while self.buffer_size > self.max_buffer_size:
                    evicted = self.replay_buffer.pop(0)
                    self.buffer_size -= evicted.count

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()
Пример #2
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count
Пример #3
0
 def _replay(self):
     samples = {}
     with self.replay_timer:
         for policy_id, replay_buffer in self.replay_buffers.items():
             if isinstance(replay_buffer, PrioritizedReplayBuffer):
                 (obses_t, actions, rewards, obses_tp1, dones, weights,
                  batch_indexes) = replay_buffer.sample(
                      self.train_batch_size,
                      beta=self.prioritized_replay_beta.value(
                          self.num_steps_trained))
             else:
                 (obses_t, actions, rewards, obses_tp1,
                  dones) = replay_buffer.sample(self.train_batch_size)
                 weights = np.ones_like(rewards)
                 batch_indexes = -np.ones_like(rewards)
             samples[policy_id] = SampleBatch({
                 "obs": obses_t,
                 "actions": actions,
                 "rewards": rewards,
                 "new_obs": obses_tp1,
                 "dones": dones,
                 "weights": weights,
                 "batch_indexes": batch_indexes
             })
     return MultiAgentBatch(samples, self.train_batch_size)
Пример #4
0
def _from_json(batch):
    if isinstance(batch, bytes):  # smart_open S3 doesn't respect "r"
        batch = batch.decode("utf-8")
    data = json.loads(batch)

    if "type" in data:
        data_type = data.pop("type")
    else:
        raise ValueError("JSON record missing 'type' field")

    if data_type == "SampleBatch":
        for k, v in data.items():
            data[k] = unpack_if_needed(v)
        return SampleBatch(data)
    elif data_type == "MultiAgentBatch":
        policy_batches = {}
        for policy_id, policy_batch in data["policy_batches"].items():
            inner = {}
            for k, v in policy_batch.items():
                inner[k] = unpack_if_needed(v)
            policy_batches[policy_id] = SampleBatch(inner)
        return MultiAgentBatch(policy_batches, data["count"])
    else:
        raise ValueError(
            "Type field must be one of ['SampleBatch', 'MultiAgentBatch']",
            data_type)
Пример #5
0
 def add_batch(self, batch):
     # Handle everything as if multiagent
     if isinstance(batch, SampleBatch):
         batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
     self.buffer.append(batch)
     self.cur_size += batch.count
     self.num_added += batch.count
     while self.cur_size > self.buffer_size:
         self.cur_size -= self.buffer.pop(0).count
Пример #6
0
 def add_batch(self, batch):
     # Handle everything as if multiagent
     if isinstance(batch, SampleBatch):
         batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
     with self.add_batch_timer:
         for policy_id, s in batch.policy_batches.items():
             for row in s.rows():
                 self.replay_buffers[policy_id].add(
                     row["obs"], row["actions"], row["rewards"],
                     row["new_obs"], row["dones"], row["weights"])
     self.num_added += batch.count
Пример #7
0
    def build_and_reset(self, episode):
        """Returns the accumulated sample batches for each policy.

        Any unprocessed rows will be first postprocessed with a policy
        postprocessor. The internal state of this builder will be reset.

        Arguments:
            episode: current MultiAgentEpisode object or None
        """

        self.postprocess_batch_so_far(episode)
        policy_batches = {}
        for policy_id, builder in self.policy_builders.items():
            if builder.count > 0:
                policy_batches[policy_id] = builder.build_and_reset()
        old_count = self.count
        self.count = 0
        return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
Пример #8
0
    def replay(self):
        if self.num_added < self.replay_starts:
            return None

        with self.replay_timer:
            samples = {}
            for policy_id, replay_buffer in self.replay_buffers.items():
                (obses_t, actions, rewards, obses_tp1, dones, weights,
                 batch_indexes) = replay_buffer.sample(
                     self.train_batch_size, beta=self.prioritized_replay_beta)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
            return MultiAgentBatch(samples, self.train_batch_size)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batches = ray_get_and_free(
                    [e.sample.remote() for e in self.remote_evaluators])
            else:
                batches = [self.local_evaluator.sample()]

            # Handle everything as if multiagent
            tmp = []
            for batch in batches:
                if isinstance(batch, SampleBatch):
                    batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                            batch.count)
                tmp.append(batch)
            batches = tmp

            for batch in batches:
                if batch.count > self.max_buffer_size:
                    raise ValueError(
                        "The size of a single sample batch exceeds the replay "
                        "buffer size ({} > {})".format(batch.count,
                                                       self.max_buffer_size))
                self.replay_buffer.append(batch)
                self.num_steps_sampled += batch.count
                self.buffer_size += batch.count
                while self.buffer_size > self.max_buffer_size:
                    evicted = self.replay_buffer.pop(0)
                    self.buffer_size -= evicted.count

        if self.num_steps_sampled >= self.replay_starts:
            return self._optimize()
        else:
            return {}
Пример #10
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                # TODO(rliaw): remove when refactoring
                from ray.rllib.agents.ppo.rollout import collect_samples
                samples = collect_samples(self.remote_evaluators,
                                          self.train_batch_size)
            else:
                samples = self.local_evaluator.sample()
            # Handle everything as if multiagent
            if isinstance(samples, SampleBatch):
                samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples},
                                          samples.count)

        for _, batch in samples.policy_batches.items():
            for field in self.standardize_fields:
                value = batch[field]
                standardized = (value - value.mean()) / max(1e-4, value.std())
                batch[field] = standardized

        for policy_id, policy in self.policies.items():
            # Important: don't shuffle RNN sequence elements
            if (policy_id in samples.policy_batches
                    and not policy._state_inputs):
                samples.policy_batches[policy_id].shuffle()

        num_loaded_tuples = {}
        with self.load_timer:
            for policy_id, batch in samples.policy_batches.items():
                policy = self.policies[policy_id]
                tuples = policy._get_loss_inputs_dict(batch)
                data_keys = [ph for _, ph in policy._loss_inputs]
                if policy._state_inputs:
                    state_keys = policy._state_inputs + [policy._seq_lens]
                else:
                    state_keys = []
                num_loaded_tuples[policy_id] = (
                    self.optimizers[policy_id].load_data(
                        self.sess, [tuples[k] for k in data_keys],
                        [tuples[k] for k in state_keys]))

        fetches = {}
        with self.grad_timer:
            for policy_id, tuples_per_device in num_loaded_tuples.items():
                optimizer = self.optimizers[policy_id]
                num_batches = (int(tuples_per_device) //
                               int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                for i in range(self.num_sgd_iter):
                    iter_extra_fetches = defaultdict(list)
                    permutation = np.random.permutation(num_batches)
                    for batch_index in range(num_batches):
                        batch_fetches = optimizer.optimize(
                            self.sess, permutation[batch_index] *
                            self.per_device_batch_size)
                        for k, v in batch_fetches.items():
                            iter_extra_fetches[k].append(v)
                    logger.debug("{} {}".format(i,
                                                _averaged(iter_extra_fetches)))
                fetches[policy_id] = _averaged(iter_extra_fetches)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return fetches
Пример #11
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                if self.straggler_mitigation:
                    samples = collect_samples_straggler_mitigation(
                        self.remote_evaluators, self.train_batch_size)
                else:
                    samples = collect_samples(
                        self.remote_evaluators, self.sample_batch_size,
                        self.num_envs_per_worker, self.train_batch_size)
                if samples.count > self.train_batch_size * 2:
                    logger.info(
                        "Collected more training samples than expected "
                        "(actual={}, train_batch_size={}). ".format(
                            samples.count, self.train_batch_size) +
                        "This may be because you have many workers or "
                        "long episodes in 'complete_episodes' batch mode.")
            else:
                samples = []
                while sum(s.count for s in samples) < self.train_batch_size:
                    samples.append(self.local_evaluator.sample())
                samples = SampleBatch.concat_samples(samples)

            # Handle everything as if multiagent
            if isinstance(samples, SampleBatch):
                samples = MultiAgentBatch({
                    DEFAULT_POLICY_ID: samples
                }, samples.count)

        for policy_id, policy in self.policies.items():
            if policy_id not in samples.policy_batches:
                continue

            batch = samples.policy_batches[policy_id]
            for field in self.standardize_fields:
                value = batch[field]
                standardized = (value - value.mean()) / max(1e-4, value.std())
                batch[field] = standardized

            # Important: don't shuffle RNN sequence elements
            if not policy._state_inputs:
                batch.shuffle()

        num_loaded_tuples = {}
        with self.load_timer:
            for policy_id, batch in samples.policy_batches.items():
                if policy_id not in self.policies:
                    continue

                policy = self.policies[policy_id]
                tuples = policy._get_loss_inputs_dict(batch)
                data_keys = [ph for _, ph in policy._loss_inputs]
                if policy._state_inputs:
                    state_keys = policy._state_inputs + [policy._seq_lens]
                else:
                    state_keys = []
                num_loaded_tuples[policy_id] = (
                    self.optimizers[policy_id].load_data(
                        self.sess, [tuples[k] for k in data_keys],
                        [tuples[k] for k in state_keys]))

        fetches = {}
        with self.grad_timer:
            for policy_id, tuples_per_device in num_loaded_tuples.items():
                optimizer = self.optimizers[policy_id]
                num_batches = max(
                    1,
                    int(tuples_per_device) // int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                for i in range(self.num_sgd_iter):
                    iter_extra_fetches = defaultdict(list)
                    permutation = np.random.permutation(num_batches)
                    for batch_index in range(num_batches):
                        batch_fetches = optimizer.optimize(
                            self.sess, permutation[batch_index] *
                            self.per_device_batch_size)
                        for k, v in batch_fetches[LEARNER_STATS_KEY].items():
                            iter_extra_fetches[k].append(v)
                    logger.debug("{} {}".format(i,
                                                _averaged(iter_extra_fetches)))
                fetches[policy_id] = _averaged(iter_extra_fetches)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += tuples_per_device * len(self.devices)
        self.learner_stats = fetches
        return fetches