def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.workers.remote_workers():
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote()
                            for e in self.workers.remote_workers()
                        ]))
                else:
                    samples.append(self.workers.local_worker().sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        # Unfortunate to have to hack it like this, but not sure how else to do it.
        # Setting the phase to zeros results in policy optimization, and to ones results in aux optimization.
        # These have to be added prior to the policy sgd.
        samples["phase"] = np.zeros(samples.count)

        with self.grad_timer:
            fetches = do_minibatch_sgd(samples, self.policies,
                                       self.workers.local_worker(),
                                       self.num_sgd_iter,
                                       self.sgd_minibatch_size,
                                       self.standardize_fields)
        self.grad_timer.push_units_processed(samples.count)

        if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
            self.learner_stats = fetches[DEFAULT_POLICY_ID]
        else:
            self.learner_stats = fetches

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count

        if self.num_steps_sampled > self.aux_loss_start_after_num_steps:
            # Add samples to the memory to be provided to the aux loss.
            self._remove_unnecessary_data(samples)
            self.memory.append(samples)

            # Optionally run the aux optimization.
            if len(self.memory) >= self.aux_loss_every_k:
                samples = SampleBatch.concat_samples(self.memory)
                self._add_policy_logits(samples)
                # Ones indicate aux phase.
                samples["phase"] = np.ones_like(samples["phase"])
                do_minibatch_sgd(samples, self.policies,
                                 self.workers.local_worker(),
                                 self.aux_loss_num_sgd_iter,
                                 self.sgd_minibatch_size, [])
                self.memory = []

        return self.learner_stats
示例#2
0
    def sample_and_learn(self, expected_batch_size, num_sgd_iter,
                         sgd_minibatch_size, standardize_fields):
        """Sample and batch and learn on it.

        This is typically used in combination with distributed allreduce.

        Arguments:
            expected_batch_size (int): Expected number of samples to learn on.
            num_sgd_iter (int): Number of SGD iterations.
            sgd_minibatch_size (int): SGD minibatch size.
            standardize_fields (list): List of sample fields to normalize.

        Returns:
            info: dictionary of extra metadata from learn_on_batch().
            count: number of samples learned on.
        """
        batch = self.sample()
        assert batch.count == expected_batch_size, \
            ("Batch size possibly out of sync between workers, expected:",
             expected_batch_size, "got:", batch.count)
        logger.info("Executing distributed minibatch SGD "
                    "with epoch size {}, minibatch size {}".format(
                        batch.count, sgd_minibatch_size))
        info = do_minibatch_sgd(batch, self.policy_map, self, num_sgd_iter,
                                sgd_minibatch_size, standardize_fields)
        return info, batch.count
示例#3
0
 def __call__(self,
              batch: SampleBatchType) -> (SampleBatchType, List[dict]):
     _check_sample_batch_type(batch)
     metrics = _get_shared_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
             lw = self.workers.local_worker()
             info = do_minibatch_sgd(
                 batch, {pid: lw.get_policy(pid)
                         for pid in self.policies}, lw, self.num_sgd_iter,
                 self.sgd_minibatch_size, [])
             # TODO(ekl) shouldn't be returning learner stats directly here
             # TODO(sven): Skips `custom_metrics` key from on_learn_on_batch
             #  callback (shouldn't).
             metrics.info[LEARNER_INFO] = info
         else:
             info = self.workers.local_worker().learn_on_batch(batch)
             metrics.info[LEARNER_INFO] = extract_stats(
                 info, LEARNER_STATS_KEY)
             metrics.info["custom_metrics"] = extract_stats(
                 info, "custom_metrics")
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     # Update weights - after learning on the local worker - on all remote
     # workers.
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights(
                 self.policies))
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights, _get_global_vars())
     # Also update global vars of the local worker.
     self.workers.local_worker().set_global_vars(_get_global_vars())
     return batch, info
示例#4
0
 def __call__(self,
              batch: SampleBatchType) -> (SampleBatchType, List[dict]):
     _check_sample_batch_type(batch)
     metrics = _get_shared_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
             w = self.workers.local_worker()
             info = do_minibatch_sgd(
                 batch, {p: w.get_policy(p)
                         for p in self.policies}, w, self.num_sgd_iter,
                 self.sgd_minibatch_size, [])
             # TODO(ekl) shouldn't be returning learner stats directly here
             metrics.info[LEARNER_INFO] = info
         else:
             info = self.workers.local_worker().learn_on_batch(batch)
             metrics.info[LEARNER_INFO] = get_learner_stats(info)
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights(
                 self.policies))
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights, _get_global_vars())
     # Also update global vars of the local worker.
     self.workers.local_worker().set_global_vars(_get_global_vars())
     return batch, info
