예제 #1
0
    def test_make_dataset_nested_specs(self):
        environment_spec = specs.EnvironmentSpec(observations={
            'obs_1':
            specs.Array((3, 64, 64), 'uint8'),
            'obs_2':
            specs.Array((10, ), 'int32')
        },
                                                 actions=specs.BoundedArray(
                                                     (),
                                                     'float32',
                                                     minimum=-1.,
                                                     maximum=1.),
                                                 rewards=specs.Array(
                                                     (), 'float32'),
                                                 discounts=specs.BoundedArray(
                                                     (),
                                                     'float32',
                                                     minimum=0.,
                                                     maximum=1.))

        dataset = reverb_dataset.make_dataset(
            client=self.tf_client, environment_spec=environment_spec)

        self.assertTrue(
            _check_specs(tuple(environment_spec), dataset.element_spec.data))
예제 #2
0
파일: test_utils.py 프로젝트: weileze/acme
def _numeric_to_spec(x: Union[float, int, np.ndarray]):
  if isinstance(x, np.ndarray):
    return specs.Array(shape=x.shape, dtype=x.dtype)
  elif isinstance(x, (float, int)):
    return specs.Array(shape=(), dtype=type(x))
  else:
    raise ValueError(f'Unsupported numeric: {type(x)}')
예제 #3
0
  def test_pmap_update_nested(self):
    local_device_count = jax.local_device_count()
    state = running_statistics.init_state({
        'a': specs.Array((5,), jnp.float32),
        'b': specs.Array((2,), jnp.float32)
    })

    x = {
        'a': (jnp.arange(15 * local_device_count,
                         dtype=jnp.float32)).reshape(local_device_count, 3, 5),
        'b': (jnp.arange(6 * local_device_count,
                         dtype=jnp.float32)).reshape(local_device_count, 3, 2),
    }

    devices = jax.local_devices()
    state = jax.device_put_replicated(state, devices)
    pmap_axis_name = 'i'
    state = jax.pmap(
        functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name),
        pmap_axis_name)(state, x)
    state = jax.pmap(
        functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name),
        pmap_axis_name)(state, x)
    normalized = jax.pmap(running_statistics.normalize)(x, state)

    mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized)
    std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
예제 #4
0
  def test_nested_normalize(self):
    state = running_statistics.init_state({
        'a': specs.Array((5,), jnp.float32),
        'b': specs.Array((2,), jnp.float32)
    })

    x1 = {
        'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5),
        'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2)
    }
    x2 = {
        'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5) + 20,
        'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + 8
    }
    x3 = {
        'a': jnp.arange(40, dtype=jnp.float32).reshape(4, 2, 5),
        'b': jnp.arange(16, dtype=jnp.float32).reshape(4, 2, 2)
    }

    state = update_and_validate(state, x1)
    state = update_and_validate(state, x2)
    state = update_and_validate(state, x3)
    normalized = running_statistics.normalize(x3, state)

    mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized)
    std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.zeros_like(x)),
        mean)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.ones_like(x)),
        std)
예제 #5
0
    def test_make_dataset_nested_specs(self):
        environment_spec = specs.EnvironmentSpec(observations={
            'obs_1':
            specs.Array((3, 64, 64), 'uint8'),
            'obs_2':
            specs.Array((10, ), 'int32')
        },
                                                 actions=specs.BoundedArray(
                                                     (),
                                                     'float32',
                                                     minimum=-1.,
                                                     maximum=1.),
                                                 rewards=specs.Array(
                                                     (), 'float32'),
                                                 discounts=specs.BoundedArray(
                                                     (),
                                                     'float32',
                                                     minimum=0.,
                                                     maximum=1.))

        dataset = reverb_dataset.make_dataset(
            client=self.tf_client, environment_spec=environment_spec)

        expected_spec = adders.Step(observation=environment_spec.observations,
                                    action=environment_spec.actions,
                                    reward=environment_spec.rewards,
                                    discount=environment_spec.discounts,
                                    start_of_episode=specs.Array(shape=(),
                                                                 dtype=bool),
                                    extras=())

        self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
