Пример #1
0
def test_writing():
    config = {'root_dir': '../data/wiki_test'}
    env = Writing(config)

    class TestAgent(object):
        def __init__(self):
            self.msg_queue = [1]

        def push_messages(self, msg):
            self.msg_queue.append(msg)

        def pull_messages(self):
            return self.msg_queue

    ray.init()
    agent = ray.remote(TestAgent).remote()
    agent_id = env.register_agent(agent)

    test_msg = 'Hi there'
    res = env.send_messages(agent_id, test_msg)
    ray_get_and_free(res)
    assert ray.get(agent.pull_messages.remote()).pop() == test_msg

    obs = env.reset()
    print(obs)
    results = env.step('I want to send a message')
    print(env.write)
    print(results)
        def _do_live_policy_checkpoint(trainer, training_iteration):
            local_train_policy = trainer.workers.local_worker(
            ).policy_map[TRAIN_POLICY]
            checkpoints_dir = os.path.join(experiment_save_dir,
                                           "policy_checkpoints")
            checkpoint_name = f"policy_{trainer.claimed_policy_num}_{datetime_str()}_iter_{training_iteration}.dill"
            checkpoint_save_path = os.path.join(checkpoints_dir,
                                                checkpoint_name)
            local_train_policy.save_model_weights(
                save_file_path=checkpoint_save_path,
                remove_scope_prefix=TRAIN_POLICY)
            policy_key = os.path.join(base_experiment_name,
                                      full_experiment_name,
                                      "policy_checkpoints", checkpoint_name)
            storage_client = connect_storage_client()
            upload_file(storage_client=storage_client,
                        bucket_name=BUCKET_NAME,
                        object_key=policy_key,
                        local_source_path=checkpoint_save_path)

            locks_checkpoint_name = f"dch_population_checkpoint_{datetime_str()}"

            ray_get_and_free(
                trainer.live_table_tracker.set_latest_key_for_claimed_policy.
                remote(
                    new_key=policy_key,
                    request_locks_checkpoint_with_name=locks_checkpoint_name))
        def claim_new_active_policy_after_trainer_init_callback(trainer):
            def set_train_policy_warmup_target_entropy_proportion(worker):
                worker.policy_map[TRAIN_POLICY].set_target_entropy_proportion(
                    PIPELINE_WARMUP_ENTROPY_TARGET_PROPORTION)

            trainer.workers.foreach_worker(
                set_train_policy_warmup_target_entropy_proportion)

            trainer.storage_client = connect_storage_client()

            logger.info("Initializing trainer manager interface")
            trainer.manager_interface = LearnerManagerInterface(
                server_host=MANAGER_SERVER_HOST,
                port=MANAGER_PORT,
                worker_id=full_experiment_name,
                storage_client=trainer.storage_client,
                minio_bucket_name=BUCKET_NAME)

            trainer.live_table_tracker = LivePolicyPayoffTracker.remote(
                minio_endpoint=MINIO_ENDPOINT,
                minio_access_key=MINIO_ACCESS_KEY,
                minio_secret_key=MINIO_SECRET_KEY,
                minio_bucket=BUCKET_NAME,
                manager_host=MANAGER_SERVER_HOST,
                manager_port=MANAGER_PORT,
                lock_server_host=LOCK_SERVER_HOST,
                lock_server_port=LOCK_SERVER_PORT,
                worker_id=full_experiment_name,
                policy_class_name=TRAIN_POLICY_CLASS.__name__,
                policy_config_key=TRAIN_POLICY_MODEL_CONFIG_KEY,
                provide_payoff_barrier_sync=
                not PIPELINE_LIVE_PAYOFF_TABLE_CALC_IS_ASYNCHRONOUS)
            trainer.claimed_policy_num = ray_get_and_free(
                trainer.live_table_tracker.get_claimed_policy_num.remote())
            trainer.are_all_lower_policies_finished = False
            trainer.payoff_table_needs_update_started = False
            trainer.payoff_table = None
            _do_live_policy_checkpoint(trainer=trainer, training_iteration=0)

            if not PIPELINE_LIVE_PAYOFF_TABLE_CALC_IS_ASYNCHRONOUS:
                # wait for all other learners to also reach this point before continuing
                ray_get_and_free(trainer.live_table_tracker.
                                 wait_at_barrier_for_other_learners.remote())

            trainer.new_payoff_table_promise = trainer.live_table_tracker.get_live_payoff_table_dill_pickled.remote(
                first_wait_for_n_seconds=2)
            _process_new_live_payoff_table_result_if_ready(
                trainer=trainer, block_until_result_is_ready=True)

            if INIT_FROM_POPULATION:
                init_train_policy_weights_from_static_policy_distribution_after_trainer_init_callback(
                    trainer=trainer)
            else:
                print(
                    colored(
                        f"Policy {trainer.claimed_policy_num}: (Initializing train policy to random)",
                        "white"))
    def _step(self):
        sample_timesteps, train_timesteps = 0, 0
        weights = None

        with self.timers["sample_processing"]:
            completed = list(self.sample_tasks.completed())
            counts = ray_get_and_free([c[1][1] for c in completed])
            for i, (ev, (sample_batch, count)) in enumerate(completed):
                sample_timesteps += counts[i]

                # Send the data to the replay buffer
                random.choice(
                    self.replay_actors).add_batch.remote(sample_batch)

                # Update weights if needed
                self.steps_since_update[ev] += counts[i]
                if self.steps_since_update[ev] >= self.max_weight_sync_delay:
                    # Note that it's important to pull new weights once
                    # updated to avoid excessive correlation between actors
                    if weights is None or self.learner.weights_updated:
                        self.learner.weights_updated = False
                        with self.timers["put_weights"]:
                            weights = ray.put(
                                self.workers.local_worker().get_weights())
                    ev.set_weights.remote(weights)
                    self.num_weight_syncs += 1
                    self.steps_since_update[ev] = 0

                # Kick off another sample request
                self.sample_tasks.add(ev, ev.sample_with_count.remote())

                # added for dynamic experience replay
                if self.dynamic_experience_replay and random.random() < 0.0001:
                    random.choice(self.replay_actors).update_demos.remote()

        with self.timers["replay_processing"]:
            for ra, replay in self.replay_tasks.completed():
                self.replay_tasks.add(ra, ra.replay.remote())
                if self.learner.inqueue.full():
                    self.num_samples_dropped += 1
                else:
                    with self.timers["get_samples"]:
                        samples = ray_get_and_free(replay)
                    # Defensive copy against plasma crashes, see #2610 #3452
                    self.learner.inqueue.put((ra, samples and samples.copy()))

        with self.timers["update_priorities"]:
            while not self.learner.outqueue.empty():
                ra, prio_dict, count = self.learner.outqueue.get()
                ra.update_priorities.remote(prio_dict)
                train_timesteps += count

        return sample_timesteps, train_timesteps