示例#5
0
    def _sample_and_train_torch_distributed(worker: RolloutWorker):
        # This function is applied remotely on each rollout worker.
        config = worker.policy_config

        # Generate a sample.
        start = time.perf_counter()
        batch = worker.sample()
        sample_time = time.perf_counter() - start
        expected_batch_size = (config["rollout_fragment_length"] *
                               config["num_envs_per_worker"])
        assert batch.count == expected_batch_size, (
            "Batch size possibly out of sync between workers, expected:",
            expected_batch_size,
            "got:",
            batch.count,
        )

        # Perform n minibatch SGD update(s) on the worker itself.
        start = time.perf_counter()
        info = do_minibatch_sgd(
            batch,
            worker.policy_map,
            worker,
            config["num_sgd_iter"],
            config["sgd_minibatch_size"],
            [Postprocessing.ADVANTAGES],
        )
        learn_on_batch_time = time.perf_counter() - start
        return {
            "info": info,
            "env_steps": batch.env_steps(),
            "agent_steps": batch.agent_steps(),
            "sample_time": sample_time,
            "learn_on_batch_time": learn_on_batch_time,
        }
示例#6
0
def train_one_step(trainer, train_batch, policies_to_train=None) -> Dict:
    config = trainer.config
    workers = trainer.workers
    local_worker = workers.local_worker()
    num_sgd_iter = config.get("num_sgd_iter", 1)
    sgd_minibatch_size = config.get("sgd_minibatch_size", 0)

    learn_timer = trainer._timers[LEARN_ON_BATCH_TIMER]
    with learn_timer:
        # Subsample minibatches (size=`sgd_minibatch_size`) from the
        # train batch and loop through train batch `num_sgd_iter` times.
        if num_sgd_iter > 1 or sgd_minibatch_size > 0:
            info = do_minibatch_sgd(
                train_batch,
                {
                    pid: local_worker.get_policy(pid)
                    for pid in policies_to_train
                    or local_worker.get_policies_to_train(train_batch)
                },
                local_worker,
                num_sgd_iter,
                sgd_minibatch_size,
                [],
            )
        # Single update step using train batch.
        else:
            info = local_worker.learn_on_batch(train_batch)

    learn_timer.push_units_processed(train_batch.count)
    trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
    trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()

    return info
示例#7
0
 def train_torch_distributed_allreduce(batch):
     expected_batch_size = (
         config["rollout_fragment_length"] * config["num_envs_per_worker"]
     )
     this_worker = get_global_worker()
     assert batch.count == expected_batch_size, (
         "Batch size possibly out of sync between workers, expected:",
         expected_batch_size,
         "got:",
         batch.count,
     )
     logger.info(
         "Executing distributed minibatch SGD "
         "with epoch size {}, minibatch size {}".format(
             batch.count, config["sgd_minibatch_size"]
         )
     )
     info = do_minibatch_sgd(
         batch,
         this_worker.policy_map,
         this_worker,
         config["num_sgd_iter"],
         config["sgd_minibatch_size"],
         ["advantages"],
     )
     return info, batch.count
示例#8
0
def train_one_step(trainer, train_batch) -> Dict:
    config = trainer.config
    workers = trainer.workers
    local_worker = workers.local_worker()
    policies = local_worker.policies_to_train
    num_sgd_iter = config.get("sgd_num_iter", 1)
    sgd_minibatch_size = config.get("sgd_minibatch_size", 0)

    learn_timer = trainer._timers[LEARN_ON_BATCH_TIMER]
    with learn_timer:
        # Subsample minibatches (size=`sgd_minibatch_size`) from the
        # train batch and loop through train batch `num_sgd_iter` times.
        if num_sgd_iter > 1 or sgd_minibatch_size > 0:
            info = do_minibatch_sgd(
                train_batch,
                {pid: local_worker.get_policy(pid)
                 for pid in policies}, local_worker, num_sgd_iter,
                sgd_minibatch_size, [])
        # Single update step using train batch.
        else:
            info = local_worker.learn_on_batch(train_batch)

    learn_timer.push_units_processed(train_batch.count)
    trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
    trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()

    # Update weights - after learning on the local worker - on all remote
    # workers.
    if workers.remote_workers():
        with trainer._timers[WORKER_UPDATE_TIMER]:
            weights = ray.put(workers.local_worker().get_weights(policies))
            for e in workers.remote_workers():
                e.set_weights.remote(weights)
    return info
