def evaluate(self, eval_steps): """Evaluate the model and log metrics.""" _, rng = jax_random.split(self._rngs[0]) # TODO(lukaszkaiser): both model state and parameters by default include # the loss layer. Currently, we access the pure-model parameters by just # indexing, [0] here. But we should make it more explicit in a better API. weights = (self._opt_state[0][0], self._metrics_weights) state = (self._model_state[0], self._metrics_state) step_log(self._step, 'Evaluation') train_eval_slice = itertools.islice(self._train_eval_stream, eval_steps) train_metrics, _ = evaluation_round( train_eval_slice, self._metrics, self._jit_eval, weights, state, rng) log_metrics(train_metrics, self._train_sw, 'train', self._step, history=self._history) eval_slice = itertools.islice(self._eval_stream, eval_steps) eval_metrics, _ = evaluation_round( eval_slice, self._metrics, self._jit_eval, weights, state, rng) log_metrics(eval_metrics, self._eval_sw, 'eval', self._step, history=self._history) step_log(self._step, 'Finished evaluation') # Save the optimizer weights in the history for (name, value) in self.nontrainable_params.items(): self._history.append('train', 'training/{}'.format(name), self._step, value)
def evaluation_round(self, inputs_stream, weights, state, rng): """Evaluate. Args: inputs_stream: iterable of inputs to evaluate on. weights: weights for each f in eval_fns. state: state for each f in eval_fns. rng: random number generator. Returns: metrics: dict from metric name to metric value averaged over the number of inputs. state: end state for `predict_fn`. """ metrics = collections.defaultdict(float) count = 0 for inp in inputs_stream: count += 1 rng, subrng = jax_random.split(rng) metric_values, _ = self._jit_eval(inp, weights, state, subrng) try: metric_values = list(metric_values) except TypeError: metric_values = [float(metric_values)] for m, v in zip(self._metrics, metric_values): metrics[m] += v return {m: v / count for (m, v) in six.iteritems(metrics)}, state
def evaluation_round(inputs_stream, metric_names, eval_fn, weights, state, rng): """Evaluate. Args: inputs_stream: iterable of inputs to evaluate on. metric_names: list of strings, the order in which eval_fn returns metrics. eval_fn: metric function, which takes inputs and predictions (and weights, state, rng) and returns a tuple of scalar metric values. weights: weights for each f in eval_fns. state: state for each f in eval_fns. rng: random number generator. Returns: metrics: dict from metric name to metric value averaged over the number of inputs. state: end state for `predict_fn`. """ metrics = collections.defaultdict(float) count = 0 for inp in inputs_stream: count += 1 rng, subrng = jax_random.split(rng) metric_values, _ = eval_fn(inp, weights, state, subrng) try: metric_values = list(metric_values) except TypeError: metric_values = [float(metric_values)] for m, v in zip(metric_names, metric_values): metrics[m] += v return {m: v / count for (m, v) in six.iteritems(metrics)}, state
def single_update(i, opt_state, batch, state, rng): weights, slots, opt_params = opt_state rng, subrng = jax_random.split(rng[0]) grad_fn = backend.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) return optimizer.tree_update(i, grads, weights, slots, opt_params), state, [subrng]
def _consume_act(self, actions, predict_fn, rng): act_repr = self._action_serializer.serialize(actions) for (i, subrng) in enumerate(jax_random.split(rng, self._act_repr_length)): # Run the network to update the inference buffers, but ignore the result. predict_fn(self._last_symbols, rng=subrng) self._last_symbols = act_repr[:, i:(i + 1)]
def predict(x, weights, state, rng): """Predict function jited and parallelized as requested.""" res, state = backend.combine_devices(model_predict( backend.reshape_by_device(x, n_devices), weights, state, np.stack(jax_random.split(rng, n_devices)))) return layers.nested_map(lambda y: np.mean(y, axis=0), res), state
def mapped_update(i, opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots, opt_params = opt_state rng, subrng = jax_random.split(rng) grad_fn = backend.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) grads = jax.tree_util.tree_map( lambda g: backend.psum(g, 'batch'), grads) return optimizer.tree_update( i, grads, weights, slots, opt_params), state, subrng
def initialize_environments(self, batch_size=1, **kwargs): """Initializes the environments.""" self._steps = np.zeros(batch_size, dtype=np.int32) self._last_observations = np.full( (batch_size, ) + self._observation_space.shape, np.nan) self._last_symbols = np.zeros((batch_size, 1), dtype=np.int32) super(SerializedSequenceSimulatedEnvProblem, self).initialize_environments(batch_size=batch_size, **kwargs) (subrng, self._rng) = jax_random.split(self._rng) (_, self._init_model_state) = self._model_initialize( input_shapes=(batch_size, 1), input_dtype=np.int32, rng=subrng)
def _predict_obs(self, predict_fn, rng): obs_repr = np.zeros( (self._steps.shape[0], self._obs_repr_length), dtype=np.int32, ) for (i, subrng) in enumerate(jax_random.split(rng, self._obs_repr_length)): log_probs = predict_fn(self._last_symbols, rng=subrng) self._last_symbols = utils.gumbel_sample(log_probs) obs_repr[:, i] = self._last_symbols[:, 0] return self._obs_serializer.deserialize(obs_repr)
def test_computes(self): rng_key = jax_random.get_prng(0) hidden_size = (4, 4) output_size = 6 model = atari_cnn.FrameStackMLP(hidden_sizes=hidden_size, output_size=output_size) B, T, OBS = 2, 2, 3 # pylint: disable=invalid-name rng_key, key = jax_random.split(rng_key) _, _ = model.initialize_once((1, 1, OBS), onp.float32, key) x = onp.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS) y = model(x) self.assertEqual((B, T + 1, output_size), y.shape)
def _reset(self, indices): """Resets environments at the given indices. Args: indices: list of indices of underlying envs to call reset on. Returns: np.ndarray of batched observations from the reset envs. """ history = next(self._history_stream) (subrng, self._rng) = jax_random.split(self._rng) return self._reset_model(self._predict_fn, indices, history, subrng)
def test_computes(self): rng_key = jax_random.get_prng(0) hidden_size = (4, 4) output_size = 6 model = atari_cnn.AtariCnn(hidden_sizes=hidden_size, output_size=output_size) B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name rng_key, key = jax_random.split(rng_key) _, _ = model.initialize_once((1, 1) + OBS, onp.float32, key) x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( B, T + 1, *OBS) y = model(x) self.assertEqual((B, T + 1, output_size), y.shape)
def _step(self, actions): """Takes a step in all environments. Args: actions: (np.ndarray) with first dimension equal to the batch size. Returns: a tuple of batched raw observations, raw rewards, dones and infos. """ # Predict the next observation. (subrng, self._rng) = jax_random.split(self._rng) (observation, reward, done) = self._step_model(self._predict_fn, actions, subrng) return (observation, reward, done, {})
def mapped_update(i, opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots, opt_params = opt_state rng, subrng = jax_random.split(rng) grad_fn = backend.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just # the number of devices on this host machine, however psum goes over all # devices of all hosts (ex: a TPU pod) and we need to be averaging over all # of them. grads = jax.tree_util.tree_map( lambda g: backend.psum(g, 'batch') / backend.psum(1.0, 'batch'), grads) return optimizer.tree_update( i, grads, weights, slots, opt_params), state, subrng
def single_compute_loss(opt_state, batch, state, rng): rng, subrng = jax_random.split(rng[0]) loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) return loss_val, state, [subrng]
def mapped_compute_loss(opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. rng, subrng = jax_random.split(rng) loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) return loss_val, state, subrng
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, save_steps=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, mask_id=None, metrics=None): if backend.get_name() == 'jax': self._host_id = jax.host_id() self._host_count = jax.host_count() else: self._host_id = 0 self._host_count = 1 self._is_chief = (self._host_id == 0) if save_steps is None: save_steps = [] self._save_steps = save_steps self._should_save_checkpoints = should_save_checkpoints self._should_write_summaries = should_write_summaries self._has_weights = has_weights self._mask_id = mask_id self._metrics_dict = _METRICS if metrics is None else metrics loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id) device_count = backend.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count and backend.get_name() == 'jax': raise ValueError( 'JAX cannot work yet with n_devices != all devices: ' '%d != %d' % (n_devices, device_count)) self._n_devices = n_devices # Simple differential seeding of RNG across hosts by host_id and time. if random_seed is None and self._host_count > 1: _, random_seed = divmod( int(time.time() * 1e6) + int(self._host_id * 1e6), 2**32) rng = get_random_number_generator_and_set_seed(random_seed) inputs = inputs(n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, n_devices)) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) model_target_shape = tuple( tuple([None] + list(shape)) for shape in inputs.target_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) model_target_shape = tuple([None] + list(inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = backend.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(input_shape, input_dtype, target_shape, target_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. if not isinstance(input_dtype, (list, tuple)): input_dtype = [input_dtype] input_shape = [input_shape] if not isinstance(target_dtype, (list, tuple)): target_dtype = [target_dtype] target_shape = [target_shape] dtypes = list(input_dtype) + list(target_dtype) shapes = list(input_shape) + list(target_shape) if self._has_weights: shapes += list(target_shape) dtypes += [np.float32 for _ in target_dtype] input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = backend.jit( new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, model_target_shape, self._inputs.target_dtype, init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [ self._metrics_dict[m](has_weights=self._has_weights, mask_id=self._mask_id) for m in self._metrics ] metrics_in_parallel = tl.Branch(*metrics_layers) # TODO(lukaszkaiser): clean this up once layer API stabilizes. # For now, we need to initialize metric layers somehow, so here we go. # We assume that they do not have any parameters, so this is a dummy. dummy_shapes = ((1, 2), (1, ), (1, )) if self._has_weights else ((1, 2), (1, )) dummy_signature = tuple(ShapeDtype(s) for s in dummy_shapes) metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access m_weights, m_state = metrics_in_parallel.init(dummy_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None if output_dir is not None: self.reset(output_dir)