Example #1
0
    def __call__(self, item: Tuple[ModelGradients, int]) -> None:
        if not isinstance(item, tuple) or len(item) != 2:
            raise ValueError(
                "Input must be a tuple of (grad_dict, count), got {}".format(
                    item))
        gradients, count = item
        metrics = _get_shared_metrics()
        metrics.counters[STEPS_TRAINED_COUNTER] += count

        apply_timer = metrics.timers[APPLY_GRADS_TIMER]
        with apply_timer:
            self.workers.local_worker().apply_gradients(gradients)
            apply_timer.push_units_processed(count)

        # Also update global vars of the local worker.
        self.workers.local_worker().set_global_vars(_get_global_vars())

        if self.update_all:
            if self.workers.remote_workers():
                with metrics.timers[WORKER_UPDATE_TIMER]:
                    weights = ray.put(self.workers.local_worker().get_weights(
                        self.policies))
                    for e in self.workers.remote_workers():
                        e.set_weights.remote(weights, _get_global_vars())
        else:
            if metrics.current_actor is None:
                raise ValueError(
                    "Could not find actor to update. When "
                    "update_all=False, `current_actor` must be set "
                    "in the iterator context.")
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = self.workers.local_worker().get_weights(
                    self.policies)
                metrics.current_actor.set_weights.remote(
                    weights, _get_global_vars())
Example #2
0
 def __call__(self,
              batch: SampleBatchType) -> (SampleBatchType, List[dict]):
     _check_sample_batch_type(batch)
     metrics = _get_shared_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
             lw = self.workers.local_worker()
             info = do_minibatch_sgd(
                 batch, {pid: lw.get_policy(pid)
                         for pid in self.policies}, lw, self.num_sgd_iter,
                 self.sgd_minibatch_size, [])
             # TODO(ekl) shouldn't be returning learner stats directly here
             # TODO(sven): Skips `custom_metrics` key from on_learn_on_batch
             #  callback (shouldn't).
             metrics.info[LEARNER_INFO] = info
         else:
             info = self.workers.local_worker().learn_on_batch(batch)
             metrics.info[LEARNER_INFO] = extract_stats(
                 info, LEARNER_STATS_KEY)
             metrics.info["custom_metrics"] = extract_stats(
                 info, "custom_metrics")
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     # Update weights - after learning on the local worker - on all remote
     # workers.
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights(
                 self.policies))
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights, _get_global_vars())
     # Also update global vars of the local worker.
     self.workers.local_worker().set_global_vars(_get_global_vars())
     return batch, info
Example #3
0
 def __call__(self,
              batch: SampleBatchType) -> (SampleBatchType, List[dict]):
     _check_sample_batch_type(batch)
     metrics = _get_shared_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
             w = self.workers.local_worker()
             info = do_minibatch_sgd(
                 batch, {p: w.get_policy(p)
                         for p in self.policies}, w, self.num_sgd_iter,
                 self.sgd_minibatch_size, [])
             # TODO(ekl) shouldn't be returning learner stats directly here
             metrics.info[LEARNER_INFO] = info
         else:
             info = self.workers.local_worker().learn_on_batch(batch)
             metrics.info[LEARNER_INFO] = get_learner_stats(info)
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights(
                 self.policies))
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights, _get_global_vars())
     # Also update global vars of the local worker.
     self.workers.local_worker().set_global_vars(_get_global_vars())
     return batch, info
Example #4
0
 def __call__(self, item):
     actor, batch = item
     self.steps_since_broadcast += 1
     if (self.steps_since_broadcast >= self.broadcast_interval
             and self.learner_thread.weights_updated):
         self.weights = ray.put(self.workers.local_worker().get_weights())
         self.steps_since_broadcast = 0
         self.learner_thread.weights_updated = False
         # Update metrics.
         metrics = _get_shared_metrics()
         metrics.counters["num_weight_broadcasts"] += 1
     actor.set_weights.remote(self.weights, _get_global_vars())
     # Also update global vars of the local worker.
     self.workers.local_worker().set_global_vars(_get_global_vars())