示例#9
0
def train_one_step(algorithm, train_batch, policies_to_train=None) -> Dict:
    """Function that improves the all policies in `train_batch` on the local worker.

    Examples:
        >>> from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
        >>> algo = [...] # doctest: +SKIP
        >>> train_batch = synchronous_parallel_sample(algo.workers) # doctest: +SKIP
        >>> # This trains the policy on one batch.
        >>> results = train_one_step(algo, train_batch)) # doctest: +SKIP
        {"default_policy": ...}

    Updates the NUM_ENV_STEPS_TRAINED and NUM_AGENT_STEPS_TRAINED counters as well as
    the LEARN_ON_BATCH_TIMER timer of the `algorithm` object.
    """

    config = algorithm.config
    workers = algorithm.workers
    local_worker = workers.local_worker()
    num_sgd_iter = config.get("num_sgd_iter", 1)
    sgd_minibatch_size = config.get("sgd_minibatch_size", 0)

    learn_timer = algorithm._timers[LEARN_ON_BATCH_TIMER]
    with learn_timer:
        # Subsample minibatches (size=`sgd_minibatch_size`) from the
        # train batch and loop through train batch `num_sgd_iter` times.
        if num_sgd_iter > 1 or sgd_minibatch_size > 0:
            info = do_minibatch_sgd(
                train_batch,
                {
                    pid: local_worker.get_policy(pid)
                    for pid in policies_to_train
                    or local_worker.get_policies_to_train(train_batch)
                },
                local_worker,
                num_sgd_iter,
                sgd_minibatch_size,
                [],
            )
        # Single update step using train batch.
        else:
            info = local_worker.learn_on_batch(train_batch)

    learn_timer.push_units_processed(train_batch.count)
    algorithm._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
    algorithm._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()

    if algorithm.reward_estimators:
        info[DEFAULT_POLICY_ID]["off_policy_estimation"] = {}
        for name, estimator in algorithm.reward_estimators.items():
            info[DEFAULT_POLICY_ID]["off_policy_estimation"][
                name] = estimator.train(train_batch)
    return info
示例#10
0
    def __call__(self,
                 batch: SampleBatchType) -> (SampleBatchType, List[dict]):
        _check_sample_batch_type(batch)
        metrics = _get_shared_metrics()
        learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
        lw = self.local_worker
        with learn_timer:
            # Subsample minibatches (size=`sgd_minibatch_size`) from the
            # train batch and loop through train batch `num_sgd_iter` times.
            if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
                learner_info = do_minibatch_sgd(
                    batch,
                    {
                        pid: lw.get_policy(pid)
                        for pid in self.policies
                        or lw.get_policies_to_train(batch)
                    },
                    lw,
                    self.num_sgd_iter,
                    self.sgd_minibatch_size,
                    [],
                )
            # Single update step using train batch.
            else:
                learner_info = lw.learn_on_batch(batch)

            metrics.info[LEARNER_INFO] = learner_info
            learn_timer.push_units_processed(batch.count)
        metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
        metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = batch.count
        if isinstance(batch, MultiAgentBatch):
            metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps(
            )
        # Update weights - after learning on the local worker - on all remote
        # workers.
        if self.workers.remote_workers():
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = ray.put(
                    lw.get_weights(self.policies
                                   or lw.get_policies_to_train(batch)))
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights, _get_global_vars())
        # Also update global vars of the local worker.
        lw.set_global_vars(_get_global_vars())
        return batch, learner_info
示例#11
0
    def train_on_batch(samples):
        if isinstance(samples, SampleBatch):
            samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples},
                                      samples.count)

        worker = get_global_worker()

        if not hasattr(worker, 'num_iterations_trained'):
            worker.num_iterations_trained = 0

        if num_sgd_iter > 1:
            info = do_minibatch_sgd(samples, {
                pid: worker.get_policy(pid)
                for pid in worker.policies_to_train
            }, worker, num_sgd_iter, sgd_minibatch_size, [])
        else:
            info = worker.learn_on_batch(samples)

        worker.num_iterations_trained += 1
        info['num_iterations_trained'] = worker.num_iterations_trained

        return info, samples.count, num_sgd_iter
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.workers.remote_workers():
                    samples.extend(
                        ray.get([
                            e.sample.remote()
                            for e in self.workers.remote_workers()
                        ]))
                else:
                    samples.append(self.workers.local_worker().sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            fetches = do_minibatch_sgd(samples, self.policies,
                                       self.workers.local_worker(),
                                       self.num_sgd_iter,
                                       self.sgd_minibatch_size,
                                       self.standardize_fields)
        self.grad_timer.push_units_processed(samples.count)

        if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
            self.learner_stats = fetches[DEFAULT_POLICY_ID]
        else:
            self.learner_stats = fetches
        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return self.learner_stats