示例#1
0
    def __init__(self, args, config):

        DefaultCallbacks.__init__(self)

        self.args = args
        self.config = config

        self.outdir = args.writer_dir
        self.model = args.model_name
        self.framework = config['framework']
        self.no_frame = args.no_frame
        # self.epochs_trained = self._extract_train_epochs()
        self.epochs_trained = args.epochs_trained if args.epochs_trained is not None else None
        self.use_hickle = args.use_hickle
        self.write_compressed = _write_hkl if self.use_hickle else _write_pkl
        self.read_compressed = _read_hkl if self.use_hickle else _write_pkl
        self.ext = "hkl" if self.use_hickle else "pkl"

        # This needs to be -1 because of weird issues with
        # the parallel processing of rollouts
        self.episode_num = -1
        self.step_num = 0
        self.reward_total = 0
        self.last_done = False
        self.episode = {}
def generate_fake_multiagent_batch(env, policies):
    multi_agent_batch_builder = MultiAgentSampleBatchBuilder(
        policy_map={
            player: policy for player, policy in zip(env.players_ids, policies)
        },
        clip_rewards=False,
        callbacks=DefaultCallbacks(),
    )
    fake_actions = generate_fake_discrete_actions(env)
    env.reset()
    observations, rewards, done, info = env.step(fake_actions)
    for player_id in env.players_ids:
        step_player_values = {
            "eps_id": 0,
            "obs": observations[player_id],
            "new_obs": observations[player_id],
            "actions": fake_actions[player_id],
            "prev_actions": fake_actions[player_id],
            "rewards": rewards[player_id],
            "prev_rewards": rewards[player_id],
            "dones": True,
        }
        multi_agent_batch_builder.add_values(
            agent_id=player_id, policy_id=player_id, **step_player_values
        )
    multiagent_batch = multi_agent_batch_builder.build_and_reset()
    return multiagent_batch
示例#3
0
def get_fake_training_batch_for_ppo_in_ipd(policy):
    policy_id = "fake_player"
    players = {policy_id: policy}

    multi_agent_batch_builder = MultiAgentSampleBatchBuilder(
        policy_map={
            player_id: player
            for player_id, player in players.items()
        },
        clip_rewards=False,
        callbacks=DefaultCallbacks(),
    )

    n_steps_in_epi = 20
    for step_n in range(n_steps_in_epi):
        step_player_values = {
            SampleBatch.EPS_ID: 0,
            SampleBatch.OBS: 0,
            SampleBatch.NEXT_OBS: 0,
            SampleBatch.ACTIONS: 0,
            SampleBatch.REWARDS: random.randint(0, 10),
            SampleBatch.PREV_REWARDS: random.randint(0, 10),
            SampleBatch.VF_PREDS: random.randint(0, 10),
            SampleBatch.DONES: step_n == n_steps_in_epi - 1,
            SampleBatch.ACTION_DIST_INPUTS: [random.random(),
                                             random.random()],
            SampleBatch.ACTION_LOGP: random.random(),
        }
        multi_agent_batch_builder.add_values(agent_id=policy_id,
                                             policy_id=policy_id,
                                             **step_player_values)

    multiagent_batch = multi_agent_batch_builder.build_and_reset()
    return multiagent_batch.policy_batches[policy_id]
示例#4
0
文件: policy.py 项目: alipay/ray
    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        config: TrainerConfigDict,
    ):
        """Initializes a Policy instance.

        Args:
            observation_space: Observation space of the policy.
            action_space: Action space of the policy.
            config: A complete Trainer/Policy config dict. For the default
                config keys and values, see rllib/trainer/trainer.py.
        """
        self.observation_space: gym.Space = observation_space
        self.action_space: gym.Space = action_space
        # The base struct of the observation/action spaces.
        # E.g. action-space = gym.spaces.Dict({"a": Discrete(2)}) ->
        # action_space_struct = {"a": Discrete(2)}
        self.observation_space_struct = get_base_struct_from_space(
            observation_space)
        self.action_space_struct = get_base_struct_from_space(action_space)

        self.config: TrainerConfigDict = config
        self.framework = self.config.get("framework")
        # Create the callbacks object to use for handling custom callbacks.
        if self.config.get("callbacks"):
            self.callbacks: "DefaultCallbacks" = self.config.get("callbacks")()
        else:
            from ray.rllib.agents.callbacks import DefaultCallbacks

            self.callbacks: "DefaultCallbacks" = DefaultCallbacks()

        # The global timestep, broadcast down from time to time from the
        # local worker to all remote workers.
        self.global_timestep: int = 0

        # The action distribution class to use for action sampling, if any.
        # Child classes may set this.
        self.dist_class: Optional[Type] = None

        # Maximal view requirements dict for `learn_on_batch()` and
        # `compute_actions` calls.
        # View requirements will be automatically filtered out later based
        # on the postprocessing and loss functions to ensure optimal data
        # collection and transfer performance.
        view_reqs = self._get_default_view_requirements()
        if not hasattr(self, "view_requirements"):
            self.view_requirements = view_reqs
        else:
            for k, v in view_reqs.items():
                if k not in self.view_requirements:
                    self.view_requirements[k] = v
        # Whether the Model's initial state (method) has been added
        # automatically based on the given view requirements of the model.
        self._model_init_state_automatically_added = False
示例#5
0
文件: policy.py 项目: smorad/ray
    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        config: TrainerConfigDict,
    ):
        """Initializes a Policy instance.

        Args:
            observation_space: Observation space of the policy.
            action_space: Action space of the policy.
            config: A complete Trainer/Policy config dict. For the default
                config keys and values, see rllib/trainer/trainer.py.
        """
        self.observation_space: gym.Space = observation_space
        self.action_space: gym.Space = action_space
        # The base struct of the observation/action spaces.
        # E.g. action-space = gym.spaces.Dict({"a": Discrete(2)}) ->
        # action_space_struct = {"a": Discrete(2)}
        self.observation_space_struct = get_base_struct_from_space(
            observation_space)
        self.action_space_struct = get_base_struct_from_space(action_space)

        self.config: TrainerConfigDict = config
        self.framework = self.config.get("framework")
        # Create the callbacks object to use for handling custom callbacks.
        if self.config.get("callbacks"):
            self.callbacks: "DefaultCallbacks" = self.config.get("callbacks")()
        else:
            from ray.rllib.agents.callbacks import DefaultCallbacks

            self.callbacks: "DefaultCallbacks" = DefaultCallbacks()

        # The global timestep, broadcast down from time to time from the
        # local worker to all remote workers.
        self.global_timestep: int = 0

        # The action distribution class to use for action sampling, if any.
        # Child classes may set this.
        self.dist_class: Optional[Type] = None

        # Initialize view requirements.
        self.init_view_requirements()

        # Whether the Model's initial state (method) has been added
        # automatically based on the given view requirements of the model.
        self._model_init_state_automatically_added = False
示例#6
0
文件: policy.py 项目: yncxcw/ray
    def __init__(self, observation_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, config: TrainerConfigDict):
        """Initialize the graph.

        This is the standard constructor for policies. The policy
        class you pass into RolloutWorker will be constructed with
        these arguments.

        Args:
            observation_space (gym.spaces.Space): Observation space of the
                policy.
            action_space (gym.spaces.Space): Action space of the policy.
            config (TrainerConfigDict): Policy-specific configuration data.
        """
        self.observation_space = observation_space
        self.action_space = action_space
        self.action_space_struct = get_base_struct_from_space(action_space)
        self.config = config
        if self.config.get("callbacks"):
            self.callbacks: "DefaultCallbacks" = self.config.get("callbacks")()
        else:
            from ray.rllib.agents.callbacks import DefaultCallbacks
            self.callbacks: "DefaultCallbacks" = DefaultCallbacks()
        # The global timestep, broadcast down from time to time from the
        # driver.
        self.global_timestep = 0
        # The action distribution class to use for action sampling, if any.
        # Child classes may set this.
        self.dist_class = None
        # Maximal view requirements dict for `learn_on_batch()` and
        # `compute_actions` calls.
        # View requirements will be automatically filtered out later based
        # on the postprocessing and loss functions to ensure optimal data
        # collection and transfer performance.
        view_reqs = self._get_default_view_requirements()
        if not hasattr(self, "view_requirements"):
            self.view_requirements = view_reqs
        else:
            for k, v in view_reqs.items():
                if k not in self.view_requirements:
                    self.view_requirements[k] = v
        self._model_init_state_automatically_added = False