Пример #5
0
    def _try_recover(self):
        """Try to identify and blacklist any unhealthy workers.

        This method is called after an unexpected remote error is encountered
        from a worker. It issues check requests to all current workers and
        blacklists any that respond with error. If no healthy workers remain,
        an error is raised.
        """

        if (not self._has_policy_optimizer()
                and not hasattr(self, "execution_plan")):
            raise NotImplementedError(
                "Recovery is not supported for this algorithm")
        if self._has_policy_optimizer():
            workers = self.optimizer.workers
        else:
            assert hasattr(self, "execution_plan")
            workers = self.workers

        logger.info("Health checking all workers...")
        checks = []
        for ev in workers.remote_workers():
            _, obj_id = ev.sample_with_count.remote()
            checks.append(obj_id)

        healthy_workers = []
        for i, obj_id in enumerate(checks):
            w = workers.remote_workers()[i]
            try:
                ray_get_and_free(obj_id)
                healthy_workers.append(w)
                logger.info("Worker {} looks healthy".format(i + 1))
            except RayError:
                logger.exception("Blacklisting worker {}".format(i + 1))
                try:
                    w.__ray_terminate__.remote()
                except Exception:
                    logger.exception("Error terminating unhealthy worker")

        if len(healthy_workers) < 1:
            raise RuntimeError(
                "Not enough healthy workers remain to continue.")

        if self._has_policy_optimizer():
            self.optimizer.reset(healthy_workers)
        else:
            assert hasattr(self, "execution_plan")
            logger.warning("Recreating execution plan after failure")
            workers.reset(healthy_workers)
            self.train_exec_impl = self.execution_plan(workers, self.config)