Example #5
0
    def __call__(self, item):
        actor, (info, samples, training_steps) = item

        metrics = _get_shared_metrics()

        metrics.counters[STEPS_TRAINED_COUNTER] += training_steps
        metrics.counters[STEPS_SAMPLED_COUNTER] += samples

        self.counters[actor] += 1

        metrics.counters[
            f"WorkerIteration/Worker{self.worker_idx[actor]}"] += 1

        global_vars = _get_global_vars()
        self.workers.local_worker().set_global_vars(global_vars)
        actor.set_global_vars.remote(global_vars)

        if self.counters[actor] % self.broadcast_interval == 0:
            metrics.counters["num_weight_broadcasts"] += 1

            with metrics.timers[WORKER_UPDATE_TIMER]:
                for pid, gw in self.global_weights.items():

                    def update_worker(w, alpha):
                        return w.policy_map[pid].easgd_update(gw, alpha)

                    diff = ray.get(
                        actor.apply.remote(update_worker, self.alpha))
                    self.global_weights[pid] = EASGDUpdate.easgd_add(
                        gw, diff, self.alpha)

        return info
Example #6
0
File: apex.py Project: alipay/ray
 def __call__(self, item: Tuple[ActorHandle, SampleBatchType]):
     actor, batch = item
     self.steps_since_update[actor] += batch.count
     if self.steps_since_update[actor] >= self.max_weight_sync_delay:
         # Note that it's important to pull new weights once
         # updated to avoid excessive correlation between actors.
         if self.weights is None or self.learner_thread.weights_updated:
             self.learner_thread.weights_updated = False
             self.weights = ray.put(
                 self.workers.local_worker().get_weights())
         actor.set_weights.remote(self.weights, _get_global_vars())
         # Also update global vars of the local worker.
         self.workers.local_worker().set_global_vars(_get_global_vars())
         self.steps_since_update[actor] = 0
         # Update metrics.
         metrics = _get_shared_metrics()
         metrics.counters["num_weight_syncs"] += 1
Example #7
0
 def __call__(self, batch: SampleBatchType) -> List[dict]:
     _check_sample_batch_type(batch)
     metrics = LocalIterator.get_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         info = self.workers.local_worker().learn_on_batch(batch)
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     metrics.info[LEARNER_INFO] = get_learner_stats(info)
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights())
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights, _get_global_vars())
     # Also update global vars of the local worker.
     self.workers.local_worker().set_global_vars(_get_global_vars())
     return info
Example #8
0
    def __call__(self,
                 batch: SampleBatchType) -> (SampleBatchType, List[dict]):
        _check_sample_batch_type(batch)
        metrics = _get_shared_metrics()
        learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
        lw = self.local_worker
        with learn_timer:
            # Subsample minibatches (size=`sgd_minibatch_size`) from the
            # train batch and loop through train batch `num_sgd_iter` times.
            if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
                learner_info = do_minibatch_sgd(
                    batch,
                    {
                        pid: lw.get_policy(pid)
                        for pid in self.policies
                        or lw.get_policies_to_train(batch)
                    },
                    lw,
                    self.num_sgd_iter,
                    self.sgd_minibatch_size,
                    [],
                )
            # Single update step using train batch.
            else:
                learner_info = lw.learn_on_batch(batch)

            metrics.info[LEARNER_INFO] = learner_info
            learn_timer.push_units_processed(batch.count)
        metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
        metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = batch.count
        if isinstance(batch, MultiAgentBatch):
            metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps(
            )
        # Update weights - after learning on the local worker - on all remote
        # workers.
        if self.workers.remote_workers():
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = ray.put(
                    lw.get_weights(self.policies
                                   or lw.get_policies_to_train(batch)))
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights, _get_global_vars())
        # Also update global vars of the local worker.
        lw.set_global_vars(_get_global_vars())
        return batch, learner_info