示例#7
0
class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
    """Common experience collection class.

    This class wraps a policy instance and an environment class to
    collect experiences from the environment. You can create many replicas of
    this class as Ray actors to scale RL training.

    This class supports vectorized and multi-agent policy evaluation (e.g.,
    VectorEnv, MultiAgentEnv, etc.)

    Examples:
        >>> # Create a rollout worker and using it to collect experiences.
        >>> worker = RolloutWorker(
        ...   env_creator=lambda _: gym.make("CartPole-v0"),
        ...   policy=PGTFPolicy)
        >>> print(worker.sample())
        SampleBatch({
            "obs": [[...]], "actions": [[...]], "rewards": [[...]],
            "dones": [[...]], "new_obs": [[...]]})

        >>> # Creating a multi-agent rollout worker
        >>> worker = RolloutWorker(
        ...   env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
        ...   policies={
        ...       # Use an ensemble of two policies for car agents
        ...       "car_policy1":
        ...         (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
        ...       "car_policy2":
        ...         (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
        ...       # Use a single shared policy for all traffic lights
        ...       "traffic_light_policy":
        ...         (PGTFPolicy, Box(...), Discrete(...), {}),
        ...   },
        ...   policy_mapping_fn=lambda agent_id:
        ...     random.choice(["car_policy1", "car_policy2"])
        ...     if agent_id.startswith("car_") else "traffic_light_policy")
        >>> print(worker.sample())
        MultiAgentBatch({
            "car_policy1": SampleBatch(...),
            "car_policy2": SampleBatch(...),
            "traffic_light_policy": SampleBatch(...)})
    """

    @DeveloperAPI
    @classmethod
    def as_remote(cls,
                  num_cpus=None,
                  num_gpus=None,
                  memory=None,
                  object_store_memory=None,
                  resources=None):
        return ray.remote(
            num_cpus=num_cpus,
            num_gpus=num_gpus,
            memory=memory,
            object_store_memory=object_store_memory,
            resources=resources)(cls)

    @DeveloperAPI
    def __init__(self,
                 env_creator,
                 policy,
                 policy_mapping_fn=None,
                 policies_to_train=None,
                 tf_session_creator=None,
                 rollout_fragment_length=100,
                 batch_mode="truncate_episodes",
                 episode_horizon=None,
                 preprocessor_pref="deepmind",
                 sample_async=False,
                 compress_observations=False,
                 num_envs=1,
                 observation_fn=None,
                 observation_filter="NoFilter",
                 clip_rewards=None,
                 clip_actions=True,
                 env_config=None,
                 model_config=None,
                 policy_config=None,
                 worker_index=0,
                 num_workers=0,
                 monitor_path=None,
                 log_dir=None,
                 log_level=None,
                 callbacks=None,
                 input_creator=lambda ioctx: ioctx.default_sampler_input(),
                 input_evaluation=frozenset([]),
                 output_creator=lambda ioctx: NoopOutput(),
                 remote_worker_envs=False,
                 remote_env_batch_wait_ms=0,
                 soft_horizon=False,
                 no_done_at_end=False,
                 seed=None,
                 extra_python_environs=None,
                 fake_sampler=False):
        """Initialize a rollout worker.

        Arguments:
            env_creator (func): Function that returns a gym.Env given an
                EnvContext wrapped configuration.
            policy (class|dict): Either a class implementing
                Policy, or a dictionary of policy id strings to
                (Policy, obs_space, action_space, config) tuples. If a
                dict is specified, then we are in multi-agent mode and a
                policy_mapping_fn should also be set.
            policy_mapping_fn (func): A function that maps agent ids to
                policy ids in multi-agent mode. This function will be called
                each time a new agent appears in an episode, to bind that agent
                to a policy for the duration of the episode.
            policies_to_train (list): Optional whitelist of policies to train,
                or None for all policies.
            tf_session_creator (func): A function that returns a TF session.
                This is optional and only useful with TFPolicy.
            rollout_fragment_length (int): The target number of env transitions
                to include in each sample batch returned from this worker.
            batch_mode (str): One of the following batch modes:
                "truncate_episodes": Each call to sample() will return a batch
                    of at most `rollout_fragment_length * num_envs` in size.
                    The batch will be exactly
                    `rollout_fragment_length * num_envs` in size if
                    postprocessing does not change batch sizes. Episodes may be
                    truncated in order to meet this size requirement.
                "complete_episodes": Each call to sample() will return a batch
                    of at least `rollout_fragment_length * num_envs` in size.
                    Episodes will not be truncated, but multiple episodes may
                    be packed within one batch to meet the batch size. Note
                    that when `num_envs > 1`, episode steps will be buffered
                    until the episode completes, and hence batches may contain
                    significant amounts of off-policy data.
            episode_horizon (int): Whether to stop episodes at this horizon.
            preprocessor_pref (str): Whether to prefer RLlib preprocessors
                ("rllib") or deepmind ("deepmind") when applicable.
            sample_async (bool): Whether to compute samples asynchronously in
                the background, which improves throughput but can cause samples
                to be slightly off-policy.
            compress_observations (bool): If true, compress the observations.
                They can be decompressed with rllib/utils/compression.
            num_envs (int): If more than one, will create multiple envs
                and vectorize the computation of actions. This has no effect if
                if the env already implements VectorEnv.
            observation_fn (ObservationFunction): Optional multi-agent
                observation function.
            observation_filter (str): Name of observation filter to use.
            clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to
                experience postprocessing. Setting to None means clip for Atari
                only.
            clip_actions (bool): Whether to clip action values to the range
                specified by the policy action space.
            env_config (dict): Config to pass to the env creator.
            model_config (dict): Config to use when creating the policy model.
            policy_config (dict): Config to pass to the policy. In the
                multi-agent case, this config will be merged with the
                per-policy configs specified by `policy`.
            worker_index (int): For remote workers, this should be set to a
                non-zero and unique value. This index is passed to created envs
                through EnvContext so that envs can be configured per worker.
            num_workers (int): For remote workers, how many workers altogether
                have been created?
            monitor_path (str): Write out episode stats and videos to this
                directory if specified.
            log_dir (str): Directory where logs can be placed.
            log_level (str): Set the root log level on creation.
            callbacks (DefaultCallbacks): Custom training callbacks.
            input_creator (func): Function that returns an InputReader object
                for loading previous generated experiences.
            input_evaluation (list): How to evaluate the policy performance.
                This only makes sense to set when the input is reading offline
                data. The possible values include:
                  - "is": the step-wise importance sampling estimator.
                  - "wis": the weighted step-wise is estimator.
                  - "simulation": run the environment in the background, but
                    use this data for evaluation only and never for learning.
            output_creator (func): Function that returns an OutputWriter object
                for saving generated experiences.
            remote_worker_envs (bool): If using num_envs > 1, whether to create
                those new envs in remote processes instead of in the current
                process. This adds overheads, but can make sense if your envs
            remote_env_batch_wait_ms (float): Timeout that remote workers
                are waiting when polling environments. 0 (continue when at
                least one env is ready) is a reasonable default, but optimal
                value could be obtained by measuring your environment
                step / reset and model inference perf.
            soft_horizon (bool): Calculate rewards but don't reset the
                environment when the horizon is hit.
            no_done_at_end (bool): Ignore the done=True at the end of the
                episode and instead record done=False.
            seed (int): Set the seed of both np and tf to this value to
                to ensure each remote worker has unique exploration behavior.
            extra_python_environs (dict): Extra python environments need to
                be set.
            fake_sampler (bool): Use a fake (inf speed) sampler for testing.
        """
        self._original_kwargs = locals().copy()
        del self._original_kwargs["self"]

        global _global_worker
        _global_worker = self

        # set extra environs first
        if extra_python_environs:
            for key, value in extra_python_environs.items():
                os.environ[key] = str(value)

        def gen_rollouts():
            while True:
                yield self.sample()

        ParallelIteratorWorker.__init__(self, gen_rollouts, False)

        policy_config = policy_config or {}
        if (tf and policy_config.get("eager")
                and not policy_config.get("no_eager_on_workers")
                # This eager check is necessary for certain all-framework tests
                # that use tf's eager_mode() context generator.
                and not tf.executing_eagerly()):
            tf.enable_eager_execution()

        if log_level:
            logging.getLogger("ray.rllib").setLevel(log_level)

        if worker_index > 1:
            disable_log_once_globally()  # only need 1 worker to log
        elif log_level == "DEBUG":
            enable_periodic_logging()

        env_context = EnvContext(env_config or {}, worker_index)
        self.policy_config = policy_config
        if callbacks:
            self.callbacks = callbacks()
        else:
            from ray.rllib.agents.callbacks import DefaultCallbacks
            self.callbacks = DefaultCallbacks()
        self.worker_index = worker_index
        self.num_workers = num_workers
        model_config = model_config or {}
        policy_mapping_fn = (policy_mapping_fn
                             or (lambda agent_id: DEFAULT_POLICY_ID))
        if not callable(policy_mapping_fn):
            raise ValueError("Policy mapping function not callable?")
        self.env_creator = env_creator
        self.rollout_fragment_length = rollout_fragment_length * num_envs
        self.batch_mode = batch_mode
        self.compress_observations = compress_observations
        self.preprocessing_enabled = True
        self.last_batch = None
        self.global_vars = None
        self.fake_sampler = fake_sampler

        self.env = _validate_env(env_creator(env_context))
        if isinstance(self.env, MultiAgentEnv) or \
                isinstance(self.env, BaseEnv):

            def wrap(env):
                return env  # we can't auto-wrap these env types
        elif is_atari(self.env) and \
                not model_config.get("custom_preprocessor") and \
                preprocessor_pref == "deepmind":

            # Deepmind wrappers already handle all preprocessing
            self.preprocessing_enabled = False

            if clip_rewards is None:
                clip_rewards = True

            def wrap(env):
                env = wrap_deepmind(
                    env,
                    dim=model_config.get("dim"),
                    framestack=model_config.get("framestack"))
                if monitor_path:
                    from gym import wrappers
                    env = wrappers.Monitor(env, monitor_path, resume=True)
                return env
        else:

            def wrap(env):
                if monitor_path:
                    from gym import wrappers
                    env = wrappers.Monitor(env, monitor_path, resume=True)
                return env

        self.env = wrap(self.env)

        def make_env(vector_index):
            return wrap(
                env_creator(
                    env_context.copy_with_overrides(
                        vector_index=vector_index, remote=remote_worker_envs)))

        self.tf_sess = None
        policy_dict = _validate_and_canonicalize(policy, self.env)
        self.policies_to_train = policies_to_train or list(policy_dict.keys())
        # set numpy and python seed
        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)
            if not hasattr(self.env, "seed"):
                raise ValueError("Env doesn't support env.seed(): {}".format(
                    self.env))
            self.env.seed(seed)
            try:
                assert torch is not None
                torch.manual_seed(seed)
            except AssertionError:
                logger.info("Could not seed torch")
        if _has_tensorflow_graph(policy_dict) and not (tf and
                                                       tf.executing_eagerly()):
            if not tf:
                raise ImportError("Could not import tensorflow")
            with tf.Graph().as_default():
                if tf_session_creator:
                    self.tf_sess = tf_session_creator()
                else:
                    self.tf_sess = tf.Session(
                        config=tf.ConfigProto(
                            gpu_options=tf.GPUOptions(allow_growth=True)))
                with self.tf_sess.as_default():
                    # set graph-level seed
                    if seed is not None:
                        tf.set_random_seed(seed)
                    self.policy_map, self.preprocessors = \
                        self._build_policy_map(policy_dict, policy_config)
            if (ray.is_initialized()
                    and ray.worker._mode() != ray.worker.LOCAL_MODE):
                if not ray.get_gpu_ids():
                    logger.debug(
                        "Creating policy evaluation worker {}".format(
                            worker_index) +
                        " on CPU (please ignore any CUDA init errors)")
                elif not tf.test.is_gpu_available():
                    raise RuntimeError(
                        "GPUs were assigned to this worker by Ray, but "
                        "TensorFlow reports GPU acceleration is disabled. "
                        "This could be due to a bad CUDA or TF installation.")
        else:
            self.policy_map, self.preprocessors = self._build_policy_map(
                policy_dict, policy_config)

        self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
        if self.multiagent:
            if not ((isinstance(self.env, MultiAgentEnv)
                     or isinstance(self.env, ExternalMultiAgentEnv))
                    or isinstance(self.env, BaseEnv)):
                raise ValueError(
                    "Have multiple policies {}, but the env ".format(
                        self.policy_map) +
                    "{} is not a subclass of BaseEnv, MultiAgentEnv or "
                    "ExternalMultiAgentEnv?".format(self.env))

        self.filters = {
            policy_id: get_filter(observation_filter,
                                  policy.observation_space.shape)
            for (policy_id, policy) in self.policy_map.items()
        }
        if self.worker_index == 0:
            logger.info("Built filter map: {}".format(self.filters))

        # Always use vector env for consistency even if num_envs = 1
        self.async_env = BaseEnv.to_base_env(
            self.env,
            make_env=make_env,
            num_envs=num_envs,
            remote_envs=remote_worker_envs,
            remote_env_batch_wait_ms=remote_env_batch_wait_ms)
        self.num_envs = num_envs

        if self.batch_mode == "truncate_episodes":
            pack_episodes = True
        elif self.batch_mode == "complete_episodes":
            rollout_fragment_length = float("inf")  # never cut episodes
            pack_episodes = False  # sampler will return 1 episode per poll
        else:
            raise ValueError("Unsupported batch mode: {}".format(
                self.batch_mode))

        self.io_context = IOContext(log_dir, policy_config, worker_index, self)
        self.reward_estimators = []
        for method in input_evaluation:
            if method == "simulation":
                logger.warning(
                    "Requested 'simulation' input evaluation method: "
                    "will discard all sampler outputs and keep only metrics.")
                sample_async = True
            elif method == "is":
                ise = ImportanceSamplingEstimator.create(self.io_context)
                self.reward_estimators.append(ise)
            elif method == "wis":
                wise = WeightedImportanceSamplingEstimator.create(
                    self.io_context)
                self.reward_estimators.append(wise)
            else:
                raise ValueError(
                    "Unknown evaluation method: {}".format(method))

        if sample_async:
            self.sampler = AsyncSampler(
                self,
                self.async_env,
                self.policy_map,
                policy_mapping_fn,
                self.preprocessors,
                self.filters,
                clip_rewards,
                rollout_fragment_length,
                self.callbacks,
                horizon=episode_horizon,
                pack=pack_episodes,
                tf_sess=self.tf_sess,
                clip_actions=clip_actions,
                blackhole_outputs="simulation" in input_evaluation,
                soft_horizon=soft_horizon,
                no_done_at_end=no_done_at_end,
                observation_fn=observation_fn)
            self.sampler.start()
        else:
            self.sampler = SyncSampler(
                self,
                self.async_env,
                self.policy_map,
                policy_mapping_fn,
                self.preprocessors,
                self.filters,
                clip_rewards,
                rollout_fragment_length,
                self.callbacks,
                horizon=episode_horizon,
                pack=pack_episodes,
                tf_sess=self.tf_sess,
                clip_actions=clip_actions,
                soft_horizon=soft_horizon,
                no_done_at_end=no_done_at_end,
                observation_fn=observation_fn)

        self.input_reader = input_creator(self.io_context)
        assert isinstance(self.input_reader, InputReader), self.input_reader
        self.output_writer = output_creator(self.io_context)
        assert isinstance(self.output_writer, OutputWriter), self.output_writer

        logger.debug(
            "Created rollout worker with env {} ({}), policies {}".format(
                self.async_env, self.env, self.policy_map))

    @override(EvaluatorInterface)
    def sample(self):
        """Evaluate the current policies and return a batch of experiences.

        Return:
            SampleBatch|MultiAgentBatch from evaluating the current policies.
        """

        if self.fake_sampler and self.last_batch is not None:
            return self.last_batch

        if log_once("sample_start"):
            logger.info("Generating sample batch of size {}".format(
                self.rollout_fragment_length))

        batches = [self.input_reader.next()]
        steps_so_far = batches[0].count

        # In truncate_episodes mode, never pull more than 1 batch per env.
        # This avoids over-running the target batch size.
        if self.batch_mode == "truncate_episodes":
            max_batches = self.num_envs
        else:
            max_batches = float("inf")

        while (steps_so_far < self.rollout_fragment_length
               and len(batches) < max_batches):
            batch = self.input_reader.next()
            steps_so_far += batch.count
            batches.append(batch)
        batch = batches[0].concat_samples(batches)

        self.callbacks.on_sample_end(worker=self, samples=batch)

        # Always do writes prior to compression for consistency and to allow
        # for better compression inside the writer.
        self.output_writer.write(batch)

        # Do off-policy estimation if needed
        if self.reward_estimators:
            for sub_batch in batch.split_by_episode():
                for estimator in self.reward_estimators:
                    estimator.process(sub_batch)

        if log_once("sample_end"):
            logger.info("Completed sample batch:\n\n{}\n".format(
                summarize(batch)))

        if self.compress_observations == "bulk":
            batch.compress(bulk=True)
        elif self.compress_observations:
            batch.compress()

        if self.fake_sampler:
            self.last_batch = batch
        return batch

    @DeveloperAPI
    @ray.method(num_return_vals=2)
    def sample_with_count(self):
        """Same as sample() but returns the count as a separate future."""
        batch = self.sample()
        return batch, batch.count

    @override(EvaluatorInterface)
    def get_weights(self, policies=None):
        if policies is None:
            policies = self.policy_map.keys()
        return {
            pid: policy.get_weights()
            for pid, policy in self.policy_map.items() if pid in policies
        }

    @override(EvaluatorInterface)
    def set_weights(self, weights, global_vars=None):
        for pid, w in weights.items():
            self.policy_map[pid].set_weights(w)
        if global_vars:
            self.set_global_vars(global_vars)

    @override(EvaluatorInterface)
    def compute_gradients(self, samples):
        if log_once("compute_gradients"):
            logger.info("Compute gradients on:\n\n{}\n".format(
                summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            grad_out, info_out = {}, {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "compute_gradients")
                for pid, batch in samples.policy_batches.items():
                    if pid not in self.policies_to_train:
                        continue
                    grad_out[pid], info_out[pid] = (
                        self.policy_map[pid]._build_compute_gradients(
                            builder, batch))
                grad_out = {k: builder.get(v) for k, v in grad_out.items()}
                info_out = {k: builder.get(v) for k, v in info_out.items()}
            else:
                for pid, batch in samples.policy_batches.items():
                    if pid not in self.policies_to_train:
                        continue
                    grad_out[pid], info_out[pid] = (
                        self.policy_map[pid].compute_gradients(batch))
        else:
            grad_out, info_out = (
                self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
        info_out["batch_count"] = samples.count
        if log_once("grad_out"):
            logger.info("Compute grad info:\n\n{}\n".format(
                summarize(info_out)))
        return grad_out, info_out

    @override(EvaluatorInterface)
    def apply_gradients(self, grads):
        if log_once("apply_gradients"):
            logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
        if isinstance(grads, dict):
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "apply_gradients")
                outputs = {
                    pid: self.policy_map[pid]._build_apply_gradients(
                        builder, grad)
                    for pid, grad in grads.items()
                }
                return {k: builder.get(v) for k, v in outputs.items()}
            else:
                return {
                    pid: self.policy_map[pid].apply_gradients(g)
                    for pid, g in grads.items()
                }
        else:
            return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)

    @override(EvaluatorInterface)
    def learn_on_batch(self, samples):
        if log_once("learn_on_batch"):
            logger.info(
                "Training on concatenated sample batches:\n\n{}\n".format(
                    summarize(samples)))
        if isinstance(samples, MultiAgentBatch):
            info_out = {}
            to_fetch = {}
            if self.tf_sess is not None:
                builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
            else:
                builder = None
            for pid, batch in samples.policy_batches.items():
                if pid not in self.policies_to_train:
                    continue
                policy = self.policy_map[pid]
                if builder and hasattr(policy, "_build_learn_on_batch"):
                    to_fetch[pid] = policy._build_learn_on_batch(
                        builder, batch)
                else:
                    info_out[pid] = policy.learn_on_batch(batch)
            info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
        else:
            info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
                samples)
        if log_once("learn_out"):
            logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
        return info_out

    def sample_and_learn(self, expected_batch_size, num_sgd_iter,
                         sgd_minibatch_size, standardize_fields):
        """Sample and batch and learn on it.

        This is typically used in combination with distributed allreduce.

        Arguments:
            expected_batch_size (int): Expected number of samples to learn on.
            num_sgd_iter (int): Number of SGD iterations.
            sgd_minibatch_size (int): SGD minibatch size.
            standardize_fields (list): List of sample fields to normalize.

        Returns:
            info: dictionary of extra metadata from learn_on_batch().
            count: number of samples learned on.
        """
        batch = self.sample()
        assert batch.count == expected_batch_size, \
            ("Batch size possibly out of sync between workers, expected:",
             expected_batch_size, "got:", batch.count)
        logger.info("Executing distributed minibatch SGD "
                    "with epoch size {}, minibatch size {}".format(
                        batch.count, sgd_minibatch_size))
        info = do_minibatch_sgd(batch, self.policy_map, self, num_sgd_iter,
                                sgd_minibatch_size, standardize_fields)
        return info, batch.count

    @DeveloperAPI
    def get_metrics(self):
        """Returns a list of new RolloutMetric objects from evaluation."""

        out = self.sampler.get_metrics()
        for m in self.reward_estimators:
            out.extend(m.get_metrics())
        return out

    @DeveloperAPI
    def foreach_env(self, func):
        """Apply the given function to each underlying env instance."""

        envs = self.async_env.get_unwrapped()
        if not envs:
            return [func(self.async_env)]
        else:
            return [func(e) for e in envs]

    @DeveloperAPI
    def get_policy(self, policy_id=DEFAULT_POLICY_ID):
        """Return policy for the specified id, or None.

        Arguments:
            policy_id (str): id of policy to return.
        """

        return self.policy_map.get(policy_id)

    @DeveloperAPI
    def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
        """Apply the given function to the specified policy."""

        return func(self.policy_map[policy_id])

    @DeveloperAPI
    def foreach_policy(self, func):
        """Apply the given function to each (policy, policy_id) tuple."""

        return [func(policy, pid) for pid, policy in self.policy_map.items()]

    @DeveloperAPI
    def foreach_trainable_policy(self, func):
        """
        Applies the given function to each (policy, policy_id) tuple, which
        can be found in `self.policies_to_train`.

        Args:
            func (callable): A function - taking a Policy and its ID - that is
                called on all Policies within `self.policies_to_train`.

        Returns:
            List[any]: The list of n return values of all
                `func([policy], [ID])`-calls.
        """
        return [
            func(policy, pid) for pid, policy in self.policy_map.items()
            if pid in self.policies_to_train
        ]

    @DeveloperAPI
    def sync_filters(self, new_filters):
        """Changes self's filter to given and rebases any accumulated delta.

        Args:
            new_filters (dict): Filters with new state to update local copy.
        """
        assert all(k in new_filters for k in self.filters)
        for k in self.filters:
            self.filters[k].sync(new_filters[k])

    @DeveloperAPI
    def get_filters(self, flush_after=False):
        """Returns a snapshot of filters.

        Args:
            flush_after (bool): Clears the filter buffer state.

        Returns:
            return_filters (dict): Dict for serializable filters
        """
        return_filters = {}
        for k, f in self.filters.items():
            return_filters[k] = f.as_serializable()
            if flush_after:
                f.clear_buffer()
        return return_filters

    @DeveloperAPI
    def save(self):
        filters = self.get_filters(flush_after=True)
        state = {
            pid: self.policy_map[pid].get_state()
            for pid in self.policy_map
        }
        return pickle.dumps({"filters": filters, "state": state})

    @DeveloperAPI
    def restore(self, objs):
        objs = pickle.loads(objs)
        self.sync_filters(objs["filters"])
        for pid, state in objs["state"].items():
            self.policy_map[pid].set_state(state)

    @DeveloperAPI
    def set_global_vars(self, global_vars):
        self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
        self.global_vars = global_vars

    @DeveloperAPI
    def get_global_vars(self):
        return self.global_vars

    @DeveloperAPI
    def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
        self.policy_map[policy_id].export_model(export_dir)

    @DeveloperAPI
    def import_policy_model_from_h5(self,
                                    import_file,
                                    policy_id=DEFAULT_POLICY_ID):
        self.policy_map[policy_id].import_model_from_h5(import_file)

    @DeveloperAPI
    def export_policy_checkpoint(self,
                                 export_dir,
                                 filename_prefix="model",
                                 policy_id=DEFAULT_POLICY_ID):
        self.policy_map[policy_id].export_checkpoint(export_dir,
                                                     filename_prefix)

    @DeveloperAPI
    def stop(self):
        self.async_env.stop()

    @DeveloperAPI
    def creation_args(self):
        """Returns the args used to create this worker."""
        return self._original_kwargs

    def _build_policy_map(self, policy_dict, policy_config):
        policy_map = {}
        preprocessors = {}
        for name, (cls, obs_space, act_space,
                   conf) in sorted(policy_dict.items()):
            logger.debug("Creating policy for {}".format(name))
            merged_conf = merge_dicts(policy_config, conf)
            merged_conf["num_workers"] = self.num_workers
            merged_conf["worker_index"] = self.worker_index
            if self.preprocessing_enabled:
                preprocessor = ModelCatalog.get_preprocessor_for_space(
                    obs_space, merged_conf.get("model"))
                preprocessors[name] = preprocessor
                obs_space = preprocessor.observation_space
            else:
                preprocessors[name] = NoPreprocessor(obs_space)
            if isinstance(obs_space, gym.spaces.Dict) or \
                    isinstance(obs_space, gym.spaces.Tuple):
                raise ValueError(
                    "Found raw Tuple|Dict space as input to policy. "
                    "Please preprocess these observations with a "
                    "Tuple|DictFlatteningPreprocessor.")
            if tf and tf.executing_eagerly():
                if hasattr(cls, "as_eager"):
                    cls = cls.as_eager()
                    if policy_config["eager_tracing"]:
                        cls = cls.with_tracing()
                elif not issubclass(cls, TFPolicy):
                    pass  # could be some other type of policy
                else:
                    raise ValueError("This policy does not support eager "
                                     "execution: {}".format(cls))
            if tf:
                with tf.variable_scope(name):
                    policy_map[name] = cls(obs_space, act_space, merged_conf)
            else:
                policy_map[name] = cls(obs_space, act_space, merged_conf)
        if self.worker_index == 0:
            logger.info("Built policy map: {}".format(policy_map))
            logger.info("Built preprocessor map: {}".format(preprocessors))
        return policy_map, preprocessors

    def setup_torch_data_parallel(self, url, world_rank, world_size, backend):
        """Join a torch process group for distributed SGD."""

        logger.info("Joining process group, url={}, world_rank={}, "
                    "world_size={}, backend={}".format(url, world_rank,
                                                       world_size, backend))
        torch.distributed.init_process_group(
            backend=backend,
            init_method=url,
            rank=world_rank,
            world_size=world_size)

        for pid, policy in self.policy_map.items():
            if not isinstance(policy, TorchPolicy):
                raise ValueError(
                    "This policy does not support torch distributed", policy)
            policy.distributed_world_size = world_size

    def get_node_ip(self):
        """Returns the IP address of the current node."""
        return ray.services.get_node_ip_address()

    def find_free_port(self):
        """Finds a free port on the current node."""
        from ray.util.sgd import utils
        return utils.find_free_port()

    def __del__(self):
        if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler):
            self.sampler.shutdown = True
