def significance_weights(mask): # (repr,) -> (batch, length, repr) # significance = [0, 1, 2] significance = serializer.significance_map assert significance.shape[0] * 2 == mask.shape[2] significance = jnp.repeat(significance[jnp.newaxis, ...], repeats=2, axis=0) # significance = [0, 1, 2, 0, 1, 2] significance = jnp.concatenate(significance, axis=0) assert significance.shape[0] == mask.shape[2] # significance = batch_size * [0, 1, 2, 0, 1, 2] significance = jnp.repeat( significance[np.newaxis, ...], repeats=mask.shape[0], axis=0) # significance = batch_size * [0, 1, 2, 0, 1, 2] * mask.shape[1] significance = jnp.repeat( significance[..., jnp.newaxis], repeats=mask.shape[1], axis=2) # significance = batch_size * mask.shape[1] * [0, 1, 2, 0, 1, 2] significance = jnp.swapaxes(significance, 1, 2) assert significance.shape == mask.shape sig_weights = mask * decay ** significance batch_size = sig_weights.shape[0] mask_size = sig_weights.shape[1]*sig_weights.shape[2] # TODO(henrykm): Make sure that the reshape works in the desired way sig_weights = np.reshape(sig_weights, (batch_size, mask_size)) # Alternatively we also can do something like # sig_weights = jnp.concatenate(sig_weights, axis=1) # sig_weights = jnp.concatenate(sig_weights, axis=0) # sig_weights = jnp.reshape(sig_weights, (batch_size, mask_size)) return sig_weights
def _funnel_mask(self, batch_size, keys_len, queries_len, funnel_factor, is_upsampling): """Creates a funnel mask. This function based on keys/queries lengths creates a triangle mask that prevents tokens from attending to positions following it. If funnel_factor is not equal to 1 due to funnel upsampling or downsampling it adjusts created mask for funnel attention by repeating each element funnel_factor times. This is because after funnel layer one token attends to funnel_factor different tokens in downsampling. During upsampling on the other hand funnel_factor tokens are attending to single token before upsampling. Args: batch_size: batch size. keys_len: keys length. queries_len: queries length. funnel_factor: funnel factor. is_upsampling: upsampling if set to True. Returns: Funnel mask. """ if self._mode == 'predict': # We cannot generate more than one token because it contradicts # all autoregressive properties assert queries_len == 1 mask = jnp.arange( self._max_len) <= (self.state // self._total_kv_pooling) mask = jnp.reshape(mask, (1, 1, 1, self._max_len)) mask = jnp.repeat(mask, batch_size, axis=0) self.state += self._n_raw_tokens_generated return mask if funnel_factor != 1: if not is_upsampling: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-1) else: mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-2) else: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
def _run_value_model(self, obs, use_eval_model=True): """Runs value model.""" n_devices = fastmath.device_count() if use_eval_model: weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng else: # TODO(henrykm): this strangely looking solution address the problem that # value_batches_stream calls _run_value_model _once_ before # the trainer is initialized. try: weights = tl.for_n_devices(self._value_trainer.model_weights, n_devices) state = tl.for_n_devices(self._value_trainer.model_state, n_devices) rng = self._value_trainer._rng # pylint: disable=protected-access except AttributeError: weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng # TODO(henrykm): the line below fails on TPU with the error # ValueError: Number of devices (8) does not evenly divide batch size (1). obs_batch = obs.shape[0] if n_devices > obs_batch: obs = jnp.repeat(obs, int(n_devices / obs_batch), axis=0) values, _ = self._value_eval_jit(obs, weights, state, rng) values = values[:obs_batch] values *= self._value_network_scale return values
def significance_weights(mask): # (repr,) -> (batch, length, repr) # significance = [0, 1, 2] significance = serializer.significance_map assert significance.shape[0] == mask.shape[2] # significance = batch_size * [0, 1, 2] significance = jnp.repeat( significance[np.newaxis, ...], repeats=mask.shape[0], axis=0) # significance = batch_size * [0, 1, 2] * mask.shape[1] significance = jnp.repeat( significance[..., jnp.newaxis], repeats=mask.shape[1], axis=2) # significance = batch_size * mask.shape[1] * [0, 1, 2] significance = jnp.swapaxes(significance, 1, 2) assert significance.shape == mask.shape sig_weights = mask * decay ** significance return sig_weights
def representation_mask(mask): # mask shape (batch_size,4) mask = jnp.amax(mask, axis=tuple(range(2, mask.ndim))) # mask shape (batch_size,4) mask = jnp.repeat(mask[..., jnp.newaxis], repeats=serializer.representation_length, axis=2) # mask shape (batch_size,4,representation_length) return mask
def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, is_upsampling): """Funnel mask. Args: batch_size: batch size. keys_len: keys length. queries_len: queries length. funnel_factor: funnel factor. is_upsampling: True or False. Returns: funnel mask. This function based on keys/queries lengths creates a triangle mask that prevents tokens from attending to positions following it. If funnel_factor is not equal to 1 due to funnel upsampling or downsampling it adjusts created mask for funnel attention by repeating each element funnel_factor times. This is because after funnel layer one token attends to funnel_factor different tokens in downsampling. During upsampling on the other hand funnel_factor tokens are attending to single token before upsampling. """ if funnel_factor != 1: if not is_upsampling: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-1) else: mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-2) else: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
def _run_value_model(self, obs): """Runs value model.""" n_devices = fastmath.device_count() weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng # TODO(henrykm): the line below fails on TPU with the error # ValueError: Number of devices (8) does not evenly divide batch size (1). obs_batch = obs.shape[0] if n_devices > obs_batch: obs = jnp.repeat(obs, int(n_devices / obs_batch), axis=0) values, _ = self._value_eval_jit(obs, weights, state, rng) values = values[:obs_batch] values *= self._value_network_scale return values
def one_step(self, batch, rng, step=0, learning_rate=None): """Updates layers weights/state and optimizers slots by running one step. Args: batch: Batch of data to use for optimization. rng: Random number generator to use for running this step. step: Which step of the training are we running. learning_rate: Learning rate to use instead of the default one. Returns: Tuple (loss, stats) with new values from one step of training, where stats are all optimizer statistics. """ # Update the learning rate if needed. if learning_rate is not None: self._replicated_loss_opt_params['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) for (std_op, rev_ops) in self._replicated_opt_params: std_op['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) for op in rev_ops: op['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) # Batch needs to be split across the local devices -- the difference # between _for_n_devices and _reshape_by_device is that the latter splits # the batch dim to batch // n_devices, vs _for_n_devices # broadcasts/replicates to n_devices dimension. if self._n_devices > 1: batch = tl.reshape_by_device(batch, self._n_devices) step = jnp.repeat(step, self._n_devices) # Create separate rng for each device and layer. if self._n_devices == 1: rngs = fastmath.random.split(rng, self._n_layers) else: # Splitting by device first to be identical with default trainer. per_device_rng = fastmath.random.split(rng, self._n_devices) per_device_rngs = [ fastmath.random.split(r, self._n_layers) for r in per_device_rng] rngs = [jnp.stack([r[i] for r in per_device_rngs]) for i in range(self._n_layers)] # Group rngs by layer blocks. rng_blocks, rng_i = [], 0 for _, rev_layers in self._blocks: l = len(rev_layers) rng_blocks.append((rngs[rng_i], rngs[rng_i + 1: rng_i + l + 1])) rng_i += l + 1 # Run the layers forward upto the loss layer. stack = batch block_inputs_states = [] for i, (std_layer, rev_layers) in enumerate(self._blocks): acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i] std_rng, rev_rngs = rng_blocks[i] # Run the standard layer. stack, std_inputs, std_state = self._run_forward_standard( stack, std_layer, acc_std_layer_fn, std_rng) # Run the reversible layers and collect old and new states. stack, rev_old_states, rev_new_states = self._run_forward_reversible( stack, rev_layers, acc_rev_layer_fns, rev_rngs) block_inputs_states.append( ((std_inputs, std_state), (rev_old_states, rev_new_states))) # Run the loss layer forward and backward with optimizer update. loss_state = self._replicate(self._loss_layer.state) loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in) loss_stats, grad_stack = self._run_backward_standard( None, step, self._loss_layer, loss_inputs, loss_state, self._loss_fbo, rngs[-1], self._loss_opt, self._replicated_loss_opt_params) stats = [loss_stats] # Run the layers backward and run optimizer updates. for i in range(len(self._blocks) - 1, -1, -1): std_layer, rev_layers = self._blocks[i] (std_inputs, std_state), (rev_old_states, rev_new_states) = block_inputs_states[i] std_fbo, rev_fbos = self._fbos[i] std_opt, rev_opts = self._optimizers[i] std_rng, rev_rngs = rng_blocks[i] repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i] # Run reversible layers backward with optimizer update. stack, grad_stack, new_stats = self._run_backward_reversible( stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states, rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params) stats.extend(new_stats) # Run the standard layer forward-and-backward pass and optimizer update. std_layer_stats, grad_stack = self._run_backward_standard( grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng, std_opt, repl_std_opt_params) stack = cb.outputs_onto_stack( # Put layer inputs on the stack. std_inputs, stack, std_layer.n_out) stats.append(std_layer_stats) # Join stats from different optimizers into one. joint_stats = {} for i, stat in enumerate(reversed(stats)): for k, v in stat.items(): joint_stats[f'layer{i}/' + k] = v return stats[0]['loss'], joint_stats
def multi_device_update_fn( weights_and_slots, step, opt_params, batch, state, rng): # Need to replicate step to n_devices leading dimension. return _multi_device_update_fn(weights_and_slots, jnp.repeat(step, n_devices), opt_params, batch, state, rng)
def update(weights_and_slots, i, opt_params, batch, state, rng): return mapped_update(weights_and_slots, np.repeat(i, n_devices), opt_params, batch, state, rng)
def one_step(self, batch, rng, step=0, learning_rate=None): """Updates layers weights/state and optimizers slots by running one step. Args: batch: Batch of data to use for optimization. rng: Random number generator to use for running this step. step: Which step of the training are we running. learning_rate: Learning rate to use instead of the default one. Returns: Tuple (loss, stats) with new values from one step of training, where stats are all optimizer statistics. """ # Update the learning rate if needed. if learning_rate is not None: for op in self._replicated_opt_params: op['learning_rate'] = tl.for_n_devices(learning_rate, self._n_devices) # Batch needs to be split across the local devices -- the difference # between _for_n_devices and _reshape_by_device is that the latter splits # the batch dim to batch // n_devices, vs _for_n_devices # broadcasts/replicates to n_devices dimension. if self._n_devices > 1: batch = tl.reshape_by_device(batch, self._n_devices) step = jnp.repeat(step, self._n_devices) # Separate rng needs to be created for each device. if self._n_devices == 1: rngs = fastmath.random.split(rng, len(self._reversible_layers) + 2) else: # Splitting by device first to be identical with default trainer. per_device_rng = fastmath.random.split(rng, self._n_devices) per_device_rngs = [ fastmath.random.split(r, len(self._reversible_layers) + 2) for r in per_device_rng ] rngs = [ jnp.stack([r[i] for r in per_device_rngs]) for i in range(len(self._reversible_layers) + 2) ] # Run the layers forward upto the loss layer. stack = batch # Run the first layer. first_layer_inputs = _inputs_from_stack(self._first_layer, stack) first_layer_weights = self._replicate(self._first_layer.weights) first_layer_state = self._replicate(self._first_layer.state) outputs, first_layer_new_state = self._accelerated_first_layer_fn( first_layer_inputs, first_layer_weights, first_layer_state, rngs[0]) stack = _outputs_onto_stack(self._first_layer, outputs, stack) # Run the reversible layers and collect old and new states. old_states, new_states = [], [] for i, layer in enumerate(self._reversible_layers): weights = self._replicate( layer.weights) # also copies cpu -> accelerator state = self._replicate(layer.state) old_states.append(state) inputs = _inputs_from_stack(layer, stack) outputs, new_state = self._accelerated_reversible_layers_fns[i]( inputs, weights, state, rngs[i + 1]) stack = _outputs_onto_stack(layer, outputs, stack) new_states.append(new_state) # Run the loss layer forward and backward with optimizer update. loss_weights = self._replicate(self._loss_layer.weights) loss_state = self._replicate(self._loss_layer.state) loss_inputs = _inputs_from_stack(self._loss_layer, stack) loss_slots = self._replicate(self._optimizers[-1].slots) new_weights, new_state, new_slots, grad_stack, loss_stats = self._loss_fbo( loss_inputs, loss_weights, loss_state, loss_slots, self._replicated_opt_params[-1], rngs[-1], step) stats = [loss_stats] self._loss_layer.weights = self._unreplicate( new_weights) # acceler. -> cpu self._loss_layer.state = self._unreplicate(new_state) self._optimizers[-1].slots = self._unreplicate(new_slots) # Run reversible layers backward with optimizer update. counter = -1 for layer, reverse_and_fbo, old_state, new_state, rng in reversed( list( zip(self._reversible_layers, self._reverse_and_fbos, old_states, new_states, rngs[1:-1]))): counter -= 1 # We are running backwards and reversing, so we get *outputs* from stack. outputs = _inputs_from_stack(layer, stack, layer.n_out) grads = _inputs_from_stack(layer, grad_stack, layer.n_out) slots = self._replicate(self._optimizers[counter].slots) opt_params = self._replicated_opt_params[counter] weights = self._replicate(layer.weights) # cpu -> accelerator new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo( outputs, weights, old_state, new_state, slots, opt_params, rng, step, grads) layer.weights = self._unreplicate( new_weights) # accelerator -> cpu layer.state = self._unreplicate(new_state) self._optimizers[counter].slots = self._unreplicate(new_slots) stats.append(layer_stats) stack = _outputs_onto_stack(layer, inputs, stack, layer.n_out, layer.n_in) grad_stack = _outputs_onto_stack(layer, grads, grad_stack, layer.n_out, layer.n_in) # Run the first layer forward-and-backward pass and optimizer update. grads = _inputs_from_stack(self._first_layer, grad_stack, self._first_layer.n_out) slots = self._replicate(self._optimizers[0].slots) new_weights, new_state, new_slots, first_layer_stats = self._first_fbo( first_layer_inputs, first_layer_weights, first_layer_new_state, slots, self._replicated_opt_params[0], rngs[0], step, grads) stats.append(first_layer_stats) self._first_layer.weights = self._unreplicate(new_weights) self._first_layer.state = self._unreplicate(new_state) self._optimizers[0].slots = self._unreplicate(new_slots) return stats[0]['loss'], stats
def NaiveUpsampling(shorten_factor, d_model, *args, **kwargs): # pylint: disable = unused-argument return core.Fn('Repeat', lambda x: jnp.repeat(x, shorten_factor, axis=1))