def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), **kwargs): rng = kwargs.pop('rng', None) rngs = (None,) * self._n_layers if rng is not None: rngs = random.split(rng, self._n_layers) def call_compute_residual(x, weights): res = self.compute_residual(x, weights=weights, state=state[0], rng=rngs[0], **kwargs) return res assert len(ct) == self.n_preserve + 1 ct = (ct[0], ct[0]) + ct[1:] stack_with_residual, vjpfun = jax.vjp( call_compute_residual, output, weights[0]) reconstructed_x = self.subtract_top( stack_with_residual, weights=weights[-1], state=state[-1], rng=rngs[-1], **kwargs) x_ct, residual_weights_ct = vjpfun(ct) assert not jax.tree_util.tree_leaves(weights[-1]) add_top_weights_ct = weights[-1] return reconstructed_x, (x_ct, [residual_weights_ct, add_top_weights_ct])
def run_policy( policy_and_value_net_apply, observations, lengths, weights, state, rng, action_space, ): """Runs the policy network.""" # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive # action sampling. (B, T_plus_1) = observations.shape[:2] # pylint: disable=invalid-name dummy_actions = onp.zeros((B, T_plus_1 - 1) + action_space.shape, dtype=action_space.dtype) policy_input = (observations, dummy_actions) (rng, subrng) = trax_random.split(rng) (log_probs, value_preds) = policy_and_value_net_apply(policy_input, weights=weights, state=state, rng=subrng) # We need the log_probs of those actions that correspond to the last actual # time-step. index = lengths - 1 # Since we want to index using lengths. log_probs = log_probs[np.arange(B), index] value_preds = value_preds[np.arange(B), index] return (log_probs, value_preds, state, rng)
def single_update(i, opt_state, batch, state, rng): weights, slots, opt_params = opt_state rng, subrng = jax_random.split(rng[0]) grad_fn = math.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) return optimizer.tree_update(i, grads, weights, slots, opt_params), state, [subrng]
def evaluation_round(self, inputs_stream, weights, state, rng): """Evaluate. Args: inputs_stream: iterable of inputs to evaluate on. weights: weights for each f in eval_fns. state: state for each f in eval_fns. rng: random number generator. Returns: metrics: dict from metric name to metric value averaged over the number of inputs. state: end state for `predict_fn`. """ metrics = collections.defaultdict(float) count = 0 for inp in inputs_stream: count += 1 rng, subrng = jax_random.split(rng) metric_values, _ = self._jit_eval(inp, weights, state, subrng) try: metric_values = list(metric_values) except TypeError: metric_values = [float(metric_values)] for m, v in zip(self._metrics, metric_values): metrics[m] += v return {m: v / count for (m, v) in six.iteritems(metrics)}, state
def _consume_act(self, actions, predict_fn, rng): act_repr = self._action_serializer.serialize(actions) for (i, subrng) in enumerate(jax_random.split(rng, self._act_repr_length)): # Run the network to update the inference buffers, but ignore the result. predict_fn(self._last_symbols, rng=subrng) self._last_symbols = act_repr[:, i:(i + 1)]
def _predict_obs(self, predict_fn, rng): obs_repr = np.zeros( (self._steps.shape[0], self._obs_repr_length), dtype=np.int32, ) for (i, subrng) in enumerate(jax_random.split(rng, self._obs_repr_length)): log_probs = predict_fn(self._last_symbols, rng=subrng) self._last_symbols = utils.gumbel_sample(log_probs) obs_repr[:, i] = self._last_symbols[:, 0] return self._obs_serializer.deserialize(obs_repr)
def _reset(self, indices): """Resets environments at the given indices. Args: indices: list of indices of underlying envs to call reset on. Returns: np.ndarray of batched observations from the reset envs. """ history = next(self._history_stream) (subrng, self._rng) = jax_random.split(self._rng) return self._reset_model(self._predict_fn, indices, history, subrng)
def reverse(self, output, weights=(), state=(), new_state=(), rng=None): reconstructed_x = output rngs = (None,) * self._n_layers if rng is not None: rngs = random.split(rng, self._n_layers) # Note that self.sublayers aligns exactly with self.reverse_layers in # terms of parameter and rng usage, so no re-ordering is required. for layer, p, s, ns, rng in zip( self.reverse_layers, weights, state, new_state, rngs): reconstructed_x = layer(reconstructed_x, weights=p, state=s, new_state=ns, rng=rng) return reconstructed_x
def forward_and_backward(self, inputs, ct, state, new_state, rng=None): # Simultaneous forward pass and backprop through the attention mechanism. qkv = inputs[:3] passthrough = inputs[3:] out_ct = ct[0] passthrough_ct = ct[1:] if rng is not None: # Adjust RNG to match the forward pass. rng = random.split(rng, self._n_layers)[0] out, qkv_ct = self.attention.forward_and_backward( qkv, out_ct, state[0], new_state[0], rng) return (out,) + passthrough, qkv_ct + passthrough_ct
def _step(self, actions): """Takes a step in all environments. Args: actions: (np.ndarray) with first dimension equal to the batch size. Returns: a tuple of batched raw observations, raw rewards, dones and infos. """ # Predict the next observation. (subrng, self._rng) = jax_random.split(self._rng) (observation, reward, done) = self._step_model(self._predict_fn, actions, subrng) return (observation, reward, done, {})
def mapped_update(i, opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots, opt_params = opt_state rng, subrng = jax_random.split(rng) grad_fn = math.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just # the number of devices on this host machine, however psum goes over all # devices of all hosts (ex: a TPU pod) and we need to be averaging over all # of them. grads = jax.tree_util.tree_map( lambda g: math.psum(g, 'batch') / math.psum(1.0, 'batch'), grads) return optimizer.tree_update(i, grads, weights, slots, opt_params), state, subrng
def run_policy_all_timesteps( policy_and_value_net_apply, observations, weights, state, rng, action_space, ): """Runs the policy network.""" # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive # action sampling. (B, T_plus_1) = observations.shape[:2] # pylint: disable=invalid-name dummy_actions = onp.zeros((B, T_plus_1 - 1) + action_space.shape, dtype=action_space.dtype) policy_input = (observations, dummy_actions) (rng, subrng) = trax_random.split(rng) (log_probs, value_preds) = policy_and_value_net_apply(policy_input, weights=weights, state=state, rng=subrng) return log_probs, value_preds, state, rng
def evaluate(self, n_eval_steps): """Evaluate the model and log metrics.""" _, rng = jax_random.split(self._rngs[0]) # TODO(lukaszkaiser): both model state and parameters by default include # the loss layer. Currently, we access the pure-model parameters by just # indexing, [0] here. But we should make it more explicit in a better API. weights = (self._opt_state[0][0], self._metrics_weights) state = (self._model_state[0], self._metrics_state) self.log_step('Evaluation') train_eval_slice = itertools.islice(self._train_eval_stream, n_eval_steps) train_metrics, _ = self.evaluation_round(train_eval_slice, weights, state, rng) self.log_metrics(train_metrics, self._train_sw, 'train') eval_slice = itertools.islice(self._eval_stream, n_eval_steps) eval_metrics, _ = self.evaluation_round(eval_slice, weights, state, rng) self.log_metrics(eval_metrics, self._eval_sw, 'eval') self.log_step('Finished evaluation') # Save the optimizer weights in the history for (name, value) in self.nontrainable_params.items(): self._history.append('train', 'training/{}'.format(name), self._step, value)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, id_to_mask=None, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, self._n_devices, rng = (self._init_host_and_devices( n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._has_weights = has_weights self._id_to_mask = id_to_mask self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask) # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) # If the inputs are a tuple/list, add [None] (batch) to each element. if self._inputs.input_shape and isinstance(self._inputs.input_shape[0], (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(self._inputs.input_shape)) # Same for targets. if self._inputs.target_shape and isinstance( self._inputs.target_shape[0], (list, tuple)): model_target_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.target_shape) else: model_target_shape = tuple([None] + list(self._inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = math.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = math.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(shape_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. shapes, dtypes = shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = math.jit( new_opt_state_and_model_state, static_argnums=(0, )) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda self._inputs.example_shape_dtype, init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [ self._metrics_dict[m](has_weights=self._has_weights, id_to_mask=self._id_to_mask) for m in self._metrics ] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def predict(x, weights, state, rng): """Predict function jited and parallelized as requested.""" res, state = _combine_devices( model_predict(_reshape_by_device(x, n_devices), weights, state, np.stack(jax_random.split(rng, n_devices)))) return math.nested_map(lambda y: np.mean(y, axis=0), res), state
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, self._n_devices, rng = (self._init_host_and_devices( n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest if metrics is not None: self._metrics_dict = metrics else: self._metrics_dict = _DEFAULT_METRICS self._metrics_dict['loss'] = loss_fn self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') self._model_with_loss = tl.Serial(model_train, loss_fn) # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) shapes, dtypes = self._inputs.example_shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) def new_opt_state_and_model_state(rng): """Returns optimizer and model states suitable for training a model.""" weights, state = self._model_with_loss.init(input_signature, rng=rng) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if math.backend_name() == 'jax': # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = math.jit( new_opt_state_and_model_state) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state(init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m] for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel.rng = init_rng example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) self._input_signature = example_signature output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def single_compute_loss(opt_state, batch, state, rng): rng, subrng = jax_random.split(rng[0]) loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) return loss_val, state, [subrng]
def mapped_compute_loss(opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. rng, subrng = jax_random.split(rng) loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) return loss_val, state, subrng
def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), **kwargs): rng = kwargs.pop('rng', None) rngs = (None,) * self._n_layers if rng is not None: rngs = random.split(rng, self._n_layers) # Forward pass through self.pre_attention, while preparing for # later backprop. def call_pre_attention(x, weights): res = self.pre_attention(x, weights=weights, state=state[0], rng=rngs[0], **kwargs) return res stack, pre_attention_vjpfun = jax.vjp(call_pre_attention, output, weights[0]) # Backprop through adding the residual assert len(ct) == 2 ct = saved_ct = (ct[0], ct[0], ct[1]) # Backprop through self.post_attention with respect to the inputs only def call_post_attention(x): res = self.post_attention(x, weights=weights[2], state=state[2], rng=rngs[2], **kwargs) return res # Note: these are *not* the actual inputs to self.post_attention. # If self.post_attention is not linear, we will get incorrect gradients. dummy_inputs = (stack[-3], stack[-2], stack[-1]) _, post_attention_vjpfun = jax.vjp(call_post_attention, dummy_inputs) (ct,) = post_attention_vjpfun(ct) # Simultaneous forward pass and backprop through the attention mechanism stack, ct = self.attention.forward_and_backward( stack, ct, rng=rngs[1], state=state[1], new_state=new_state[1], **kwargs) assert not jax.tree_util.tree_leaves(weights[1]) attention_weights_ct = weights[1] # This is valid when weights is empty. # Backprop through self.pre_attention x_ct, pre_attention_weights_ct = pre_attention_vjpfun(ct) # Forward pass for self.post_attention, and backprop with respect to the # parameters only def call_post_attention2(weights): res = self.post_attention(stack, weights=weights, state=state[2], rng=rngs[2], **kwargs) return res stack, post_attention_vjpfun = jax.vjp(call_post_attention2, weights[2]) (post_attention_weights_ct,) = post_attention_vjpfun(saved_ct) # Forward pass through subtracting the residual reconstructed_x = self.subtract_top( stack, weights=weights[-1], state=state[-1], rng=rngs[-1], **kwargs) assert not jax.tree_util.tree_leaves(weights[-1]) add_top_weights_ct = weights[-1] weights_ct = [ pre_attention_weights_ct, attention_weights_ct, post_attention_weights_ct, add_top_weights_ct, ] return reconstructed_x, (x_ct, weights_ct)