示例#8
0
    def __init__(self,
                 env_creator,
                 policy,
                 policy_mapping_fn=None,
                 policies_to_train=None,
                 tf_session_creator=None,
                 rollout_fragment_length=100,
                 batch_mode="truncate_episodes",
                 episode_horizon=None,
                 preprocessor_pref="deepmind",
                 sample_async=False,
                 compress_observations=False,
                 num_envs=1,
                 observation_fn=None,
                 observation_filter="NoFilter",
                 clip_rewards=None,
                 clip_actions=True,
                 env_config=None,
                 model_config=None,
                 policy_config=None,
                 worker_index=0,
                 num_workers=0,
                 monitor_path=None,
                 log_dir=None,
                 log_level=None,
                 callbacks=None,
                 input_creator=lambda ioctx: ioctx.default_sampler_input(),
                 input_evaluation=frozenset([]),
                 output_creator=lambda ioctx: NoopOutput(),
                 remote_worker_envs=False,
                 remote_env_batch_wait_ms=0,
                 soft_horizon=False,
                 no_done_at_end=False,
                 seed=None,
                 extra_python_environs=None,
                 fake_sampler=False):
        """Initialize a rollout worker.

        Arguments:
            env_creator (func): Function that returns a gym.Env given an
                EnvContext wrapped configuration.
            policy (class|dict): Either a class implementing
                Policy, or a dictionary of policy id strings to
                (Policy, obs_space, action_space, config) tuples. If a
                dict is specified, then we are in multi-agent mode and a
                policy_mapping_fn should also be set.
            policy_mapping_fn (func): A function that maps agent ids to
                policy ids in multi-agent mode. This function will be called
                each time a new agent appears in an episode, to bind that agent
                to a policy for the duration of the episode.
            policies_to_train (list): Optional whitelist of policies to train,
                or None for all policies.
            tf_session_creator (func): A function that returns a TF session.
                This is optional and only useful with TFPolicy.
            rollout_fragment_length (int): The target number of env transitions
                to include in each sample batch returned from this worker.
            batch_mode (str): One of the following batch modes:
                "truncate_episodes": Each call to sample() will return a batch
                    of at most `rollout_fragment_length * num_envs` in size.
                    The batch will be exactly
                    `rollout_fragment_length * num_envs` in size if
                    postprocessing does not change batch sizes. Episodes may be
                    truncated in order to meet this size requirement.
                "complete_episodes": Each call to sample() will return a batch
                    of at least `rollout_fragment_length * num_envs` in size.
                    Episodes will not be truncated, but multiple episodes may
                    be packed within one batch to meet the batch size. Note
                    that when `num_envs > 1`, episode steps will be buffered
                    until the episode completes, and hence batches may contain
                    significant amounts of off-policy data.
            episode_horizon (int): Whether to stop episodes at this horizon.
            preprocessor_pref (str): Whether to prefer RLlib preprocessors
                ("rllib") or deepmind ("deepmind") when applicable.
            sample_async (bool): Whether to compute samples asynchronously in
                the background, which improves throughput but can cause samples
                to be slightly off-policy.
            compress_observations (bool): If true, compress the observations.
                They can be decompressed with rllib/utils/compression.
            num_envs (int): If more than one, will create multiple envs
                and vectorize the computation of actions. This has no effect if
                if the env already implements VectorEnv.
            observation_fn (ObservationFunction): Optional multi-agent
                observation function.
            observation_filter (str): Name of observation filter to use.
            clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to
                experience postprocessing. Setting to None means clip for Atari
                only.
            clip_actions (bool): Whether to clip action values to the range
                specified by the policy action space.
            env_config (dict): Config to pass to the env creator.
            model_config (dict): Config to use when creating the policy model.
            policy_config (dict): Config to pass to the policy. In the
                multi-agent case, this config will be merged with the
                per-policy configs specified by `policy`.
            worker_index (int): For remote workers, this should be set to a
                non-zero and unique value. This index is passed to created envs
                through EnvContext so that envs can be configured per worker.
            num_workers (int): For remote workers, how many workers altogether
                have been created?
            monitor_path (str): Write out episode stats and videos to this
                directory if specified.
            log_dir (str): Directory where logs can be placed.
            log_level (str): Set the root log level on creation.
            callbacks (DefaultCallbacks): Custom training callbacks.
            input_creator (func): Function that returns an InputReader object
                for loading previous generated experiences.
            input_evaluation (list): How to evaluate the policy performance.
                This only makes sense to set when the input is reading offline
                data. The possible values include:
                  - "is": the step-wise importance sampling estimator.
                  - "wis": the weighted step-wise is estimator.
                  - "simulation": run the environment in the background, but
                    use this data for evaluation only and never for learning.
            output_creator (func): Function that returns an OutputWriter object
                for saving generated experiences.
            remote_worker_envs (bool): If using num_envs > 1, whether to create
                those new envs in remote processes instead of in the current
                process. This adds overheads, but can make sense if your envs
            remote_env_batch_wait_ms (float): Timeout that remote workers
                are waiting when polling environments. 0 (continue when at
                least one env is ready) is a reasonable default, but optimal
                value could be obtained by measuring your environment
                step / reset and model inference perf.
            soft_horizon (bool): Calculate rewards but don't reset the
                environment when the horizon is hit.
            no_done_at_end (bool): Ignore the done=True at the end of the
                episode and instead record done=False.
            seed (int): Set the seed of both np and tf to this value to
                to ensure each remote worker has unique exploration behavior.
            extra_python_environs (dict): Extra python environments need to
                be set.
            fake_sampler (bool): Use a fake (inf speed) sampler for testing.
        """
        self._original_kwargs = locals().copy()
        del self._original_kwargs["self"]

        global _global_worker
        _global_worker = self

        # set extra environs first
        if extra_python_environs:
            for key, value in extra_python_environs.items():
                os.environ[key] = str(value)

        def gen_rollouts():
            while True:
                yield self.sample()

        ParallelIteratorWorker.__init__(self, gen_rollouts, False)

        policy_config = policy_config or {}
        if (tf and policy_config.get("eager")
                and not policy_config.get("no_eager_on_workers")
                # This eager check is necessary for certain all-framework tests
                # that use tf's eager_mode() context generator.
                and not tf.executing_eagerly()):
            tf.enable_eager_execution()

        if log_level:
            logging.getLogger("ray.rllib").setLevel(log_level)

        if worker_index > 1:
            disable_log_once_globally()  # only need 1 worker to log
        elif log_level == "DEBUG":
            enable_periodic_logging()

        env_context = EnvContext(env_config or {}, worker_index)
        self.policy_config = policy_config
        if callbacks:
            self.callbacks = callbacks()
        else:
            from ray.rllib.agents.callbacks import DefaultCallbacks
            self.callbacks = DefaultCallbacks()
        self.worker_index = worker_index
        self.num_workers = num_workers
        model_config = model_config or {}
        policy_mapping_fn = (policy_mapping_fn
                             or (lambda agent_id: DEFAULT_POLICY_ID))
        if not callable(policy_mapping_fn):
            raise ValueError("Policy mapping function not callable?")
        self.env_creator = env_creator
        self.rollout_fragment_length = rollout_fragment_length * num_envs
        self.batch_mode = batch_mode
        self.compress_observations = compress_observations
        self.preprocessing_enabled = True
        self.last_batch = None
        self.global_vars = None
        self.fake_sampler = fake_sampler

        self.env = _validate_env(env_creator(env_context))
        if isinstance(self.env, MultiAgentEnv) or \
                isinstance(self.env, BaseEnv):

            def wrap(env):
                return env  # we can't auto-wrap these env types
        elif is_atari(self.env) and \
                not model_config.get("custom_preprocessor") and \
                preprocessor_pref == "deepmind":

            # Deepmind wrappers already handle all preprocessing
            self.preprocessing_enabled = False

            if clip_rewards is None:
                clip_rewards = True

            def wrap(env):
                env = wrap_deepmind(
                    env,
                    dim=model_config.get("dim"),
                    framestack=model_config.get("framestack"))
                if monitor_path:
                    from gym import wrappers
                    env = wrappers.Monitor(env, monitor_path, resume=True)
                return env
        else:

            def wrap(env):
                if monitor_path:
                    from gym import wrappers
                    env = wrappers.Monitor(env, monitor_path, resume=True)
                return env

        self.env = wrap(self.env)

        def make_env(vector_index):
            return wrap(
                env_creator(
                    env_context.copy_with_overrides(
                        vector_index=vector_index, remote=remote_worker_envs)))

        self.tf_sess = None
        policy_dict = _validate_and_canonicalize(policy, self.env)
        self.policies_to_train = policies_to_train or list(policy_dict.keys())
        # set numpy and python seed
        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)
            if not hasattr(self.env, "seed"):
                raise ValueError("Env doesn't support env.seed(): {}".format(
                    self.env))
            self.env.seed(seed)
            try:
                assert torch is not None
                torch.manual_seed(seed)
            except AssertionError:
                logger.info("Could not seed torch")
        if _has_tensorflow_graph(policy_dict) and not (tf and
                                                       tf.executing_eagerly()):
            if not tf:
                raise ImportError("Could not import tensorflow")
            with tf.Graph().as_default():
                if tf_session_creator:
                    self.tf_sess = tf_session_creator()
                else:
                    self.tf_sess = tf.Session(
                        config=tf.ConfigProto(
                            gpu_options=tf.GPUOptions(allow_growth=True)))
                with self.tf_sess.as_default():
                    # set graph-level seed
                    if seed is not None:
                        tf.set_random_seed(seed)
                    self.policy_map, self.preprocessors = \
                        self._build_policy_map(policy_dict, policy_config)
            if (ray.is_initialized()
                    and ray.worker._mode() != ray.worker.LOCAL_MODE):
                if not ray.get_gpu_ids():
                    logger.debug(
                        "Creating policy evaluation worker {}".format(
                            worker_index) +
                        " on CPU (please ignore any CUDA init errors)")
                elif not tf.test.is_gpu_available():
                    raise RuntimeError(
                        "GPUs were assigned to this worker by Ray, but "
                        "TensorFlow reports GPU acceleration is disabled. "
                        "This could be due to a bad CUDA or TF installation.")
        else:
            self.policy_map, self.preprocessors = self._build_policy_map(
                policy_dict, policy_config)

        self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
        if self.multiagent:
            if not ((isinstance(self.env, MultiAgentEnv)
                     or isinstance(self.env, ExternalMultiAgentEnv))
                    or isinstance(self.env, BaseEnv)):
                raise ValueError(
                    "Have multiple policies {}, but the env ".format(
                        self.policy_map) +
                    "{} is not a subclass of BaseEnv, MultiAgentEnv or "
                    "ExternalMultiAgentEnv?".format(self.env))

        self.filters = {
            policy_id: get_filter(observation_filter,
                                  policy.observation_space.shape)
            for (policy_id, policy) in self.policy_map.items()
        }
        if self.worker_index == 0:
            logger.info("Built filter map: {}".format(self.filters))

        # Always use vector env for consistency even if num_envs = 1
        self.async_env = BaseEnv.to_base_env(
            self.env,
            make_env=make_env,
            num_envs=num_envs,
            remote_envs=remote_worker_envs,
            remote_env_batch_wait_ms=remote_env_batch_wait_ms)
        self.num_envs = num_envs

        if self.batch_mode == "truncate_episodes":
            pack_episodes = True
        elif self.batch_mode == "complete_episodes":
            rollout_fragment_length = float("inf")  # never cut episodes
            pack_episodes = False  # sampler will return 1 episode per poll
        else:
            raise ValueError("Unsupported batch mode: {}".format(
                self.batch_mode))

        self.io_context = IOContext(log_dir, policy_config, worker_index, self)
        self.reward_estimators = []
        for method in input_evaluation:
            if method == "simulation":
                logger.warning(
                    "Requested 'simulation' input evaluation method: "
                    "will discard all sampler outputs and keep only metrics.")
                sample_async = True
            elif method == "is":
                ise = ImportanceSamplingEstimator.create(self.io_context)
                self.reward_estimators.append(ise)
            elif method == "wis":
                wise = WeightedImportanceSamplingEstimator.create(
                    self.io_context)
                self.reward_estimators.append(wise)
            else:
                raise ValueError(
                    "Unknown evaluation method: {}".format(method))

        if sample_async:
            self.sampler = AsyncSampler(
                self,
                self.async_env,
                self.policy_map,
                policy_mapping_fn,
                self.preprocessors,
                self.filters,
                clip_rewards,
                rollout_fragment_length,
                self.callbacks,
                horizon=episode_horizon,
                pack=pack_episodes,
                tf_sess=self.tf_sess,
                clip_actions=clip_actions,
                blackhole_outputs="simulation" in input_evaluation,
                soft_horizon=soft_horizon,
                no_done_at_end=no_done_at_end,
                observation_fn=observation_fn)
            self.sampler.start()
        else:
            self.sampler = SyncSampler(
                self,
                self.async_env,
                self.policy_map,
                policy_mapping_fn,
                self.preprocessors,
                self.filters,
                clip_rewards,
                rollout_fragment_length,
                self.callbacks,
                horizon=episode_horizon,
                pack=pack_episodes,
                tf_sess=self.tf_sess,
                clip_actions=clip_actions,
                soft_horizon=soft_horizon,
                no_done_at_end=no_done_at_end,
                observation_fn=observation_fn)

        self.input_reader = input_creator(self.io_context)
        assert isinstance(self.input_reader, InputReader), self.input_reader
        self.output_writer = output_creator(self.io_context)
        assert isinstance(self.output_writer, OutputWriter), self.output_writer

        logger.debug(
            "Created rollout worker with env {} ({}), policies {}".format(
                self.async_env, self.env, self.policy_map))