Example #9
0
    def __call__(self, item):
        actor, (updates, info, samples, training_steps) = item

        if self.counters[actor] > min(self.counters.values()) + 50: return {}

        lw = self.workers.local_worker()

        metrics = _get_shared_metrics()

        metrics.counters[STEPS_TRAINED_COUNTER] += training_steps
        metrics.counters[STEPS_SAMPLED_COUNTER] += samples

        self.counters[actor] += 1
        metrics.counters[
            f"WorkerIteration/Worker{self.worker_idx[actor]}"] += 1

        global_vars = _get_global_vars()
        lw.set_global_vars(global_vars)
        actor.set_global_vars.remote(global_vars)

        with metrics.timers[WORKER_UPDATE_TIMER]:
            for pid, update in updates.items():

                def sync_update(w, update):
                    w.policy_map[pid].asp_sync_updates(update)

                update, num_significant = update

                if lw != actor: sync_update(lw, update)

                if self.workers.remote_workers():
                    for e in self.workers.remote_workers():
                        if e != actor: e.apply.remote(sync_update, update)

                metrics.counters[
                    "significant_weight_updates"] += num_significant

        return info
Example #10
0
    def __call__(self,
                 samples: SampleBatchType) -> (SampleBatchType, List[dict]):
        _check_sample_batch_type(samples)

        # Handle everything as if multiagent
        if isinstance(samples, SampleBatch):
            samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples},
                                      samples.count)

        metrics = _get_shared_metrics()
        load_timer = metrics.timers[LOAD_BATCH_TIMER]
        learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
        # Load data into GPUs.
        with load_timer:
            num_loaded_tuples = {}
            for policy_id, batch in samples.policy_batches.items():
                # Not a policy-to-train.
                if policy_id not in self.policies:
                    continue

                # Decompress SampleBatch, in case some columns are compressed.
                batch.decompress_if_needed()

                policy = self.workers.local_worker().get_policy(policy_id)
                policy._debug_vars()
                tuples = policy._get_loss_inputs_dict(
                    batch, shuffle=self.shuffle_sequences)
                data_keys = list(policy._loss_input_dict_no_rnn.values())
                if policy._state_inputs:
                    state_keys = policy._state_inputs + [policy._seq_lens]
                else:
                    state_keys = []
                num_loaded_tuples[policy_id] = (
                    self.optimizers[policy_id].load_data(
                        self.sess, [tuples[k] for k in data_keys],
                        [tuples[k] for k in state_keys]))

        # Execute minibatch SGD on loaded data.
        with learn_timer:
            fetches = {}
            for policy_id, tuples_per_device in num_loaded_tuples.items():
                optimizer = self.optimizers[policy_id]
                num_batches = max(
                    1,
                    int(tuples_per_device) // int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                for _ in range(self.num_sgd_iter):
                    permutation = np.random.permutation(num_batches)
                    batch_fetches_all_towers = []
                    for batch_index in range(num_batches):
                        batch_fetches = optimizer.optimize(
                            self.sess, permutation[batch_index] *
                            self.per_device_batch_size)

                        batch_fetches_all_towers.append(
                            tree.map_structure_with_path(
                                lambda p, *s: self._all_tower_reduce(p, *s),
                                *(batch_fetches["tower_{}".format(tower_num)]
                                  for tower_num in range(len(self.devices)))))

                # Reduce mean across all minibatch SGD steps (axis=0 to keep
                # all shapes as-is).
                fetches[policy_id] = tree.map_structure(
                    lambda *s: np.nanmean(s, axis=0),
                    *batch_fetches_all_towers)

        load_timer.push_units_processed(samples.count)
        learn_timer.push_units_processed(samples.count)

        metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
        metrics.info[LEARNER_INFO] = fetches
        if self.workers.remote_workers():
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = ray.put(self.workers.local_worker().get_weights(
                    self.policies))
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights, _get_global_vars())
        # Also update global vars of the local worker.
        self.workers.local_worker().set_global_vars(_get_global_vars())
        return samples, fetches
