コード例 #1
0
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the
      observation, action, discount and reward specs.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverSample
    object indefinitely.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = (observation, action, reward, discount, observation)

    key = np.array(0, np.uint64)
    probability = np.array(1.0, np.float64)
    table_size = np.array(1, np.int64)
    priority = np.array(1.0, np.float64)
    info = reverb.SampleInfo(key=key,
                             probability=probability,
                             table_size=table_size,
                             priority=priority)
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
コード例 #2
0
ファイル: specs.py プロジェクト: zerocurve/acme
def make_environment_spec(environment: dm_env.Environment) -> EnvironmentSpec:
  """Returns an `EnvironmentSpec` describing values used by an environment."""
  return EnvironmentSpec(
      observations=environment.observation_spec(),
      actions=environment.action_spec(),
      rewards=environment.reward_spec(),
      discounts=environment.discount_spec())
コード例 #3
0
    def assert_env_reset(
        wrapped_env: dm_env.Environment,
        dm_env_timestep: dm_env.TimeStep,
        env_spec: EnvSpec,
    ) -> None:
        if env_spec.env_type == EnvType.Parallel:
            rewards_spec = wrapped_env.reward_spec()
            expected_rewards = {
                agent: convert_np_type(rewards_spec[agent].dtype, 0)
                for agent in wrapped_env.agents
            }

            discount_spec = wrapped_env.discount_spec()
            expected_discounts = {
                agent: convert_np_type(rewards_spec[agent].dtype, 1)
                for agent in wrapped_env.agents
            }

            Helpers.compare_dicts(
                dm_env_timestep.reward,
                expected_rewards,
            ), "Failed to reset reward."
            Helpers.compare_dicts(
                dm_env_timestep.discount,
                expected_discounts,
            ), "Failed to reset discount."

        elif env_spec.env_type == EnvType.Sequential:
            for agent in wrapped_env.agents:
                rewards_spec = wrapped_env.reward_spec()
                expected_reward = convert_np_type(rewards_spec[agent].dtype, 0)

                discount_spec = wrapped_env.discount_spec()
                expected_discount = convert_np_type(discount_spec[agent].dtype,
                                                    1)

                assert dm_env_timestep.reward == expected_reward and type(
                    dm_env_timestep.reward) == type(
                        expected_reward), "Failed to reset reward."
                assert dm_env_timestep.discount == expected_discount and type(
                    dm_env_timestep.discount) == type(
                        expected_discount), "Failed to reset discount."
コード例 #4
0
ファイル: specs.py プロジェクト: NetColby/DNRL
 def _make_ma_environment_spec(
         self,
         environment: dm_env.Environment) -> Dict[str, EnvironmentSpec]:
     """Returns an `EnvironmentSpec` describing values used by
     an environment for each agent."""
     specs = {}
     observation_specs = environment.observation_spec()
     action_specs = environment.action_spec()
     reward_specs = environment.reward_spec()
     discount_specs = environment.discount_spec()
     self.extra_specs = environment.extra_spec()
     for agent in environment.possible_agents:
         specs[agent] = EnvironmentSpec(
             observations=observation_specs[agent],
             actions=action_specs[agent],
             rewards=reward_specs[agent],
             discounts=discount_specs[agent],
         )
     return specs
コード例 #5
0
def transition_iterator(
    environment: dm_env.Environment
) -> Callable[[int], Iterator[types.Transition]]:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    A callable that given a batch_size returns an iterator with demonstrations.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    dataset = tf.data.Dataset.from_tensors(data).repeat()

    return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator()
コード例 #6
0
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverSample
    object indefinitely.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    info = tree.map_structure(
        lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype),
        reverb.SampleInfo.tf_dtypes())
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()