Пример #6
0
def collect_samples(agents, sample_batch_size, num_envs_per_worker,
                    train_batch_size):
    """Collects at least train_batch_size samples, never discarding any."""

    num_timesteps_so_far = 0
    trajectories = []
    agent_dict = {}

    for agent in agents:
        fut_sample = agent.sample.remote()
        agent_dict[fut_sample] = agent

    while agent_dict:
        [fut_sample], _ = ray.wait(list(agent_dict))
        agent = agent_dict.pop(fut_sample)
        next_sample = ray_get_and_free(fut_sample)
        num_timesteps_so_far += next_sample.count
        trajectories.append(next_sample)

        # Only launch more tasks if we don't already have enough pending
        pending = len(agent_dict) * sample_batch_size * num_envs_per_worker
        if num_timesteps_so_far + pending < train_batch_size:
            fut_sample2 = agent.sample.remote()
            agent_dict[fut_sample2] = agent

    return SampleBatch.concat_samples(trajectories)
Пример #7
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.remote_evaluators:
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.learn_on_batch(samples)
                self.learner_stats = get_learner_stats(fetches)
                if self.num_sgd_iter > 1:
                    logger.debug("{} {}".format(i, fetches))
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return self.learner_stats
Пример #8
0
    def foreach_worker(self, func):
        """Apply the given function to each worker instance."""

        local_result = [func(self.local_worker())]
        remote_results = ray_get_and_free(
            [w.apply.remote(func) for w in self.remote_workers()])
        return local_result + remote_results
Пример #9
0
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.workers.remote_workers():
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote()
                            for e in self.workers.remote_workers()
                        ]))
                else:
                    samples.append(self.workers.local_worker().sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        # Handle everything as if multiagent
        if isinstance(samples, SampleBatch):
            samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples},
                                      samples.count)

        fetches = {}
        with self.grad_timer:
            for policy_id, policy in self.policies.items():
                if policy_id not in samples.policy_batches:
                    continue

                batch = samples.policy_batches[policy_id]
                for field in self.standardize_fields:
                    value = batch[field]
                    standardized = (value - value.mean()) / max(
                        1e-4, value.std())
                    batch[field] = standardized

                for i in range(self.num_sgd_iter):
                    iter_extra_fetches = defaultdict(list)
                    for minibatch in self._minibatches(batch):
                        batch_fetches = (
                            self.workers.local_worker().learn_on_batch(
                                MultiAgentBatch({policy_id: minibatch},
                                                minibatch.count)))[policy_id]
                        for k, v in batch_fetches[LEARNER_STATS_KEY].items():
                            iter_extra_fetches[k].append(v)
                    logger.debug("{} {}".format(i,
                                                _averaged(iter_extra_fetches)))
                fetches[policy_id] = _averaged(iter_extra_fetches)

        self.grad_timer.push_units_processed(samples.count)
        if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
            self.learner_stats = fetches[DEFAULT_POLICY_ID]
        else:
            self.learner_stats = fetches
        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return self.learner_stats
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.workers.remote_workers():
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote()
                            for e in self.workers.remote_workers()
                        ]))
                else:
                    samples.append(self.workers.local_worker().sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        # Unfortunate to have to hack it like this, but not sure how else to do it.
        # Setting the phase to zeros results in policy optimization, and to ones results in aux optimization.
        # These have to be added prior to the policy sgd.
        samples["phase"] = np.zeros(samples.count)

        with self.grad_timer:
            fetches = do_minibatch_sgd(samples, self.policies,
                                       self.workers.local_worker(),
                                       self.num_sgd_iter,
                                       self.sgd_minibatch_size,
                                       self.standardize_fields)
        self.grad_timer.push_units_processed(samples.count)

        if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
            self.learner_stats = fetches[DEFAULT_POLICY_ID]
        else:
            self.learner_stats = fetches

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count

        if self.num_steps_sampled > self.aux_loss_start_after_num_steps:
            # Add samples to the memory to be provided to the aux loss.
            self._remove_unnecessary_data(samples)
            self.memory.append(samples)

            # Optionally run the aux optimization.
            if len(self.memory) >= self.aux_loss_every_k:
                samples = SampleBatch.concat_samples(self.memory)
                self._add_policy_logits(samples)
                # Ones indicate aux phase.
                samples["phase"] = np.ones_like(samples["phase"])
                do_minibatch_sgd(samples, self.policies,
                                 self.workers.local_worker(),
                                 self.aux_loss_num_sgd_iter,
                                 self.sgd_minibatch_size, [])
                self.memory = []

        return self.learner_stats
Пример #11
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batches = ray_get_and_free(
                    [e.sample.remote() for e in self.remote_evaluators])
            else:
                batches = [self.local_evaluator.sample()]

            # Handle everything as if multiagent
            tmp = []
            for batch in batches:
                if isinstance(batch, SampleBatch):
                    batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                            batch.count)
                tmp.append(batch)
            batches = tmp

            for batch in batches:
                self.replay_buffer.append(batch)
                self.num_steps_sampled += batch.count
                self.buffer_size += batch.count
                while self.buffer_size > self.max_buffer_size:
                    evicted = self.replay_buffer.pop(0)
                    self.buffer_size -= evicted.count

        if self.num_steps_sampled >= self.replay_starts:
            return self._optimize()
        else:
            return {}