Example #11
0
    def __call__(self,
                 samples: SampleBatchType) -> (SampleBatchType, List[dict]):
        _check_sample_batch_type(samples)

        # Handle everything as if multiagent
        if isinstance(samples, SampleBatch):
            samples = MultiAgentBatch({
                DEFAULT_POLICY_ID: samples
            }, samples.count)

        metrics = _get_shared_metrics()
        load_timer = metrics.timers[LOAD_BATCH_TIMER]
        learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
        with load_timer:
            # (1) Load data into GPUs.
            num_loaded_tuples = {}
            for policy_id, batch in samples.policy_batches.items():
                if policy_id not in self.policies:
                    continue

                policy = self.workers.local_worker().get_policy(policy_id)
                policy._debug_vars()
                tuples = policy._get_loss_inputs_dict(
                    batch, shuffle=self.shuffle_sequences)
                data_keys = list(policy._loss_input_dict_no_rnn.values())
                if policy._state_inputs:
                    state_keys = policy._state_inputs + [policy._seq_lens]
                else:
                    state_keys = []
                num_loaded_tuples[policy_id] = (
                    self.optimizers[policy_id].load_data(
                        self.sess, [tuples[k] for k in data_keys],
                        [tuples[k] for k in state_keys]))

        with learn_timer:
            # (2) Execute minibatch SGD on loaded data.
            fetches = {}
            for policy_id, tuples_per_device in num_loaded_tuples.items():
                optimizer = self.optimizers[policy_id]
                num_batches = max(
                    1,
                    int(tuples_per_device) // int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                for i in range(self.num_sgd_iter):
                    iter_extra_fetches = defaultdict(list)
                    permutation = np.random.permutation(num_batches)
                    for batch_index in range(num_batches):
                        batch_fetches = optimizer.optimize(
                            self.sess, permutation[batch_index] *
                            self.per_device_batch_size)
                        for k, v in batch_fetches[LEARNER_STATS_KEY].items():
                            iter_extra_fetches[k].append(v)
                    if logger.getEffectiveLevel() <= logging.DEBUG:
                        avg = averaged(iter_extra_fetches)
                        logger.debug("{} {}".format(i, avg))
                fetches[policy_id] = averaged(iter_extra_fetches, axis=0)

        load_timer.push_units_processed(samples.count)
        learn_timer.push_units_processed(samples.count)

        metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
        metrics.info[LEARNER_INFO] = fetches
        if self.workers.remote_workers():
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = ray.put(self.workers.local_worker().get_weights(
                    self.policies))
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights, _get_global_vars())
        # Also update global vars of the local worker.
        self.workers.local_worker().set_global_vars(_get_global_vars())
        return samples, fetches
Example #12
0
    def __call__(self,
                 samples: SampleBatchType) -> (SampleBatchType, List[dict]):
        _check_sample_batch_type(samples)

        # Handle everything as if multi agent.
        samples = samples.as_multi_agent()

        metrics = _get_shared_metrics()
        load_timer = metrics.timers[LOAD_BATCH_TIMER]
        learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
        # Load data into GPUs.
        with load_timer:
            num_loaded_samples = {}
            for policy_id, batch in samples.policy_batches.items():
                # Not a policy-to-train.
                if not self.local_worker.is_policy_to_train(
                        policy_id, samples):
                    continue

                # Decompress SampleBatch, in case some columns are compressed.
                batch.decompress_if_needed()

                # Load the entire train batch into the Policy's only buffer
                # (idx=0). Policies only have >1 buffers, if we are training
                # asynchronously.
                num_loaded_samples[policy_id] = self.local_worker.policy_map[
                    policy_id].load_batch_into_buffer(batch, buffer_index=0)

        # Execute minibatch SGD on loaded data.
        with learn_timer:
            # Use LearnerInfoBuilder as a unified way to build the final
            # results dict from `learn_on_loaded_batch` call(s).
            # This makes sure results dicts always have the same structure
            # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
            # tf vs torch).
            learner_info_builder = LearnerInfoBuilder(
                num_devices=len(self.devices))

            for policy_id, samples_per_device in num_loaded_samples.items():
                policy = self.local_worker.policy_map[policy_id]
                num_batches = max(
                    1,
                    int(samples_per_device) // int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                for _ in range(self.num_sgd_iter):
                    permutation = np.random.permutation(num_batches)
                    for batch_index in range(num_batches):
                        # Learn on the pre-loaded data in the buffer.
                        # Note: For minibatch SGD, the data is an offset into
                        # the pre-loaded entire train batch.
                        results = policy.learn_on_loaded_batch(
                            permutation[batch_index] *
                            self.per_device_batch_size,
                            buffer_index=0,
                        )

                        learner_info_builder.add_learn_on_batch_results(
                            results, policy_id)

            # Tower reduce and finalize results.
            learner_info = learner_info_builder.finalize()

        load_timer.push_units_processed(samples.count)
        learn_timer.push_units_processed(samples.count)

        metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
        metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count
        metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps()
        metrics.info[LEARNER_INFO] = learner_info

        if self.workers.remote_workers():
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = ray.put(self.workers.local_worker().get_weights(
                    self.local_worker.get_policies_to_train()))
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights, _get_global_vars())

        # Also update global vars of the local worker.
        self.workers.local_worker().set_global_vars(_get_global_vars())
        return samples, learner_info
