def per_device_rngs(rng): # A function to JIT to not fragment memory. per_device_rng = fastmath.random.split(rng, self._n_devices) per_device_rngs = [ fastmath.random.split(r, self._n_layers) for r in per_device_rng] rngs = [jnp.stack([r[i] for r in per_device_rngs]) for i in range(self._n_layers)] return rngs
def predict(x, weights, state, rng): """Predict function JIT-compiled and parallelized as requested.""" res, state = _combine_devices( model_predict(reshape_by_device(x, n_devices), weights, state, jnp.stack(fastmath.random.split(rng, n_devices)))) if do_mean: return fastmath.nested_map(lambda y: jnp.mean(y, axis=0), res), state else: return res, state
def forward(self, inputs): """Returns the input activations, with added positional information.""" if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] if self._mode != 'train' or self._start_from_zero_prob >= 1.0: px = self.weights[:, :symbol_size, :] else: rng1, rng2 = fastmath.random.split(self.rng, 2) start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add) start_from_zero = fastmath.random.uniform( rng2, (), jnp.float32, 0, 1) start = jnp.where(start_from_zero < self._start_from_zero_prob, jnp.zeros((), dtype=jnp.int32), start) px = fastmath.dynamic_slice_in_dim(self.weights, start, symbol_size, axis=1) if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout keep = fastmath.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. state = self.state if inputs.shape[1] == 1: self.state = state + 1 return inputs + jnp.expand_dims(self.weights[0, state, :], 1) else: emb = [] for i in range(inputs.shape[0]): emb.append( fastmath.dynamic_slice_in_dim(self.weights[0], state[i], inputs.shape[1], axis=0)) self.state = state + inputs.shape[1] res = inputs + jnp.stack(emb, 0) return res
def _per_device_rngs(self, rng): """Create per-device RNGs from a given rng.""" # Splitting by device first to be identical with default trainer. per_device_rng = fastmath.random.split(rng, self._n_devices) per_device_rngs = [ fastmath.random.split(r, self._n_layers) for r in per_device_rng ] rngs = [ jnp.stack([r[i] for r in per_device_rngs]) for i in range(self._n_layers) ] return rngs
def one_step(self, batch, rng, step=0, learning_rate=None): """Updates loss layer weights/state and optimizer slots by running one step. Args: batch: Batch of data to use for optimization. rng: Random number generator to use for running this step. step: Which step of the training are we running. learning_rate: Learning rate to use instead of the default one. Returns: Tuple (loss, stats) with new values from one step of training, where stats are current optimizer statistics. """ # Update the learning rate if needed. if learning_rate is not None: self._opt_params['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) # batch needs to be split across the local devices -- the difference # between _for_n_devices and _reshape_by_device is that the latter splits # the batch dim to batch // n_devices, vs _for_n_devices # broadcasts/replicates to n_devices dimension. if self._n_devices > 1: batch = tl.reshape_by_device(batch, self._n_devices) # separate rng needs to be created for each device if self._n_devices > 1: rng = jnp.stack(fastmath.random.split(rng, self._n_devices)) weights = self._accelerated_loss_layer.weights state = self._accelerated_loss_layer.state if logging.vlog_is_on(1) and ((step & step - 1) == 0): # Prints every power of two, if debugging is enabled. logging.info('step[%d]', step) logging.info('opt_params[%s]', self._opt_params) logging.info('slots[%s]', self._slots) logging.info('weights[%s]', weights) logging.info('state[%s]', state) # NOTE: stats is a replicated dictionary of key to jnp arrays. (new_weights, new_slots), new_state, stats = self._accelerated_update_fn( (weights, self._slots), step, self._opt_params, batch, state, rng) if logging.vlog_is_on(1) and ((step & step - 1) == 0): logging.info('updated weights[%s]', new_weights) logging.info('stats[%s]', stats) self._accelerated_loss_layer.weights = new_weights self._accelerated_loss_layer.state = new_state self._slots = new_slots self._optimizer.slots = self._unreplicate(self._slots) return stats['loss'], stats
def one_step(self, batch, rng, step=0, learning_rate=None): """Runs one training step, to update model and optimizer parameters. Args: batch: Batch of labeled training data. rng: Single-use random number generator (JAX PRNG key). step: Training step number. learning_rate: Learning rate for the optimizer; if None, use optimizer's default learning rate. Returns: Tuple of (loss, optimizer_stats), with the newly computed loss and updated stats as reported by the optimizer. """ if learning_rate is not None: self._opt_params['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) # Split the batch across devices (batch_dim --> batch_dim // n_devices) # and create new rng's 1-1 with devices. if self._n_devices > 1: batch = tl.reshape_by_device(batch, self._n_devices) rng = jnp.stack(fastmath.random.split(rng, self._n_devices)) weights = self._accelerated_model_with_loss.weights state = self._accelerated_model_with_loss.state if logging.vlog_is_on(1) and ((step & step - 1) == 0): # Prints every power of two, if debugging is enabled. logging.info('step[%d]', step) logging.info('opt_params[%s]', self._opt_params) logging.info('slots[%s]', self._slots) logging.info('weights[%s]', weights) logging.info('state[%s]', state) # NOTE: stats is a replicated dictionary of key to jnp arrays. (new_weights, new_slots), new_state, stats = self._accelerated_update_fn( (weights, self._slots), step, self._opt_params, batch, state, rng) if logging.vlog_is_on(1) and ((step & step - 1) == 0): logging.info('updated weights[%s]', new_weights) logging.info('stats[%s]', stats) self._accelerated_model_with_loss.weights = new_weights self._accelerated_model_with_loss.state = new_state self._slots = new_slots self._optimizer.slots = self._unreplicate(self._slots) return stats['loss'], stats
def _run_one_step(self, weights, state, slots, opt_params): """Updates model weights/state and optimizer slots by running one step. Args: weights: Weights from model being trained. state: State (non-weight parameters) from model being trained. slots: Updatable weights for the optimizer in this training loop. opt_params: Dictionary of optimizer (hyper)parameters, e.g. learning rate, momentum. Returns: Tuple (loss, weights, state, slots, stats) with new values from one step of training, where stats are current optimizer statistics. """ step = self.step # Update the learning rate. opt_params['learning_rate'] = self._for_n_devices( self._task.learning_rate(step)) batch = self._task.next_batch() # batch needs to be split across the local devices -- the difference # between _for_n_devices and _reshape_by_device is that the latter splits # the batch dim to batch // n_devices, vs _for_n_devices # broadcasts/replicates to n_devices dimension. batch = self._reshape_by_device(batch) rng = self.new_rng() if self.n_devices > 1: rng = jnp.stack(jax_random.split(rng, self.n_devices)) if logging.vlog_is_on(1) and ((step & step - 1) == 0): # Prints every power of two, if debugging is enabled. logging.info('step[%d]', step) logging.info('opt_params[%s]', opt_params) logging.info('weights[%s]', weights) # NOTE: stats is a replicated dictionary of key to jnp arrays. (weights, slots), state, stats = ( self._accelerated_update_fn( (weights, slots), step, opt_params, batch, state, rng) ) if logging.vlog_is_on(1) and ((step & step - 1) == 0): logging.info('updated weights[%s]', weights) logging.info('stats[%s]', stats) return stats['loss'], weights, state, slots, stats
def forward(self, inputs): """Returns the input activations, with added positional information.""" if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] px = self.weights[:, :symbol_size, :] if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if fastmath.is_backend(fastmath.Backend.JAX): keep_prob = jax.lax.tie_in( x, jnp.full((), keep_prob, dtype=x.dtype)) keep = fastmath.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. state = self.state if inputs.shape[1] == 1: self.state = state + 1 return inputs + jnp.expand_dims(self.weights[0, state, :], 1) else: emb = [] for i in range(inputs.shape[0]): emb.append( jax.lax.dynamic_slice_in_dim(self.weights[0], state[i], inputs.shape[1], axis=0)) self.state = state + inputs.shape[1] return inputs + jnp.stack(emb, 0)
def _get_embeddings(self, t): """Get embeddings float[..., num_features]. Args: t: int[...] position (i.e. jnp.arange(..., jnp.int32)) Returns: embeddings: float[..., num_features] """ inter_bin_idx, intra_bin_idx = divmod(t, self._time_bin_length) bin_parity = inter_bin_idx % 2 bin_fraction = intra_bin_idx / self._time_bin_length embeddings = jnp.stack([ 1 / (1 + inter_bin_idx), bin_fraction, bin_parity.astype(jnp.float32), ], -1) assert embeddings.shape == t.shape + (self.num_features,), embeddings.shape return embeddings
def threefry_2x32_prange(key, lo: int = 0, hi: int = 2): """Splits a key into a stream of random keys. This uses the little-endian counter mode. Args: key: uint32[2] the key to split lo: the range to start extracting from hi: the range to stop extracting from Returns: keys: uint32[hi - lo, 2] the split keys """ if not (key.shape == (2, ) and key.dtype == jnp.uint32): raise ValueError('key must be uint32[2]') if not hi < 2**32: # You shouldn't really be using more than half the key size anyways. raise NotImplementedError('only 32-bit sizes are supported') # Create a 64-bit counter: i_lo = jnp.arange(lo, hi, dtype=jnp.uint32) i_hi = jnp.zeros_like(i_lo) i = jnp.stack([i_lo, i_hi], axis=-1) return threefry_2x32_prf(key, i)
def one_step(self, batch, rng, step=0, learning_rate=None): """Updates layers weights/state and optimizers slots by running one step. Args: batch: Batch of data to use for optimization. rng: Random number generator to use for running this step. step: Which step of the training are we running. learning_rate: Learning rate to use instead of the default one. Returns: Tuple (loss, stats) with new values from one step of training, where stats are all optimizer statistics. """ # Update the learning rate if needed. if learning_rate is not None: self._replicated_loss_opt_params['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) for (std_op, rev_ops) in self._replicated_opt_params: std_op['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) for op in rev_ops: op['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) # Batch needs to be split across the local devices -- the difference # between _for_n_devices and _reshape_by_device is that the latter splits # the batch dim to batch // n_devices, vs _for_n_devices # broadcasts/replicates to n_devices dimension. if self._n_devices > 1: batch = tl.reshape_by_device(batch, self._n_devices) step = jnp.repeat(step, self._n_devices) # Create separate rng for each device and layer. if self._n_devices == 1: rngs = fastmath.random.split(rng, self._n_layers) else: # Splitting by device first to be identical with default trainer. per_device_rng = fastmath.random.split(rng, self._n_devices) per_device_rngs = [ fastmath.random.split(r, self._n_layers) for r in per_device_rng] rngs = [jnp.stack([r[i] for r in per_device_rngs]) for i in range(self._n_layers)] # Group rngs by layer blocks. rng_blocks, rng_i = [], 0 for _, rev_layers in self._blocks: l = len(rev_layers) rng_blocks.append((rngs[rng_i], rngs[rng_i + 1: rng_i + l + 1])) rng_i += l + 1 # Run the layers forward upto the loss layer. stack = batch block_inputs_states = [] for i, (std_layer, rev_layers) in enumerate(self._blocks): acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i] std_rng, rev_rngs = rng_blocks[i] # Run the standard layer. stack, std_inputs, std_state = self._run_forward_standard( stack, std_layer, acc_std_layer_fn, std_rng) # Run the reversible layers and collect old and new states. stack, rev_old_states, rev_new_states = self._run_forward_reversible( stack, rev_layers, acc_rev_layer_fns, rev_rngs) block_inputs_states.append( ((std_inputs, std_state), (rev_old_states, rev_new_states))) # Run the loss layer forward and backward with optimizer update. loss_state = self._replicate(self._loss_layer.state) loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in) loss_stats, grad_stack = self._run_backward_standard( None, step, self._loss_layer, loss_inputs, loss_state, self._loss_fbo, rngs[-1], self._loss_opt, self._replicated_loss_opt_params) stats = [loss_stats] # Run the layers backward and run optimizer updates. for i in range(len(self._blocks) - 1, -1, -1): std_layer, rev_layers = self._blocks[i] (std_inputs, std_state), (rev_old_states, rev_new_states) = block_inputs_states[i] std_fbo, rev_fbos = self._fbos[i] std_opt, rev_opts = self._optimizers[i] std_rng, rev_rngs = rng_blocks[i] repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i] # Run reversible layers backward with optimizer update. stack, grad_stack, new_stats = self._run_backward_reversible( stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states, rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params) stats.extend(new_stats) # Run the standard layer forward-and-backward pass and optimizer update. std_layer_stats, grad_stack = self._run_backward_standard( grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng, std_opt, repl_std_opt_params) stack = cb.outputs_onto_stack( # Put layer inputs on the stack. std_inputs, stack, std_layer.n_out) stats.append(std_layer_stats) # Join stats from different optimizers into one. joint_stats = {} for i, stat in enumerate(reversed(stats)): for k, v in stat.items(): joint_stats[f'layer{i}/' + k] = v return stats[0]['loss'], joint_stats
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, _, self._n_devices, rng = ( training.init_host_and_devices(n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at if checkpoints_at is not None else [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') self._model_with_loss = tl.Serial(model_train, loss_fn) # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) shapes, dtypes = self._inputs.example_shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) def new_opt_state_and_model_state(rng): """Returns optimizer and model states suitable for training a model.""" weights, state = self._model_with_loss.init(input_signature, rng=rng) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if fastmath.is_backend(fastmath.Backend.JAX): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = ( fastmath.jit(new_opt_state_and_model_state)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state(init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m] for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel.rng = init_rng example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) self._input_signature = example_signature output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def one_step(self, batch, rng, step=0, learning_rate=None): """Updates layers weights/state and optimizers slots by running one step. Args: batch: Batch of data to use for optimization. rng: Random number generator to use for running this step. step: Which step of the training are we running. learning_rate: Learning rate to use instead of the default one. Returns: Tuple (loss, stats) with new values from one step of training, where stats are all optimizer statistics. """ # Update the learning rate if needed. if learning_rate is not None: for op in self._replicated_opt_params: op['learning_rate'] = tl.for_n_devices(learning_rate, self._n_devices) # Batch needs to be split across the local devices -- the difference # between _for_n_devices and _reshape_by_device is that the latter splits # the batch dim to batch // n_devices, vs _for_n_devices # broadcasts/replicates to n_devices dimension. if self._n_devices > 1: batch = tl.reshape_by_device(batch, self._n_devices) step = jnp.repeat(step, self._n_devices) # Separate rng needs to be created for each device. if self._n_devices == 1: rngs = fastmath.random.split(rng, len(self._reversible_layers) + 2) else: # Splitting by device first to be identical with default trainer. per_device_rng = fastmath.random.split(rng, self._n_devices) per_device_rngs = [ fastmath.random.split(r, len(self._reversible_layers) + 2) for r in per_device_rng ] rngs = [ jnp.stack([r[i] for r in per_device_rngs]) for i in range(len(self._reversible_layers) + 2) ] # Run the layers forward upto the loss layer. stack = batch # Run the first layer. first_layer_inputs = _inputs_from_stack(self._first_layer, stack) first_layer_weights = self._replicate(self._first_layer.weights) first_layer_state = self._replicate(self._first_layer.state) outputs, first_layer_new_state = self._accelerated_first_layer_fn( first_layer_inputs, first_layer_weights, first_layer_state, rngs[0]) stack = _outputs_onto_stack(self._first_layer, outputs, stack) # Run the reversible layers and collect old and new states. old_states, new_states = [], [] for i, layer in enumerate(self._reversible_layers): weights = self._replicate( layer.weights) # also copies cpu -> accelerator state = self._replicate(layer.state) old_states.append(state) inputs = _inputs_from_stack(layer, stack) outputs, new_state = self._accelerated_reversible_layers_fns[i]( inputs, weights, state, rngs[i + 1]) stack = _outputs_onto_stack(layer, outputs, stack) new_states.append(new_state) # Run the loss layer forward and backward with optimizer update. loss_weights = self._replicate(self._loss_layer.weights) loss_state = self._replicate(self._loss_layer.state) loss_inputs = _inputs_from_stack(self._loss_layer, stack) loss_slots = self._replicate(self._optimizers[-1].slots) new_weights, new_state, new_slots, grad_stack, loss_stats = self._loss_fbo( loss_inputs, loss_weights, loss_state, loss_slots, self._replicated_opt_params[-1], rngs[-1], step) stats = [loss_stats] self._loss_layer.weights = self._unreplicate( new_weights) # acceler. -> cpu self._loss_layer.state = self._unreplicate(new_state) self._optimizers[-1].slots = self._unreplicate(new_slots) # Run reversible layers backward with optimizer update. counter = -1 for layer, reverse_and_fbo, old_state, new_state, rng in reversed( list( zip(self._reversible_layers, self._reverse_and_fbos, old_states, new_states, rngs[1:-1]))): counter -= 1 # We are running backwards and reversing, so we get *outputs* from stack. outputs = _inputs_from_stack(layer, stack, layer.n_out) grads = _inputs_from_stack(layer, grad_stack, layer.n_out) slots = self._replicate(self._optimizers[counter].slots) opt_params = self._replicated_opt_params[counter] weights = self._replicate(layer.weights) # cpu -> accelerator new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo( outputs, weights, old_state, new_state, slots, opt_params, rng, step, grads) layer.weights = self._unreplicate( new_weights) # accelerator -> cpu layer.state = self._unreplicate(new_state) self._optimizers[counter].slots = self._unreplicate(new_slots) stats.append(layer_stats) stack = _outputs_onto_stack(layer, inputs, stack, layer.n_out, layer.n_in) grad_stack = _outputs_onto_stack(layer, grads, grad_stack, layer.n_out, layer.n_in) # Run the first layer forward-and-backward pass and optimizer update. grads = _inputs_from_stack(self._first_layer, grad_stack, self._first_layer.n_out) slots = self._replicate(self._optimizers[0].slots) new_weights, new_state, new_slots, first_layer_stats = self._first_fbo( first_layer_inputs, first_layer_weights, first_layer_new_state, slots, self._replicated_opt_params[0], rngs[0], step, grads) stats.append(first_layer_stats) self._first_layer.weights = self._unreplicate(new_weights) self._first_layer.state = self._unreplicate(new_state) self._optimizers[0].slots = self._unreplicate(new_slots) return stats[0]['loss'], stats