Пример #12
0
def collect_samples_straggler_mitigation(agents, train_batch_size):
    """Collects at least train_batch_size samples.

    This is the legacy behavior as of 0.6, and launches extra sample tasks to
    potentially improve performance but can result in many wasted samples.
    """

    num_timesteps_so_far = 0
    trajectories = []
    agent_dict = {}

    for agent in agents:
        fut_sample = agent.sample.remote()
        agent_dict[fut_sample] = agent

    while num_timesteps_so_far < train_batch_size:
        # TODO(pcm): Make wait support arbitrary iterators and remove the
        # conversion to list here.
        [fut_sample], _ = ray.wait(list(agent_dict))
        agent = agent_dict.pop(fut_sample)
        # Start task with next trajectory and record it in the dictionary.
        fut_sample2 = agent.sample.remote()
        agent_dict[fut_sample2] = agent

        next_sample = ray_get_and_free(fut_sample)
        num_timesteps_so_far += next_sample.count
        trajectories.append(next_sample)

    logger.info("Discarding {} sample tasks".format(len(agent_dict)))
    return SampleBatch.concat_samples(trajectories)
Пример #13
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray_get_and_free(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count
Пример #14
0
 def stats(self):
     replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote(
         self.debug))
     timing = {
         "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
         for k in self.timers
     }
     timing["learner_grad_time_ms"] = round(
         1000 * self.learner.grad_timer.mean, 3)
     timing["learner_dequeue_time_ms"] = round(
         1000 * self.learner.queue_timer.mean, 3)
     stats = {
         "sample_throughput": round(self.timers["sample"].mean_throughput,
                                    3),
         "train_throughput": round(self.timers["train"].mean_throughput, 3),
         "num_weight_syncs": self.num_weight_syncs,
         "num_samples_dropped": self.num_samples_dropped,
         "learner_queue": self.learner.learner_queue_size.stats(),
         "replay_shard_0": replay_stats,
     }
     debug_stats = {
         "timing_breakdown": timing,
         "pending_sample_tasks": self.sample_tasks.count,
         "pending_replay_tasks": self.replay_tasks.count,
     }
     if self.debug:
         stats.update(debug_stats)
     if self.learner.stats:
         stats["learner"] = self.learner.stats
     return dict(PolicyOptimizer.stats(self), **stats)
