示例#1
0
    def test_get_from_first_device(self):
        sharded = {
            'a':
            jax.device_put_sharded(
                list(jnp.arange(16).reshape([jax.local_device_count(), 4])),
                jax.local_devices()),
            'b':
            jax.device_put_sharded(
                list(jnp.arange(8).reshape([jax.local_device_count(), 2])),
                jax.local_devices(),
            ),
        }

        want = {
            'a': jnp.arange(4),
            'b': jnp.arange(2),
        }

        # Get zeroth device content as DeviceArray.
        device_arrays = utils.get_from_first_device(sharded, as_numpy=False)
        jax.tree_map(lambda x: self.assertIsInstance(x, jax.xla.DeviceArray),
                     device_arrays)
        jax.tree_map(np.testing.assert_array_equal, want, device_arrays)

        # Get the zeroth device content as numpy arrays.
        numpy_arrays = utils.get_from_first_device(sharded, as_numpy=True)
        jax.tree_map(lambda x: self.assertIsInstance(x, np.ndarray),
                     numpy_arrays)
        jax.tree_map(np.testing.assert_array_equal, want, numpy_arrays)
示例#2
0
    def step(self):
        prefetching_split = next(self._prefetched_iterator)
        # The split_sample method passed to utils.sharded_prefetch specifies what
        # parts of the objects returned by the original iterator are kept in the
        # host and what parts are prefetched on-device.
        # In this case the host property of the prefetching split contains only the
        # replay keys and the device property is the prefetched full original
        # sample.
        keys, samples = prefetching_split.host, prefetching_split.device

        # Do a batch of SGD.
        start = time.time()
        self._state, priorities, metrics = self._sgd_step(self._state, samples)
        # Take metrics from first replica.
        metrics = utils.get_from_first_device(metrics)
        # Update our counts and record it.
        counts = self._counter.increment(steps=1,
                                         time_elapsed=time.time() - start)

        # Update priorities in replay.
        if self._replay_client:
            self._async_priority_updater.put((keys, priorities))

        # Attempt to write logs.
        self._logger.write({**metrics, **counts})
示例#3
0
  def step(self):
    """Does a step of SGD and logs the results."""
    samples = next(self._prefetched_iterator)

    # Do a batch of SGD.
    start = time.time()
    self._state, results = self._sgd_step(self._state, samples)

    # Take results from first replica.
    results = utils.get_from_first_device(results)

    # Update our counts and record it.
    counts = self._counter.increment(steps=1, time_elapsed=time.time() - start)

    # Snapshot and attempt to write logs.
    self._logger.write({**results, **counts})
示例#4
0
文件: learning.py 项目: deepmind/acme
    def step(self):
        # Get a batch of Transitions.
        transitions = next(self._prefetching_iterator)
        self._state, metrics = self._sgd_step(self._state, transitions)
        metrics = utils.get_from_first_device(metrics)

        # Compute elapsed time.
        timestamp = time.time()
        elapsed_time = timestamp - self._timestamp if self._timestamp else 0
        self._timestamp = timestamp

        # Increment counts and record the current time
        counts = self._counter.increment(steps=1, walltime=elapsed_time)

        # Attempts to write the logs.
        self._logger.write({**metrics, **counts})
示例#5
0
    def step(self):
        """Does a step of SGD and logs the results."""
        samples = next(self._iterator)

        # Do a batch of SGD.
        start = time.time()
        self._state, results = self._sgd_step(self._state, samples)

        # Take results from first replica.
        # NOTE: This measure will be a noisy estimate for the purposes of the logs
        # as it does not pmean over all devices.
        results = utils.get_from_first_device(results)

        # Update our counts and record them.
        counts = self._counter.increment(steps=1,
                                         time_elapsed=time.time() - start)

        # Maybe write logs.
        self._logger.write({**results, **counts})
示例#6
0
 def get_variables(self, names: Sequence[str]) -> List[networks_lib.Params]:
   # Return first replica of parameters.
   return [utils.get_from_first_device(self._state.params, as_numpy=False)]
示例#7
0
 def get_variables(self, names: List[str]) -> List[networks_lib.Params]:
     # Return first replica of parameters.
     return [utils.get_from_first_device(self._state.params)]
示例#8
0
文件: learning.py 项目: deepmind/acme
 def get_variables(self, names: List[str]) -> List[networks_lib.Params]:
     variables = {
         'policy': utils.get_from_first_device(self._state.policy_params),
     }
     return [variables[name] for name in names]
示例#9
0
 def test_get_from_first_device_fails_if_sda_not_provided(self):
     with self.assertRaises(ValueError):
         utils.get_from_first_device(
             {'a': np.arange(jax.local_device_count())})