def step(self): """Does a step of SGD and logs the results.""" samples = next(self._prefetched_iterator) # Do a batch of SGD. start = time.time() self._state, results = self._sgd_step(self._state, samples) # Take results from first replica. results = utils.first_replica(results) # Update our counts and record it. counts = self._counter.increment(steps=1, time_elapsed=time.time() - start) # Snapshot and attempt to write logs. self._logger.write({**results, **counts})
def get_variables(self, names: List[str]) -> List[hk.Params]: # Return first replica of parameters. return [utils.first_replica(self._state.params)]