Пример #15
0
    def foreach_evaluator(self, func):
        """Apply the given function to each evaluator instance."""

        local_result = [func(self.local_evaluator)]
        remote_results = ray_get_and_free(
            [ev.apply.remote(func) for ev in self.remote_evaluators])
        return local_result + remote_results
Пример #16
0
def collect_episodes(local_worker=None,
                     remote_workers=[],
                     timeout_seconds=180):
    """Gathers new episodes metrics tuples from the given evaluators."""

    if remote_workers:
        pending = [
            a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_workers
        ]
        collected, _ = ray.wait(pending,
                                num_returns=len(pending),
                                timeout=timeout_seconds * 1.0)
        num_metric_batches_dropped = len(pending) - len(collected)
        if pending and len(collected) == 0:
            raise ValueError(
                "Timed out waiting for metrics from workers. You can "
                "configure this timeout with `collect_metrics_timeout`.")
        metric_lists = ray_get_and_free(collected)
    else:
        metric_lists = []
        num_metric_batches_dropped = 0

    if local_worker:
        metric_lists.append(local_worker.get_metrics())
    episodes = []
    for metrics in metric_lists:
        episodes.extend(metrics)
    return episodes, num_metric_batches_dropped
Пример #17
0
def collect_episodes(local_worker=None,
                     remote_workers=[],
                     to_be_collected=[],
                     timeout_seconds=180):
    """Gathers new episodes metrics tuples from the given evaluators."""

    if remote_workers:
        pending = [
            a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_workers
        ] + to_be_collected
        collected, to_be_collected = ray.wait(pending,
                                              num_returns=len(pending),
                                              timeout=timeout_seconds * 1.0)
        if pending and len(collected) == 0:
            logger.warning(
                "WARNING: collected no metrics in {} seconds".format(
                    timeout_seconds))
        metric_lists = ray_get_and_free(collected)
    else:
        metric_lists = []

    if local_worker:
        metric_lists.append(local_worker.get_metrics())
    episodes = []
    for metrics in metric_lists:
        episodes.extend(metrics)
    return episodes, to_be_collected
Пример #18
0
 def iter_train_batches(self):
     assert self.initialized, "Must call init() before using this class."
     for agg, batches in self.agg_tasks.completed_prefetch():
         for b in ray_get_and_free(batches):
             self.num_sent_since_broadcast += 1
             yield b
         agg.set_weights.remote(self.broadcasted_weights)
         self.agg_tasks.add(agg, agg.get_train_batches.remote())
         self.num_batches_processed += 1
        def checkpoint_and_set_static_policy_distribution_on_train_result_callback(
                params):
            trainer = params['trainer']
            result = params['result']
            result['psro_policy_num'] = trainer.claimed_policy_num
            evo_update(params)

            if not hasattr(trainer, 'next_refresh_steps'):
                trainer.next_refresh_steps = CHECKPOINT_AND_REFRESH_LIVE_TABLE_EVERY_N_STEPS

            if not hasattr(trainer, 'new_payoff_table_promise'):
                trainer.new_payoff_table_promise = None

            if result['timesteps_total'] >= trainer.next_refresh_steps:
                trainer.next_refresh_steps = max(
                    trainer.next_refresh_steps +
                    CHECKPOINT_AND_REFRESH_LIVE_TABLE_EVERY_N_STEPS,
                    result['timesteps_total'] + 1)
                # do checkpoint
                _do_live_policy_checkpoint(
                    trainer=trainer,
                    training_iteration=result['training_iteration'])

                if not PIPELINE_LIVE_PAYOFF_TABLE_CALC_IS_ASYNCHRONOUS:
                    # wait for all other learners to also reach this point before continuing
                    ray_get_and_free(
                        trainer.live_table_tracker.
                        wait_at_barrier_for_other_learners.remote())

                if not trainer.are_all_lower_policies_finished:
                    trainer.payoff_table_needs_update_started = True

            # refresh payoff table/selection probs
            if trainer.payoff_table_needs_update_started and trainer.new_payoff_table_promise is None:

                trainer.new_payoff_table_promise = trainer.live_table_tracker.get_live_payoff_table_dill_pickled.remote(
                    first_wait_for_n_seconds=2)
                trainer.payoff_table_needs_update_started = False

            if trainer.new_payoff_table_promise is not None:
                _process_new_live_payoff_table_result_if_ready(
                    trainer=trainer,
                    block_until_result_is_ready=
                    not PIPELINE_LIVE_PAYOFF_TABLE_CALC_IS_ASYNCHRONOUS)