Example #13
0
 def update_worker_global_vars(item):
     global_vars = _get_global_vars()
     for w in workers.remote_workers():
         w.set_global_vars.remote(global_vars)
     return item
Example #14
0
    def __call__(self,
                 samples: SampleBatchType) -> (SampleBatchType, List[dict]):
        _check_sample_batch_type(samples)

        # Handle everything as if multi agent.
        if isinstance(samples, SampleBatch):
            samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples},
                                      samples.count)

        metrics = _get_shared_metrics()
        load_timer = metrics.timers[LOAD_BATCH_TIMER]
        learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
        # Load data into GPUs.
        with load_timer:
            num_loaded_tuples = {}
            for policy_id, batch in samples.policy_batches.items():
                # Not a policy-to-train.
                if policy_id not in self.local_worker.policies_to_train:
                    continue

                # Decompress SampleBatch, in case some columns are compressed.
                batch.decompress_if_needed()

                # Load the entire train batch into the Policy's only buffer
                # (idx=0). Policies only have >1 buffers, if we are training
                # asynchronously.
                num_loaded_tuples[policy_id] = self.local_worker.policy_map[
                    policy_id].load_batch_into_buffer(batch, buffer_index=0)

        # Execute minibatch SGD on loaded data.
        with learn_timer:
            fetches = {}
            for policy_id, tuples_per_device in num_loaded_tuples.items():
                policy = self.local_worker.policy_map[policy_id]
                num_batches = max(
                    1,
                    int(tuples_per_device) // int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                batch_fetches_all_towers = []
                for _ in range(self.num_sgd_iter):
                    permutation = np.random.permutation(num_batches)
                    for batch_index in range(num_batches):
                        # Learn on the pre-loaded data in the buffer.
                        # Note: For minibatch SGD, the data is an offset into
                        # the pre-loaded entire train batch.
                        batch_fetches = policy.learn_on_loaded_batch(
                            permutation[batch_index] *
                            self.per_device_batch_size,
                            buffer_index=0)

                        # No towers: Single CPU.
                        if "tower_0" not in batch_fetches:
                            batch_fetches_all_towers.append(batch_fetches)
                        else:
                            batch_fetches_all_towers.append(
                                tree.map_structure_with_path(
                                    lambda p, *s: all_tower_reduce(p, *s),
                                    *(batch_fetches["tower_{}".format(
                                        tower_num)]
                                      for tower_num in range(len(self.devices))
                                      )))

                # Reduce mean across all minibatch SGD steps (axis=0 to keep
                # all shapes as-is).
                fetches[policy_id] = tree.map_structure(
                    lambda *s: np.nanmean(s, axis=0),
                    *batch_fetches_all_towers)

        load_timer.push_units_processed(samples.count)
        learn_timer.push_units_processed(samples.count)

        metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
        metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps()
        metrics.info[LEARNER_INFO] = fetches

        if self.workers.remote_workers():
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = ray.put(self.workers.local_worker().get_weights(
                    self.local_worker.policies_to_train))
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights, _get_global_vars())

        # Also update global vars of the local worker.
        self.workers.local_worker().set_global_vars(_get_global_vars())
        return samples, fetches