示例#9
0
    def __init__(
            self,
            *,
            env_creator: Callable[[EnvContext], EnvType],
            validate_env: Optional[Callable[[EnvType, EnvContext],
                                            None]] = None,
            policy_spec: Union[type, Dict[
                str, Tuple[Optional[type], gym.Space, gym.Space,
                           PartialTrainerConfigDict]]] = None,
            policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
            policies_to_train: Optional[List[PolicyID]] = None,
            tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None,
            rollout_fragment_length: int = 100,
            batch_mode: str = "truncate_episodes",
            episode_horizon: int = None,
            preprocessor_pref: str = "deepmind",
            sample_async: bool = False,
            compress_observations: bool = False,
            num_envs: int = 1,
            observation_fn: "ObservationFunction" = None,
            observation_filter: str = "NoFilter",
            clip_rewards: bool = None,
            clip_actions: bool = True,
            env_config: EnvConfigDict = None,
            model_config: ModelConfigDict = None,
            policy_config: TrainerConfigDict = None,
            worker_index: int = 0,
            num_workers: int = 0,
            monitor_path: str = None,
            log_dir: str = None,
            log_level: str = None,
            callbacks: Type["DefaultCallbacks"] = None,
            input_creator: Callable[[
                IOContext
            ], InputReader] = lambda ioctx: ioctx.default_sampler_input(),
            input_evaluation: List[str] = frozenset([]),
            output_creator: Callable[
                [IOContext], OutputWriter] = lambda ioctx: NoopOutput(),
            remote_worker_envs: bool = False,
            remote_env_batch_wait_ms: int = 0,
            soft_horizon: bool = False,
            no_done_at_end: bool = False,
            seed: int = None,
            extra_python_environs: dict = None,
            fake_sampler: bool = False,
            spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
                                                  gym.spaces.Space]]] = None,
            policy: Union[type, Dict[
                str, Tuple[Optional[type], gym.Space, gym.Space,
                           PartialTrainerConfigDict]]] = None,
    ):
        """Initialize a rollout worker.

        Args:
            env_creator (Callable[[EnvContext], EnvType]): Function that
                returns a gym.Env given an EnvContext wrapped configuration.
            validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
                Optional callable to validate the generated environment (only
                on worker=0).
            policy_spec (Union[type, Dict[str, Tuple[Type[Policy], gym.Space,
                gym.Space, PartialTrainerConfigDict]]]): Either a Policy class
                or a dict of policy id strings to
                (Policy class, obs_space, action_space, config)-tuples. If a
                dict is specified, then we are in multi-agent mode and a
                policy_mapping_fn can also be set (if not, will map all agents
                to DEFAULT_POLICY_ID).
            policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): A
                callable that maps agent ids to policy ids in multi-agent mode.
                This function will be called each time a new agent appears in
                an episode, to bind that agent to a policy for the duration of
                the episode. If not provided, will map all agents to
                DEFAULT_POLICY_ID.
            policies_to_train (Optional[List[PolicyID]]): Optional list of
                policies to train, or None for all policies.
            tf_session_creator (Optional[Callable[[], tf1.Session]]): A
                function that returns a TF session. This is optional and only
                useful with TFPolicy.
            rollout_fragment_length (int): The target number of env transitions
                to include in each sample batch returned from this worker.
            batch_mode (str): One of the following batch modes:
                "truncate_episodes": Each call to sample() will return a batch
                    of at most `rollout_fragment_length * num_envs` in size.
                    The batch will be exactly
                    `rollout_fragment_length * num_envs` in size if
                    postprocessing does not change batch sizes. Episodes may be
                    truncated in order to meet this size requirement.
                "complete_episodes": Each call to sample() will return a batch
                    of at least `rollout_fragment_length * num_envs` in size.
                    Episodes will not be truncated, but multiple episodes may
                    be packed within one batch to meet the batch size. Note
                    that when `num_envs > 1`, episode steps will be buffered
                    until the episode completes, and hence batches may contain
                    significant amounts of off-policy data.
            episode_horizon (int): Whether to stop episodes at this horizon.
            preprocessor_pref (str): Whether to prefer RLlib preprocessors
                ("rllib") or deepmind ("deepmind") when applicable.
            sample_async (bool): Whether to compute samples asynchronously in
                the background, which improves throughput but can cause samples
                to be slightly off-policy.
            compress_observations (bool): If true, compress the observations.
                They can be decompressed with rllib/utils/compression.
            num_envs (int): If more than one, will create multiple envs
                and vectorize the computation of actions. This has no effect if
                if the env already implements VectorEnv.
            observation_fn (ObservationFunction): Optional multi-agent
                observation function.
            observation_filter (str): Name of observation filter to use.
            clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to
                experience postprocessing. Setting to None means clip for Atari
                only.
            clip_actions (bool): Whether to clip action values to the range
                specified by the policy action space.
            env_config (EnvConfigDict): Config to pass to the env creator.
            model_config (ModelConfigDict): Config to use when creating the
                policy model.
            policy_config (TrainerConfigDict): Config to pass to the policy.
                In the multi-agent case, this config will be merged with the
                per-policy configs specified by `policy_spec`.
            worker_index (int): For remote workers, this should be set to a
                non-zero and unique value. This index is passed to created envs
                through EnvContext so that envs can be configured per worker.
            num_workers (int): For remote workers, how many workers altogether
                have been created?
            monitor_path (str): Write out episode stats and videos to this
                directory if specified.
            log_dir (str): Directory where logs can be placed.
            log_level (str): Set the root log level on creation.
            callbacks (DefaultCallbacks): Custom training callbacks.
            input_creator (Callable[[IOContext], InputReader]): Function that
                returns an InputReader object for loading previous generated
                experiences.
            input_evaluation (List[str]): How to evaluate the policy
                performance. This only makes sense to set when the input is
                reading offline data. The possible values include:
                  - "is": the step-wise importance sampling estimator.
                  - "wis": the weighted step-wise is estimator.
                  - "simulation": run the environment in the background, but
                    use this data for evaluation only and never for learning.
            output_creator (Callable[[IOContext], OutputWriter]): Function that
                returns an OutputWriter object for saving generated
                experiences.
            remote_worker_envs (bool): If using num_envs > 1, whether to create
                those new envs in remote processes instead of in the current
                process. This adds overheads, but can make sense if your envs
            remote_env_batch_wait_ms (float): Timeout that remote workers
                are waiting when polling environments. 0 (continue when at
                least one env is ready) is a reasonable default, but optimal
                value could be obtained by measuring your environment
                step / reset and model inference perf.
            soft_horizon (bool): Calculate rewards but don't reset the
                environment when the horizon is hit.
            no_done_at_end (bool): Ignore the done=True at the end of the
                episode and instead record done=False.
            seed (int): Set the seed of both np and tf to this value to
                to ensure each remote worker has unique exploration behavior.
            extra_python_environs (dict): Extra python environments need to
                be set.
            fake_sampler (bool): Use a fake (inf speed) sampler for testing.
            spaces (Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
                gym.spaces.Space]]]): An optional space dict mapping policy IDs
                to (obs_space, action_space)-tuples. This is used in case no
                Env is created on this RolloutWorker.
            policy: Obsoleted arg. Use `policy_spec` instead.
        """
        # Deprecated arg.
        if policy is not None:
            deprecation_warning("policy", "policy_spec", error=False)
            policy_spec = policy
        assert policy_spec is not None, "Must provide `policy_spec` when " \
                                        "creating RolloutWorker!"

        self._original_kwargs: dict = locals().copy()
        del self._original_kwargs["self"]

        global _global_worker
        _global_worker = self

        # set extra environs first
        if extra_python_environs:
            for key, value in extra_python_environs.items():
                os.environ[key] = str(value)

        def gen_rollouts():
            while True:
                yield self.sample()

        ParallelIteratorWorker.__init__(self, gen_rollouts, False)

        policy_config: TrainerConfigDict = policy_config or {}
        if (tf1 and policy_config.get("framework") in ["tf2", "tfe"]
                # This eager check is necessary for certain all-framework tests
                # that use tf's eager_mode() context generator.
                and not tf1.executing_eagerly()):
            tf1.enable_eager_execution()

        if log_level:
            logging.getLogger("ray.rllib").setLevel(log_level)

        if worker_index > 1:
            disable_log_once_globally()  # only need 1 worker to log
        elif log_level == "DEBUG":
            enable_periodic_logging()

        env_context = EnvContext(env_config or {}, worker_index)
        self.env_context = env_context
        self.policy_config: TrainerConfigDict = policy_config
        if callbacks:
            self.callbacks: "DefaultCallbacks" = callbacks()
        else:
            from ray.rllib.agents.callbacks import DefaultCallbacks
            self.callbacks: "DefaultCallbacks" = DefaultCallbacks()
        self.worker_index: int = worker_index
        self.num_workers: int = num_workers
        model_config: ModelConfigDict = model_config or {}
        policy_mapping_fn = (policy_mapping_fn
                             or (lambda agent_id: DEFAULT_POLICY_ID))
        if not callable(policy_mapping_fn):
            raise ValueError("Policy mapping function not callable?")
        self.env_creator: Callable[[EnvContext], EnvType] = env_creator
        self.rollout_fragment_length: int = rollout_fragment_length * num_envs
        self.batch_mode: str = batch_mode
        self.compress_observations: bool = compress_observations
        self.preprocessing_enabled: bool = True
        self.last_batch: SampleBatchType = None
        self.global_vars: dict = None
        self.fake_sampler: bool = fake_sampler

        # No Env will be used in this particular worker (not needed).
        if worker_index == 0 and num_workers > 0 and \
                policy_config["create_env_on_driver"] is False:
            self.env = None
        # Create an env for this worker.
        else:
            self.env = _validate_env(env_creator(env_context))
            if validate_env is not None:
                validate_env(self.env, self.env_context)

            if isinstance(self.env, (BaseEnv, MultiAgentEnv)):

                def wrap(env):
                    return env  # we can't auto-wrap these env types

            elif is_atari(self.env) and \
                    not model_config.get("custom_preprocessor") and \
                    preprocessor_pref == "deepmind":

                # Deepmind wrappers already handle all preprocessing.
                self.preprocessing_enabled = False

                # If clip_rewards not explicitly set to False, switch it
                # on here (clip between -1.0 and 1.0).
                if clip_rewards is None:
                    clip_rewards = True

                def wrap(env):
                    env = wrap_deepmind(
                        env,
                        dim=model_config.get("dim"),
                        framestack=model_config.get("framestack"))
                    if monitor_path:
                        from gym import wrappers
                        env = wrappers.Monitor(env, monitor_path, resume=True)
                    return env
            else:

                def wrap(env):
                    if monitor_path:
                        from gym import wrappers
                        env = wrappers.Monitor(env, monitor_path, resume=True)
                    return env

            self.env: EnvType = wrap(self.env)

        def make_env(vector_index):
            return wrap(
                env_creator(
                    env_context.copy_with_overrides(
                        worker_index=worker_index,
                        vector_index=vector_index,
                        remote=remote_worker_envs)))

        self.make_env_fn = make_env

        self.tf_sess = None
        policy_dict = _validate_and_canonicalize(
            policy_spec, self.env, spaces=spaces)
        self.policies_to_train: List[PolicyID] = policies_to_train or list(
            policy_dict.keys())
        self.policy_map: Dict[PolicyID, Policy] = None
        self.preprocessors: Dict[PolicyID, Preprocessor] = None

        # set numpy and python seed
        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)
            if not hasattr(self.env, "seed"):
                logger.info("Env doesn't support env.seed(): {}".format(
                    self.env))
            else:
                self.env.seed(seed)
            try:
                assert torch is not None
                torch.manual_seed(seed)
            except AssertionError:
                logger.info("Could not seed torch")
        if _has_tensorflow_graph(policy_dict) and not (
                tf1 and tf1.executing_eagerly()):
            if not tf1:
                raise ImportError("Could not import tensorflow")
            with tf1.Graph().as_default():
                if tf_session_creator:
                    self.tf_sess = tf_session_creator()
                else:
                    self.tf_sess = tf1.Session(
                        config=tf1.ConfigProto(
                            gpu_options=tf1.GPUOptions(allow_growth=True)))
                with self.tf_sess.as_default():
                    # set graph-level seed
                    if seed is not None:
                        tf1.set_random_seed(seed)
                    self.policy_map, self.preprocessors = \
                        self._build_policy_map(policy_dict, policy_config)
        else:
            self.policy_map, self.preprocessors = self._build_policy_map(
                policy_dict, policy_config)

        if (ray.is_initialized()
                and ray.worker._mode() != ray.worker.LOCAL_MODE):
            # Check available number of GPUs
            if not ray.get_gpu_ids():
                logger.debug("Creating policy evaluation worker {}".format(
                    worker_index) +
                             " on CPU (please ignore any CUDA init errors)")
            elif (policy_config["framework"] in ["tf2", "tf", "tfe"] and
                  not tf.config.experimental.list_physical_devices("GPU")) or \
                    (policy_config["framework"] == "torch" and
                     not torch.cuda.is_available()):
                raise RuntimeError(
                    "GPUs were assigned to this worker by Ray, but "
                    "your DL framework ({}) reports GPU acceleration is "
                    "disabled. This could be due to a bad CUDA- or {} "
                    "installation.".format(policy_config["framework"],
                                           policy_config["framework"]))

        self.multiagent: bool = set(
            self.policy_map.keys()) != {DEFAULT_POLICY_ID}
        if self.multiagent and self.env is not None:
            if not ((isinstance(self.env, MultiAgentEnv)
                     or isinstance(self.env, ExternalMultiAgentEnv))
                    or isinstance(self.env, BaseEnv)):
                raise ValueError(
                    "Have multiple policies {}, but the env ".format(
                        self.policy_map) +
                    "{} is not a subclass of BaseEnv, MultiAgentEnv or "
                    "ExternalMultiAgentEnv?".format(self.env))

        self.filters: Dict[PolicyID, Filter] = {
            policy_id: get_filter(observation_filter,
                                  policy.observation_space.shape)
            for (policy_id, policy) in self.policy_map.items()
        }
        if self.worker_index == 0:
            logger.info("Built filter map: {}".format(self.filters))

        self.num_envs: int = num_envs

        if self.env is None:
            self.async_env = None
        elif "custom_vector_env" in policy_config:
            custom_vec_wrapper = policy_config["custom_vector_env"]
            self.async_env = custom_vec_wrapper(self.env)
        else:
            # Always use vector env for consistency even if num_envs = 1.
            self.async_env: BaseEnv = BaseEnv.to_base_env(
                self.env,
                make_env=make_env,
                num_envs=num_envs,
                remote_envs=remote_worker_envs,
                remote_env_batch_wait_ms=remote_env_batch_wait_ms)

        # `truncate_episodes`: Allow a batch to contain more than one episode
        # (fragments) and always make the batch `rollout_fragment_length`
        # long.
        if self.batch_mode == "truncate_episodes":
            pack = True
        # `complete_episodes`: Never cut episodes and sampler will return
        # exactly one (complete) episode per poll.
        elif self.batch_mode == "complete_episodes":
            rollout_fragment_length = float("inf")
            pack = False
        else:
            raise ValueError("Unsupported batch mode: {}".format(
                self.batch_mode))

        self.io_context: IOContext = IOContext(log_dir, policy_config,
                                               worker_index, self)
        self.reward_estimators: List[OffPolicyEstimator] = []
        for method in input_evaluation:
            if method == "simulation":
                logger.warning(
                    "Requested 'simulation' input evaluation method: "
                    "will discard all sampler outputs and keep only metrics.")
                sample_async = True
            elif method == "is":
                ise = ImportanceSamplingEstimator.create(self.io_context)
                self.reward_estimators.append(ise)
            elif method == "wis":
                wise = WeightedImportanceSamplingEstimator.create(
                    self.io_context)
                self.reward_estimators.append(wise)
            else:
                raise ValueError(
                    "Unknown evaluation method: {}".format(method))

        if self.env is None:
            self.sampler = None
        elif sample_async:
            self.sampler = AsyncSampler(
                worker=self,
                env=self.async_env,
                policies=self.policy_map,
                policy_mapping_fn=policy_mapping_fn,
                preprocessors=self.preprocessors,
                obs_filters=self.filters,
                clip_rewards=clip_rewards,
                rollout_fragment_length=rollout_fragment_length,
                callbacks=self.callbacks,
                horizon=episode_horizon,
                multiple_episodes_in_batch=pack,
                tf_sess=self.tf_sess,
                clip_actions=clip_actions,
                blackhole_outputs="simulation" in input_evaluation,
                soft_horizon=soft_horizon,
                no_done_at_end=no_done_at_end,
                observation_fn=observation_fn,
                _use_trajectory_view_api=policy_config.get(
                    "_use_trajectory_view_api", False))
            # Start the Sampler thread.
            self.sampler.start()
        else:
            self.sampler = SyncSampler(
                worker=self,
                env=self.async_env,
                policies=self.policy_map,
                policy_mapping_fn=policy_mapping_fn,
                preprocessors=self.preprocessors,
                obs_filters=self.filters,
                clip_rewards=clip_rewards,
                rollout_fragment_length=rollout_fragment_length,
                callbacks=self.callbacks,
                horizon=episode_horizon,
                multiple_episodes_in_batch=pack,
                tf_sess=self.tf_sess,
                clip_actions=clip_actions,
                soft_horizon=soft_horizon,
                no_done_at_end=no_done_at_end,
                observation_fn=observation_fn,
                _use_trajectory_view_api=policy_config.get(
                    "_use_trajectory_view_api", False))

        self.input_reader: InputReader = input_creator(self.io_context)
        self.output_writer: OutputWriter = output_creator(self.io_context)

        logger.debug(
            "Created rollout worker with env {} ({}), policies {}".format(
                self.async_env, self.env, self.policy_map))