Пример #20
0
def main():
    ray.init()
    config = es.DEFAULT_CONFIG.copy()
    config['env_config'] = {'root_dir': './data/wiki_test'}
    config['num_workers'] = args.num_workers
    config['episodes_per_batch'] = args.num_episodes
    config['train_batch_size'] = args.num_rollouts
    env = Writing
    trainer = es.ESTrainer.remote(config, env)

    # Can optionally call trainer.restore(path) to load a checkpoint.

    for i in range(1000):
        # Perform one iteration of training the policy with PPO
        result = ray_get_and_free(trainer.train.remote())
        print(pretty_print(result))

        if i % 100 == 0:
            checkpoint = ray_get_and_free(trainer.save.remote())
            print("checkpoint saved at", checkpoint)
Пример #21
0
    def foreach_worker_with_index(self, func):
        """Apply the given function to each worker instance.

        The index will be passed as the second arg to the given function.
        """
        local_result = [func(self.local_worker(), 0)]
        remote_results = ray_get_and_free([
            w.apply.remote(func, i + 1)
            for i, w in enumerate(self.remote_workers())
        ])
        return local_result + remote_results
Пример #22
0
    def next(self):
        """Return the next batch of experiences read.

        Returns:
            SampleBatch or MultiAgentBatch read.
        """
        batches = []
        for dp in self.data_processors:
            batches.append(ray_get_and_free(dp.next.remote()))
        batch = MultiAgentBatch.concat_samples(samples=batches)

        return batch
Пример #23
0
    def foreach_evaluator_with_index(self, func):
        """Apply the given function to each evaluator instance.

        The index will be passed as the second arg to the given function.
        """

        local_result = [func(self.local_evaluator, 0)]
        remote_results = ray_get_and_free([
            ev.apply.remote(func, i + 1)
            for i, ev in enumerate(self.remote_evaluators)
        ])
        return local_result + remote_results
Пример #24
0
    def _try_recover(self):
        """Try to identify and blacklist any unhealthy workers.

        This method is called after an unexpected remote error is encountered
        from a worker. It issues check requests to all current workers and
        blacklists any that respond with error. If no healthy workers remain,
        an error is raised.
        """

        if not self._has_policy_optimizer():
            raise NotImplementedError(
                "Recovery is not supported for this algorithm")

        logger.info("Health checking all workers...")
        checks = []
        for ev in self.optimizer.remote_evaluators:
            _, obj_id = ev.sample_with_count.remote()
            checks.append(obj_id)

        healthy_evaluators = []
        for i, obj_id in enumerate(checks):
            ev = self.optimizer.remote_evaluators[i]
            try:
                ray_get_and_free(obj_id)
                healthy_evaluators.append(ev)
                logger.info("Worker {} looks healthy".format(i + 1))
            except RayError:
                logger.exception("Blacklisting worker {}".format(i + 1))
                try:
                    ev.__ray_terminate__.remote()
                except Exception:
                    logger.exception("Error terminating unhealthy worker")

        if len(healthy_evaluators) < 1:
            raise RuntimeError(
                "Not enough healthy workers remain to continue.")

        self.optimizer.reset(healthy_evaluators)