예제 #6
0
def _make_fake_env() -> dm_env.Environment:
  env_spec = specs.EnvironmentSpec(
      observations=specs.Array(shape=(10, 5), dtype=np.float32),
      actions=specs.DiscreteArray(num_values=3),
      rewards=specs.Array(shape=(), dtype=np.float32),
      discounts=specs.BoundedArray(
          shape=(), dtype=np.float32, minimum=0., maximum=1.),
  )
  return fakes.Environment(env_spec, episode_length=10)
예제 #7
0
 def observation_spec(self) -> types.Observation:
     observation_specs = {}
     for agent in self.possible_agents:
         spec = self._environment.observation_spec()
         observation_specs[agent] = types.OLT(
             observation=specs.Array(spec["info_state"], np.float32),
             legal_actions=specs.Array(spec["legal_actions"], np.float32),
             terminal=specs.Array((1, ), np.float32),
         )
     return observation_specs
예제 #8
0
  def test_different_structure_normalize(self):
    spec = TestNestedSpec(
        a=specs.Array((5,), jnp.float32), b=specs.Array((2,), jnp.float32))
    state = running_statistics.init_state(spec)

    x = {
        'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5),
        'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2)
    }

    with self.assertRaises(TypeError):
      state = update_and_validate(state, x)
예제 #9
0
    def test_normalize_config(self):
        x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5)
        x_split = jnp.split(x, 5, axis=0)

        y = jnp.arange(160, dtype=jnp.float32).reshape(20, 2, 4)
        y_split = jnp.split(y, 5, axis=0)

        z = {'a': x, 'b': y}

        z_split = [{'a': xx, 'b': yy} for xx, yy in zip(x_split, y_split)]

        update = jax.jit(running_statistics.update,
                         static_argnames=('config', ))

        config = running_statistics.NestStatisticsConfig((('a', ), ))
        state = running_statistics.init_state({
            'a':
            specs.Array((5, ), jnp.float32),
            'b':
            specs.Array((4, ), jnp.float32)
        })
        # Test initialization from the first element.
        state = update(state, z_split[0], config=config)
        state = update(state, z_split[1], config=config)
        state = update(state, z_split[2], config=config)
        state = update(state, z_split[3], config=config)
        state = update(state, z_split[4], config=config)

        normalize = jax.jit(running_statistics.normalize)
        normalized = normalize(z, state)

        for key in normalized:
            mean = jnp.mean(normalized[key], axis=(0, 1))
            std = jnp.std(normalized[key], axis=(0, 1))
            if key == 'a':
                self.assert_allclose(
                    mean,
                    jnp.zeros_like(mean),
                    err_msg=
                    f'key:{key} mean:{mean} normalized:{normalized[key]}')
                self.assert_allclose(
                    std,
                    jnp.ones_like(std),
                    err_msg=f'key:{key} std:{std} normalized:{normalized[key]}'
                )
            else:
                assert key == 'b'
                np.testing.assert_array_equal(
                    normalized[key],
                    z[key],
                    err_msg=f'z:{z[key]} normalized:{normalized[key]}')
예제 #10
0
  def test_multiple_inputs_and_outputs(self):
    def transformation(aa, bb, cc):
      return (tf.concat([aa, bb, cc], axis=-1),
              tf.concat([bb, cc], axis=-1))

    model = tf2_utils.to_sonnet_module(transformation)
    dtype = np.float32
    input_spec = [specs.Array(shape=(2,), dtype=dtype),
                  specs.Array(shape=(3,), dtype=dtype),
                  specs.Array(shape=(4,), dtype=dtype)]
    expected_output_spec = (tf.TensorSpec(shape=(9,), dtype=dtype),
                            tf.TensorSpec(shape=(7,), dtype=dtype))
    output_spec = tf2_utils.create_variables(model, input_spec)
    self.assertEqual(model.variables, ())
    self.assertEqual(output_spec, expected_output_spec)
예제 #11
0
  def test_make_dataset_with_variable_length_instances(self):
    """Dataset with variable length instances should have shapes with None."""
    environment_spec = specs.EnvironmentSpec(
        observations=specs.Array((0, 64, 64), 'uint8'),
        actions=specs.BoundedArray((), 'float32', minimum=-1., maximum=1.),
        rewards=specs.Array((), 'float32'),
        discounts=specs.BoundedArray((), 'float32', minimum=0., maximum=1.))

    dataset = reverb_dataset.make_dataset(
        server_address=self.server_address,
        environment_spec=environment_spec,
        convert_zero_size_to_none=True)

    self.assertSequenceEqual(dataset.element_spec.data[0].shape.as_list(),
                             [None, 64, 64])