示例#10
0
文件: trainer.py 项目: zommiommy/ray
 def make_callbacks():
     # Deprecation warning will be logged by DefaultCallbacks.
     return DefaultCallbacks(
         legacy_callbacks_dict=legacy_callbacks_dict)
示例#11
0
    def __init__(
        self,
        env,
        batch_size,
        trace_length,
        grid_size,
        exploiter_base_lr,
        exploiter_decay_lr_in_n_epi,
        exploiter_stop_training_after_n_epi,
        train_exploiter_n_times_per_epi,
    ):

        self.stop_training_after_n_epi = exploiter_stop_training_after_n_epi
        self.train_exploiter_n_times_per_epi = train_exploiter_n_times_per_epi

        # with tf.variable_scope(f"dqn_exploiter"):
        # Create the dqn policy for the exploiter
        dqn_config = copy.deepcopy(DEFAULT_CONFIG)
        dqn_config.update({
            "prioritized_replay":
            False,
            "double_q":
            True,
            "buffer_size":
            50000,
            "dueling":
            False,
            "learning_starts":
            min(int((batch_size - 1) * (trace_length - 1)), 64),
            "model": {
                "dim": grid_size,
                "conv_filters": [[16, [3, 3], 1], [32, [3, 3], 1]],
                # [Channel, [Kernel, Kernel], Stride]]
                # "fcnet_hiddens": [self.env.NUM_ACTIONS],
                "max_seq_len": trace_length,
                # Number of hidden layers for fully connected net
                "fcnet_hiddens": [64],
                # Nonlinearity for fully connected net (tanh, relu)
                "fcnet_activation": "relu",
            },
            # Update the replay buffer with this many samples at once. Note that
            # this setting applies per-worker if num_workers > 1.
            "rollout_fragment_length":
            1,
            # Size of a batch sampled from replay buffer for training. Note that
            # if async_updates is set, then each worker returns gradients for a
            # batch of this size.
            "train_batch_size":
            min(int((batch_size) * (trace_length)), 64),
            "explore":
            False,
            "grad_clip":
            1,
            "gamma":
            0.5,
            "lr":
            exploiter_base_lr,
            # Learning rate schedule
            "lr_schedule": [
                (0, exploiter_base_lr / 1000),
                (100, exploiter_base_lr),
                (exploiter_decay_lr_in_n_epi, exploiter_base_lr / 1e9),
            ],
            "sgd_momentum":
            0.9,
        })
        print("dqn_config", dqn_config)

        self.local_replay_buffer = LocalReplayBuffer(
            num_shards=1,
            learning_starts=dqn_config["learning_starts"],
            buffer_size=dqn_config["buffer_size"],
            replay_batch_size=dqn_config["train_batch_size"],
            replay_mode=dqn_config["multiagent"]["replay_mode"],
            replay_sequence_length=dqn_config["replay_sequence_length"],
        )

        # self.dqn_exploiter = DQNTFPolicy(obs_space=self.env.OBSERVATION_SPACE,
        #                                  action_space=self.env.ACTION_SPACE,
        #                                  config=dqn_config)

        def sgd_optimizer_dqn(policy, config) -> "torch.optim.Optimizer":
            return torch.optim.SGD(
                policy.q_func_vars,
                lr=policy.cur_lr,
                momentum=config["sgd_momentum"],
            )

        MyDQNTorchPolicy = DQNTorchPolicy.with_updates(
            optimizer_fn=sgd_optimizer_dqn)
        self.dqn_policy = MyDQNTorchPolicy(
            obs_space=env.OBSERVATION_SPACE,
            action_space=env.ACTION_SPACE,
            config=dqn_config,
        )

        self.multi_agent_batch_builders = [
            MultiAgentSampleBatchBuilder(
                policy_map={"player_blue": self.dqn_policy},
                clip_rewards=False,
                callbacks=DefaultCallbacks(),
            )
            # for _ in range(self.batch_size)
        ]