Пример #25
0
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.workers.remote_workers():
                batches = ray_get_and_free(
                    [e.sample.remote() for e in self.workers.remote_workers()])
            else:
                batches = [self.workers.local_worker().sample()]

            # Handle everything as if multiagent
            tmp = []
            for batch in batches:
                if isinstance(batch, SampleBatch):
                    batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                            batch.count)
                tmp.append(batch)
            batches = tmp

            for batch in batches:
                if batch.count > self.max_buffer_size:
                    raise ValueError(
                        "The size of a single sample batch exceeds the replay "
                        "buffer size ({} > {})".format(batch.count,
                                                       self.max_buffer_size))
                self.replay_buffer.append(batch)
                self.num_steps_sampled += batch.count
                self.buffer_size += batch.count
                while self.buffer_size > self.max_buffer_size:
                    evicted = self.replay_buffer.pop(0)
                    self.buffer_size -= evicted.count

        if self.num_steps_sampled >= self.replay_starts and self.num_steps_sampled % self.train_every == 0:
            iter_extra_fetches = defaultdict(list)
            with self.grad_timer:
                for i in range(self.num_sgd_iter):
                    batch_fetches = self._sgd_step()
                    for k, v in batch_fetches.items():
                        iter_extra_fetches[k].append(v)
            self.grad_timer.push_units_processed(self.train_batch_size *
                                                 self.num_sgd_iter)
            return averaged(iter_extra_fetches)
        else:
            self.grad_timer = TimerStat()
            self.learner_stats = {}
            return {}
Пример #26
0
    def _augment_with_replay(self, sample_futures):
        def can_replay():
            num_needed = int(
                np.ceil(self.train_batch_size / self.sample_batch_size))
            return len(self.replay_batches) > num_needed

        for ev, sample_batch in sample_futures:
            sample_batch = ray_get_and_free(sample_batch)
            yield ev, sample_batch

            if can_replay():
                f = self.replay_proportion
                while random.random() < f:
                    f -= 1
                    replay_batch = random.choice(self.replay_batches)
                    self.num_replayed += replay_batch.count
                    yield None, replay_batch
Пример #27
0
    def foreach_policy(self, func):
        """Apply the given function to each worker's (policy, policy_id) tuple.

        Args:
            func (callable): A function - taking a Policy and its ID - that is
                called on all workers' Policies.

        Returns:
            List[any]: The list of return values of func over all workers'
                policies.
        """
        local_results = self.local_worker().foreach_policy(func)
        remote_results = []
        for worker in self.remote_workers():
            res = ray_get_and_free(
                worker.apply.remote(lambda w: w.foreach_policy(func)))
            remote_results.extend(res)
        return local_results + remote_results
Пример #28
0
    def synchronize(local_filters, remotes, update_remote=True):
        """Aggregates all filters from remote evaluators.

        Local copy is updated and then broadcasted to all remote evaluators.

        Args:
            local_filters (dict): Filters to be synchronized.
            remotes (list): Remote evaluators with filters.
            update_remote (bool): Whether to push updates to remote filters.
        """
        remote_filters = ray_get_and_free(
            [r.get_filters.remote(flush_after=True) for r in remotes])
        for rf in remote_filters:
            for k in local_filters:
                local_filters[k].apply_changes(rf[k], with_buffer=False)
        if update_remote:
            copies = {k: v.as_serializable() for k, v in local_filters.items()}
            remote_copy = ray.put(copies)
            [r.sync_filters.remote(remote_copy) for r in remotes]
Пример #29
0
    def foreach_trainable_policy(self, func):
        """Apply `func` to all workers' Policies iff in `policies_to_train`.

        Args:
            func (callable): A function - taking a Policy and its ID - that is
                called on all workers' Policies in `worker.policies_to_train`.

        Returns:
            List[any]: The list of n return values of all
                `func([trainable policy], [ID])`-calls.
        """
        local_results = self.local_worker().foreach_trainable_policy(func)
        remote_results = []
        for worker in self.remote_workers():
            res = ray_get_and_free(
                worker.apply.remote(
                    lambda w: w.foreach_trainable_policy(func)))
            remote_results.extend(res)
        return local_results + remote_results
Пример #30
0
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.workers.remote_workers():
                batch = SampleBatch.concat_samples(
                    ray_get_and_free([
                        e.sample.remote()
                        for e in self.workers.remote_workers()
                    ]))
            else:
                batch = self.workers.local_worker().sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        None,
                        row["action_logp"],
                        # row["diversity_advantages"],
                        row["diversity_rewards"],
                        # row["diversity_value_targets"],
                        # row["my_logits"],
                        row["prev_actions"],
                        row["prev_rewards"])

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count