예제 #12
0
 def test_none_output(self):
   model = tf2_utils.to_sonnet_module(lambda x: None)
   input_spec = specs.Array(shape=(10,), dtype=np.float32)
   expected_spec = None
   output_spec = tf2_utils.create_variables(model, [input_spec])
   self.assertEqual(model.variables, ())
   self.assertEqual(output_spec, expected_spec)
예제 #13
0
 def test_scalar_output(self):
   model = tf2_utils.to_sonnet_module(tf.reduce_sum)
   input_spec = specs.Array(shape=(10,), dtype=np.float32)
   expected_spec = tf.TensorSpec(shape=(), dtype=tf.float32)
   output_spec = tf2_utils.create_variables(model, [input_spec])
   self.assertEqual(model.variables, ())
   self.assertEqual(output_spec, expected_spec)
예제 #14
0
  def test_rnn_snapshot(self):
    """Test that snapshotter correctly calls saves/restores snapshots on RNNs."""
    # Create a test network.
    net = snt.LSTM(10)
    spec = specs.Array([10], dtype=np.float32)
    tf2_utils.create_variables(net, [spec])

    # Test that if you add some postprocessing without rerunning
    # create_variables, it still works.
    wrapped_net = snt.DeepRNN([net, lambda x: x])

    for net1 in [net, wrapped_net]:
      # Save the test network.
      directory = self.get_tempdir()
      objects_to_save = {'net': net1}
      snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory)
      snapshotter.save()

      # Reload the test network.
      net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net'))
      inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

      with tf.GradientTape() as tape:
        outputs1, next_state1 = net1(inputs, net1.initial_state(1))
        loss1 = tf.math.reduce_sum(outputs1)
        grads1 = tape.gradient(loss1, net1.trainable_variables)

      with tf.GradientTape() as tape:
        outputs2, next_state2 = net2(inputs, net2.initial_state(1))
        loss2 = tf.math.reduce_sum(outputs2)
        grads2 = tape.gradient(loss2, net2.trainable_variables)

      assert np.allclose(outputs1, outputs2)
      assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2))
      assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
예제 #15
0
파일: fakes.py 프로젝트: hyunjay/acme
    def __init__(self,
                 *,
                 num_actions: int = 1,
                 num_observations: int = 1,
                 action_dtype=np.int32,
                 obs_dtype=np.int32,
                 obs_shape: Sequence[int] = (),
                 discount_spec: Optional[types.NestedSpec] = None,
                 reward_spec: Optional[types.NestedSpec] = None,
                 **kwargs):
        """Initialize the environment."""
        if reward_spec is None:
            reward_spec = specs.Array((), np.float32)

        if discount_spec is None:
            discount_spec = specs.BoundedArray((), np.float32, 0.0, 1.0)

        actions = specs.DiscreteArray(num_actions, dtype=action_dtype)
        observations = specs.BoundedArray(shape=obs_shape,
                                          dtype=obs_dtype,
                                          minimum=obs_dtype(0),
                                          maximum=obs_dtype(num_observations -
                                                            1))

        super().__init__(spec=specs.EnvironmentSpec(observations=observations,
                                                    actions=actions,
                                                    rewards=reward_spec,
                                                    discounts=discount_spec),
                         **kwargs)