示例#12
0
    def _init_algo(self, fear, greed, n_players, use_simple_agents,
                   action_flip_prob, max_reward_strength, value_fn_variant,
                   cost_param, with_redistribution, n_planning_eps, env_config,
                   n_units, n_episodes, env, lr, gamma, weight_decay,
                   loss_mul_planner, mean_theta, with_planner, std_theta,
                   add_state_grad, planner_momentum, planner_clip_norm,
                   entropy_coeff, seed, normalize_planner,
                   no_weights_decay_planner, planner_std_theta_mul,
                   use_adam_optimizer, use_softmax_hot, report_every_n,
                   momentum, weight_decay_pl_mul, square_cost,
                   normalize_against_v, use_v_pl, normalize_against_vp,
                   normalize_vp_separated, use_rllib_polcy, **kwargs):

        if not use_simple_agents:
            speed_ratio = 5.0
            lr = lr / speed_ratio
            loss_mul_planner = loss_mul_planner * speed_ratio**2 / 2 / 2
            cost_param = cost_param * 1.5
            if n_units == 64:
                lr = lr / 8

        print("args not used:", kwargs)
        convert_a_to_one_hot = not use_simple_agents

        np.random.seed(seed)
        tf.set_random_seed(seed)
        random.seed(seed)

        if env == "FearGreedMatrix":
            env = define_greed_fear_matrix_game(fear=fear,
                                                greed=greed)(env_config)
        elif env == "CoinGame":
            env = CoinGame(env_config)

        env.seed(seed=seed)

        agents = create_population(env,
                                   n_players,
                                   use_simple_agents=use_simple_agents,
                                   n_units=n_units,
                                   lr=lr,
                                   gamma=gamma,
                                   weight_decay=weight_decay,
                                   mean_theta=mean_theta,
                                   std_theta=std_theta,
                                   entropy_coeff=entropy_coeff,
                                   use_adam_optimizer=use_adam_optimizer,
                                   momentum=momentum,
                                   use_rllib_polcy=use_rllib_polcy)
        np.random.seed(seed + 1)
        tf.set_random_seed(seed + 1)
        random.seed(seed + 1)

        if with_planner:
            std_theta = std_theta * planner_std_theta_mul
            weight_decay = weight_decay * weight_decay_pl_mul
            if no_weights_decay_planner:
                weight_decay = 0.0
            planning_agent = Planning_Agent(
                env,
                agents,
                learning_rate=lr,
                max_reward_strength=max_reward_strength,
                cost_param=cost_param,
                with_redistribution=with_redistribution,
                value_fn_variant=value_fn_variant,
                n_units=n_units,
                weight_decay=weight_decay,
                convert_a_to_one_hot=convert_a_to_one_hot,
                loss_mul_planner=loss_mul_planner,
                mean_theta=mean_theta,
                std_theta=std_theta,
                planner_clip_norm=planner_clip_norm,
                normalize_planner=normalize_planner,
                add_state_grad=add_state_grad,
                planner_momentum=planner_momentum,
                use_adam_optimizer=use_adam_optimizer,
                use_softmax_hot=use_softmax_hot,
                square_cost=square_cost,
                normalize_against_v=normalize_against_v,
                use_v_pl=use_v_pl,
                normalize_against_vp=normalize_against_vp,
                normalize_vp_separated=normalize_vp_separated)
        else:
            planning_agent = None

        self.epi_n = 0
        self.players = agents
        self.env = env
        self.action_flip_prob = action_flip_prob
        self.planning_agent = planning_agent
        self.with_redistribution = with_redistribution
        self.n_planning_eps = n_planning_eps
        self.player_ids = env.players_ids
        self.n_players = n_players
        self.n_episodes = n_episodes
        self.max_reward_strength = max_reward_strength
        self.cost_param = cost_param
        self.value_fn_variant = value_fn_variant
        self.fear = fear
        self.greed = greed
        self.report_every_n = report_every_n
        self.normalize_vp_separated = normalize_vp_separated
        self.use_rllib_polcy = use_rllib_polcy

        self.avg_planning_rewards_per_round = []
        self.episode_reward = []
        self.training_epi_avg_reward = []

        if self.use_rllib_polcy:
            self.multi_agent_batch_builder = MultiAgentSampleBatchBuilder(
                policy_map={
                    idx: player
                    for idx, player in enumerate(self.players)
                },
                clip_rewards=False,
                callbacks=DefaultCallbacks())
示例#13
0
def mycallback():
    return DefaultCallbacks()