def test_get_from_first_device(self): sharded = { 'a': jax.device_put_sharded( list(jnp.arange(16).reshape([jax.local_device_count(), 4])), jax.local_devices()), 'b': jax.device_put_sharded( list(jnp.arange(8).reshape([jax.local_device_count(), 2])), jax.local_devices(), ), } want = { 'a': jnp.arange(4), 'b': jnp.arange(2), } # Get zeroth device content as DeviceArray. device_arrays = utils.get_from_first_device(sharded, as_numpy=False) jax.tree_map(lambda x: self.assertIsInstance(x, jax.xla.DeviceArray), device_arrays) jax.tree_map(np.testing.assert_array_equal, want, device_arrays) # Get the zeroth device content as numpy arrays. numpy_arrays = utils.get_from_first_device(sharded, as_numpy=True) jax.tree_map(lambda x: self.assertIsInstance(x, np.ndarray), numpy_arrays) jax.tree_map(np.testing.assert_array_equal, want, numpy_arrays)
def step(self): prefetching_split = next(self._prefetched_iterator) # The split_sample method passed to utils.sharded_prefetch specifies what # parts of the objects returned by the original iterator are kept in the # host and what parts are prefetched on-device. # In this case the host property of the prefetching split contains only the # replay keys and the device property is the prefetched full original # sample. keys, samples = prefetching_split.host, prefetching_split.device # Do a batch of SGD. start = time.time() self._state, priorities, metrics = self._sgd_step(self._state, samples) # Take metrics from first replica. metrics = utils.get_from_first_device(metrics) # Update our counts and record it. counts = self._counter.increment(steps=1, time_elapsed=time.time() - start) # Update priorities in replay. if self._replay_client: self._async_priority_updater.put((keys, priorities)) # Attempt to write logs. self._logger.write({**metrics, **counts})
def step(self): """Does a step of SGD and logs the results.""" samples = next(self._prefetched_iterator) # Do a batch of SGD. start = time.time() self._state, results = self._sgd_step(self._state, samples) # Take results from first replica. results = utils.get_from_first_device(results) # Update our counts and record it. counts = self._counter.increment(steps=1, time_elapsed=time.time() - start) # Snapshot and attempt to write logs. self._logger.write({**results, **counts})
def step(self): # Get a batch of Transitions. transitions = next(self._prefetching_iterator) self._state, metrics = self._sgd_step(self._state, transitions) metrics = utils.get_from_first_device(metrics) # Compute elapsed time. timestamp = time.time() elapsed_time = timestamp - self._timestamp if self._timestamp else 0 self._timestamp = timestamp # Increment counts and record the current time counts = self._counter.increment(steps=1, walltime=elapsed_time) # Attempts to write the logs. self._logger.write({**metrics, **counts})
def step(self): """Does a step of SGD and logs the results.""" samples = next(self._iterator) # Do a batch of SGD. start = time.time() self._state, results = self._sgd_step(self._state, samples) # Take results from first replica. # NOTE: This measure will be a noisy estimate for the purposes of the logs # as it does not pmean over all devices. results = utils.get_from_first_device(results) # Update our counts and record them. counts = self._counter.increment(steps=1, time_elapsed=time.time() - start) # Maybe write logs. self._logger.write({**results, **counts})
def get_variables(self, names: Sequence[str]) -> List[networks_lib.Params]: # Return first replica of parameters. return [utils.get_from_first_device(self._state.params, as_numpy=False)]
def get_variables(self, names: List[str]) -> List[networks_lib.Params]: # Return first replica of parameters. return [utils.get_from_first_device(self._state.params)]
def get_variables(self, names: List[str]) -> List[networks_lib.Params]: variables = { 'policy': utils.get_from_first_device(self._state.policy_params), } return [variables[name] for name in names]
def test_get_from_first_device_fails_if_sda_not_provided(self): with self.assertRaises(ValueError): utils.get_from_first_device( {'a': np.arange(jax.local_device_count())})