예제 #16
0
파일: base.py 프로젝트: wzyxwqx/acme
    def signature(cls,
                  environment_spec: specs.EnvironmentSpec,
                  extras_spec: types.NestedSpec = ()):
        """This is a helper method for generating signatures for Reverb tables.

    Signatures are useful for validating data types and shapes, see Reverb's
    documentation for details on how they are used.

    Args:
      environment_spec: A `specs.EnvironmentSpec` whose fields are nested
        structures with leaf nodes that have `.shape` and `.dtype` attributes.
        This should come from the environment that will be used to generate
        the data inserted into the Reverb table.
      extras_spec: A nested structure with leaf nodes that have `.shape` and
        `.dtype` attributes. The structure (and shapes/dtypes) of this must
        be the same as the `extras` passed into `ReverbAdder.add`.

    Returns:
      A `Step` whose leaf nodes are `tf.TensorSpec` objects.
    """
        spec_step = Step(observation=environment_spec.observations,
                         action=environment_spec.actions,
                         reward=environment_spec.rewards,
                         discount=environment_spec.discounts,
                         start_of_episode=specs.Array(shape=(), dtype=bool),
                         extras=extras_spec)
        return tree.map_structure_with_path(spec_like_to_tensor_spec,
                                            spec_step)
예제 #17
0
    def test_make_dataset_with_sequence_length_and_batch_size(self):
        sequence_length = 6
        batch_size = 4
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            sequence_length=sequence_length)

        def make_tensor_spec(spec):
            return tf.TensorSpec(shape=(
                batch_size,
                sequence_length,
            ) + spec.shape,
                                 dtype=spec.dtype)

        expected_spec = tree.map_structure(make_tensor_spec, environment_spec)

        expected_spec = adders.Step(observation=expected_spec.observations,
                                    action=expected_spec.actions,
                                    reward=expected_spec.rewards,
                                    discount=expected_spec.discounts,
                                    start_of_episode=specs.Array(
                                        shape=(batch_size, sequence_length),
                                        dtype=bool),
                                    extras=())

        self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
예제 #18
0
    def test_snapshot_distribution(self):
        """Test that snapshotter correctly calls saves/restores snapshots."""
        # Create a test network.
        net1 = snt.Sequential([
            networks.LayerNormMLP([10, 10]),
            networks.MultivariateNormalDiagHead(1)
        ])
        spec = specs.Array([10], dtype=np.float32)
        tf2_utils.create_variables(net1, [spec])

        # Save the test network.
        directory = self.get_tempdir()
        objects_to_save = {'net': net1}
        snapshotter = tf2_savers.Snapshotter(objects_to_save,
                                             directory=directory)
        snapshotter.save()

        # Reload the test network.
        net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net'))
        inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

        with tf.GradientTape() as tape:
            dist1 = net1(inputs)
            loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance())
            grads1 = tape.gradient(loss1, net1.trainable_variables)

        with tf.GradientTape() as tape:
            dist2 = net2(inputs)
            loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance())
            grads2 = tape.gradient(loss2, net2.trainable_variables)

        assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
예제 #19
0
    def __init__(
        self,
        rank: int,
        config_factory: List[Callable[[], WobConfig]],
        keep_pristine: bool,
        verbose_level: Optional[int] = 0,
    ):
        # Set environment configurations
        self.rank = rank

        # Set task-info. (For multi-task env, set this in reset_task())
        self._configs = config_factory
        self._keep_pristine = keep_pristine
        assert all([callable(config) for config in self._configs])

        # Sample at least one valid task config (default seed = 0).
        self.reset_task(task_index=0)

        # params
        self.env_done = False
        self.step_count = 0
        # TODO: Do we need action mask?
        self.action_mask = False

        self._verbose_level = verbose_level

        # Set observation specs.
        observation = self.reset()
        self._observation_specs = {
          k: specs.Array(shape=v.shape, dtype=v.dtype, name=k) \
          for k, v in observation.items()
        }
예제 #20
0
  def test_init_normalize(self):
    state = running_statistics.init_state(specs.Array((5,), jnp.float32))

    x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5)
    normalized = running_statistics.normalize(x, state)

    self.assert_allclose(normalized, x)
예제 #21
0
파일: mppi_test.py 프로젝트: deepmind/acme
 def setUp(self):
     super().setUp()
     self.state_dims = 8
     self.action_dims = 4
     self.params = {
         'world': jnp.ones((3, )),
         'policy': jnp.ones((3, )),
         'value': jnp.ones((3, ))
     }
     self.env_spec = specs.EnvironmentSpec(
         observations=specs.Array(shape=(self.state_dims, ), dtype=float),
         actions=specs.Array(shape=(self.action_dims, ), dtype=float),
         rewards=specs.Array(shape=(1, ), dtype=float, name='reward'),
         discounts=specs.BoundedArray(shape=(),
                                      dtype=float,
                                      minimum=0.,
                                      maximum=1.,
                                      name='discount'))
예제 #22
0
 def observation_spec(self) -> types.NestedSpec:
   observation_spec = self._environment.observation_spec()
   observation_spec.update({
     'trial_remaining_steps': specs.Array(
       shape=(),
       dtype=np.int32,
       name='trial_remaining_steps')
   })
   return observation_spec
예제 #23
0
 def test_feedforward(self, recurrent: bool):
   model = snt.Linear(42)
   if recurrent:
     model = snt.DeepRNN([model])
   input_spec = specs.Array(shape=(10,), dtype=np.float32)
   tf2_utils.create_variables(model, [input_spec])
   variables: Sequence[tf.Variable] = model.variables
   shapes = [v.shape.as_list() for v in variables]
   self.assertSequenceEqual(shapes, [[42], [10, 42]])
예제 #24
0
  def test_int_not_normalized(self):
    state = running_statistics.init_state(specs.Array((), jnp.int32))

    x = jnp.arange(5, dtype=jnp.int32)

    state = update_and_validate(state, x)
    normalized = running_statistics.normalize(x, state)

    np.testing.assert_array_equal(normalized, x)
예제 #25
0
def _broadcast_specs(*args: acme_specs.Array) -> acme_specs.Array:
    """Like np.broadcast, but for specs.Array.
    Args:
      *args: one or more specs.Array instances.
    Returns:
      A specs.Array with the broadcasted shape and dtype of the specs in *args.
    """
    bc_info = np.broadcast(*tuple(a.generate_value() for a in args))
    dtype = np.result_type(*tuple(a.dtype for a in args))
    return acme_specs.Array(shape=bc_info.shape, dtype=dtype)
예제 #26
0
  def test_validation(self):
    state = running_statistics.init_state(specs.Array((1, 2, 3), jnp.float32))

    x = jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3)
    with self.assertRaises(AssertionError):
      update_and_validate(state, x)

    x = jnp.arange(3, dtype=jnp.float32).reshape(1, 1, 3)
    with self.assertRaises(AssertionError):
      update_and_validate(state, x)
예제 #27
0
  def test_output_spec_feedforward(self, recurrent: bool):
    input_spec = specs.Array(shape=(10,), dtype=np.float32)
    model = snt.Linear(42)
    expected_spec = tf.TensorSpec(shape=(42,), dtype=tf.float32)
    if recurrent:
      model = snt.DeepRNN([model])
      expected_spec = (expected_spec, ())

    output_spec = tf2_utils.create_variables(model, [input_spec])
    self.assertEqual(output_spec, expected_spec)
예제 #28
0
 def test_multiple_outputs(self):
   model = PolicyValueHead(42)
   input_spec = specs.Array(shape=(10,), dtype=np.float32)
   expected_spec = (tf.TensorSpec(shape=(42,), dtype=tf.float32),
                    tf.TensorSpec(shape=(), dtype=tf.float32))
   output_spec = tf2_utils.create_variables(model, [input_spec])
   variables: Sequence[tf.Variable] = model.variables
   shapes = [v.shape.as_list() for v in variables]
   self.assertSequenceEqual(shapes, [[42], [10, 42], [1], [10, 1]])
   self.assertSequenceEqual(output_spec, expected_spec)
예제 #29
0
 def observation_spec(self) -> types.Observation:
     return {
         agent: types.OLT(
             observation=_convert_to_spec(
                 self.observation_space["observation"]),
             legal_actions=_convert_to_spec(
                 self.observation_space["action_mask"]),
             terminal=specs.Array((1, ), np.float32),
         )
         for agent in self._possible_agents
     }
예제 #30
0
 def observation_spec(self) -> types.Observation:
     observation_specs = {}
     for agent in self.possible_agents:
         observation_specs[agent] = types.OLT(
             observation=_convert_to_spec(
                 self._environment.observation_spaces[agent]),
             legal_actions=_convert_to_spec(
                 self._environment.action_spaces[agent]),
             terminal=specs.Array((1, ), np.float32),
         )
     return observation_specs