def forward(self, inputs): """Executes this layer as part of a forward pass through the model.""" weights = self.weights[0] if isinstance(inputs, list): inputs = tuple(inputs) # so that inputs structure matches outputs n_carry = self._n_carry def scannable_fn(x, carry_and_state): # pylint: disable=invalid-name carry, state, i = carry_and_state x_and_carry = x + carry if n_carry > 0 else x rng = fastmath.random.fold_in(self.rng, i) res, new_state = self.sublayer.pure_fn( x_and_carry, weights, state, rng, use_cache=True) if n_carry > 0: return (res[:-n_carry], (res[-n_carry:], new_state, i+1)) else: return (res, ([], new_state, i+1)) if n_carry > 0: xs = inputs[:-n_carry] # Split input stack into inputs and carry. xs_carry = inputs[-n_carry:] if self._mode == 'predict' and self._state[1] is not (): # pylint: disable=literal-comparison xs_carry = self._state[1] init = (xs_carry, self.state[0], jnp.array(0, dtype=jnp.int32)) else: xs_carry = () xs, init = inputs, ([], self.state[0], jnp.array(0, dtype=jnp.int32)) ys, (carry, new_state, _) = _scan(scannable_fn, xs, init, axis=self._axis, remat=self._remat) res = ys + carry if n_carry > 0 else ys state_carry = carry if self._mode == 'predict' and n_carry > 0 else () self.state = (new_state, state_carry) return res # Put outputs and carry back on stack.
def forward(self, inputs): weights = self.weights[0] if isinstance(inputs, list): inputs = tuple(inputs) # so that inputs structure matches outputs n_carry = self._n_carry def scannable_fn(x, carry_and_state): # pylint: disable=invalid-name carry, state, i = carry_and_state x_and_carry = x + carry if n_carry > 0 else x rng = fastmath.random.fold_in(self.rng, i) res, new_state = self.sublayer.pure_fn(x_and_carry, weights, state, rng, use_cache=True) if n_carry > 0: return (res[:-n_carry], (res[-n_carry:], new_state, i + 1)) else: return (res, ([], new_state, i + 1)) if n_carry > 0: xs = inputs[:-n_carry] # Split input stack into inputs and carry. init = (inputs[-n_carry:], self.state[0], jnp.array(0, dtype=jnp.int32)) else: xs, init = inputs, ([], self.state[0], jnp.array(0, dtype=jnp.int32)) ys, (carry, new_state, _) = _scan(scannable_fn, xs, init, axis=self._axis, remat=self._remat) res = ys + carry if n_carry > 0 else ys self.state = (new_state, ) return res # Put outputs and carry back on stack.
def train_step(self, batch): """Run one training step and update self._opt_state.""" # Calculate the current optimizer parameters. opt_param_updates = self._for_n_devices( {'learning_rate': np.array(self.learning_rate)}) opt_state = self._opt_state opt_state.opt_params.update(opt_param_updates) # Run the update. weights, slots, opt_params = opt_state (weights, slots), stat, self._model_state, self._rngs = self._jit_update_fn( (weights, slots), self._step, opt_params, batch, self._model_state, self._rngs) self._opt_state = opt_state._replace(weights=weights, slots=slots) if self._should_log_now(): for name, value in stat.items(): # TODO(afrozm): value is a scalar, but sometimes JAX is crashing here # with a device put array error complaining that it should be an array. # On multiple devices, take the mean. scalar_value = np.mean(np.array(value)) self._train_sw.scalar('training/' + name, scalar_value, step=self._step) self._step += 1
def _UpdateRow(x): # (L, H), (L1, H) & (L2, H) row_ed, row_e, _ = x mask_e = row_e != 0 len_e = jnp.sum(mask_e, dtype=jnp.int32) # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e` # and pick up (L2, H) tensor slice from there. zero = jnp.array(0, dtype=len_e.dtype) # avoid int32/int64 mismatch l2_np = jnp.array(L2, dtype=len_e.dtype) h_np = jnp.array(H, dtype=len_e.dtype) return fastmath.dynamic_slice(row_ed, (len_e, zero), (l2_np, h_np))
def __init__(self, learning_rate=0.01, clip_grad_norm=None, **init_opt_params): """Sets initial hyperparameter values for this optimizer. Takes optimizer hyperparameters as keyword arguments. These values can change over time (training steps), e.g., for learning rate schedules. To expose subclass hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See `momentum.Momentum.__init__` for one such example. Args: learning_rate: Learning rate for the optimizer. This can change during training by means of a training rate schedule. clip_grad_norm: If specified, this scalar value is used to limit gradient size -- all gradient elements in a training step are treated as if they belonged to a single vector and then scaled back if needed so that such a vector's L2 norm does not exceed `clip_grad_norm`. If None, no clipping happens. **init_opt_params: Initial values of any additional optimizer parameters. """ init_opt_params['learning_rate'] = learning_rate self._init_opt_params = { name: jnp.array(value) for (name, value) in init_opt_params.items() } self._slots = None # Gradient clipping happens with respect to the norm of the whole gradient # tree, so it is not passed to single-slot updates, but done in this class # for the whole gradient tree. self._clip_grad_norm = clip_grad_norm
def init_weights_and_state(self, input_signature): """Randomly initializes the positional encoding vectors. Args: input_signature: :py:class:`ShapeDtype` instance characterizing the input this layer should compute on. """ d_feature = input_signature.shape[-1] if self._d_feature is not None: d_feature = self._d_feature pe = np.zeros((self._max_len, d_feature), dtype=np.float32) position = np.arange(0, self._max_len)[:, np.newaxis] div_term = np.exp( np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) # [self._max_len, d_feature] if self._use_bfloat16: pe = pe.astype(jnp.bfloat16) w = jnp.array(pe) # Trainable parameters, initialized above. if self._d_feature is not None: ff = init.GlorotUniformInitializer()( (d_feature, input_signature.shape[-1]), self.rng) self.weights = w, ff else: self.weights = w if self._mode == 'predict': self.state = jnp.zeros((), dtype=jnp.int32)
def _average_multidevice_gradients(gradients, adasum=False): """Averages gradients over all the devices across different hosts.""" gradients_psum = fastmath.psum(gradients, 'batch') # sum over all devices n = fastmath.psum(jnp.array(1.0), 'batch') # number of devices on all hosts if not adasum: return fastmath.nested_map(lambda g: g / n, gradients_psum) # This implements an approximation of the Adasum algorithm from the following # paper: https://arxiv.org/pdf/2006.02924.pdf # Since implementing halving and averaging half-by-half is tricky, we first # average all hosts, so we use the sum as a point of comparison for gradients. # So for 2 devices, this algorithm is the same as in the paper, but with more # devices it does a different kind of averaging. It still has the property # that orthogonal gradients will result in a sum while identical ones will # be averaged, as postulated in the paper. adasum_nominator = fastmath.nested_map_multiarg( lambda g, q: jnp.vdot(g, q), # pylint: disable=unnecessary-lambda gradients, gradients_psum) grad_norm = fastmath.nested_map(lambda g: jnp.vdot(g, g), gradients) # If all devices have identical gradients, then the nominator is equal # to n * grad_norm; if they're orthogonal, then nominator = grad_norm. scaled_grads = fastmath.nested_map_multiarg( lambda g, nominator, g_norm: g * (1 - (nominator - g_norm) / (n * g_norm)), gradients, adasum_nominator, grad_norm) return fastmath.psum(scaled_grads, 'batch')
def init_weights_and_state(self, input_signature): if self._mode == 'predict': shape, dtype = input_signature.as_tuple() batch_size, _, d_feature = shape cache = jnp.zeros((batch_size, 2 * self._total_kv_pooling, d_feature), dtype=dtype) self.state = cache, jnp.array(0)
def __init__(self, learning_rate=0.01, clip_grad_norm=None, **init_opt_params): """Sets initial hyperparameter values for this optimizer. Takes initial optimizer parameters as keyword arguments. These values can be changed between training steps, e.g., for learning rate schedules. If you want your subclass to expose hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See `momentum.Momentum.__init__` for one such example. Args: learning_rate: The initial learning rate. clip_grad_norm: float; the value to which gradients will be clipped. **init_opt_params: Initial values of any additional optimizer parameters. """ init_opt_params['learning_rate'] = learning_rate self._init_opt_params = { name: jnp.array(value) for (name, value) in init_opt_params.items() } self._slots = None # Gradient clipping happens with respect to the norm of the whole gradient # tree, so it is not passed to single-slot updates, but done in this class # for the whole gradient tree. self._clip_grad_norm = clip_grad_norm
def _average_multidevice_gradients(gradients, adasum=False): """Averages gradients over all the devices across different hosts.""" n = fastmath.global_device_count() // base.N_WEIGHTS_SHARDS if adasum: # This implements a version of the Adasum algorithm from the following # paper: https://arxiv.org/pdf/2006.02924.pdf lg = max([i for i in range(20) if 2**i <= n]) for lg_i in range(lg): shift = 2**lg_i perm = [] for i in range(n): block_i = i % (2 * shift) # we do blocks of 2*shift size if block_i < shift: perm.append((i, i + shift)) else: perm.append((i, i - shift)) perm_grad = jax.lax.ppermute(gradients, perm=perm, axis_name='batch') gradients = fastmath.nested_map_multiarg(_adasum_merge, gradients, perm_grad) if base.N_WEIGHTS_SHARDS > 1: # only sum gradients from matching shards groups = [[base.N_WEIGHTS_SHARDS * i + d for i in range(int(n))] for d in range(base.N_WEIGHTS_SHARDS)] gradients_psum = fastmath.psum(gradients, 'batch', axis_index_groups=groups) else: gradients_psum = fastmath.psum(gradients, 'batch') # sum all gradients n = jnp.array(n, dtype=jnp.float32) return fastmath.nested_map(lambda g: g / n, gradients_psum)
def one_hot(x, n_categories, dtype=jnp.float32): # pylint: disable=invalid-name """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" indices_less_than_n = jnp.arange(n_categories) if fastmath.is_backend(fastmath.Backend.JAX): # Work around a jax broadcasting issue. indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n) return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype)
def _n_weights_per_core(weights): # pylint: disable=invalid-name """Calculates the number of weights per core. In multi-device settings, gradients and losses are averaged over all devices. When loss is weighted and the number of weights can differ by device, e.g., when the weights represent the number of tokens in a batch of sentences (which can differ from device to device), we want to make sure each token on each device is weighted in the same way. This function ensures that by reporting the number of weights per core in multi-core settings (and simply np.sum(weights) in a single-core setting). Args: weights: tensor with arbitrary shape Returns: a scalar equal to np.sum(weights) in 1-machine settings and to the sum of weights over all cores divided by the number of cores otherwise """ weights_sum = jnp.sum(weights) if fastmath.device_count() < 2: return weights_sum else: try: n_devices_total = fastmath.psum(jnp.array(1.0), 'batch') return fastmath.psum(weights_sum, 'batch') / n_devices_total except (NameError, ValueError): # running outside of pmap, e.g., on init return weights_sum # fall back to the sum
def _multi_device_put(x, devices=None): """Memory efficient multi-device replication / broadcast in JAX. JAX uses a ShardedDeviceArray class that holds a list of device buffers on separate devices for use with pmap'd computations. Sharded arrays are explicitly used to eliminate unnecessary inter-device transfer of memory buffers between use in pmap'd computations. The JAX API currently does not have a multi-device 'put' function that copies a buffer onto N devices in a memory-efficient fashion, so we implement our own here. Args: x: jax DeviceArray or numpy ndarray to be replicated. devices: a jax.devices() list or subset thereof of devices to replicate onto. Should match the list passed to any pmaps ingesting the replicated array. Returns: A ShardedDeviceArray with dtype = x.dtype and shape = (n_devices,) + x.shape that's backed by replicated device_buffers on each local device. """ # Convert _FilledConstants that don't have device_buffer, etc. if type(x) != jax.xla.DeviceArray: # pylint: disable=unidiomatic-typecheck x = jnp.array(x) # Calculate the abstract shape of the replicated array. if not devices: devices = jax.local_devices() return jax.api.device_put_sharded(len(devices) * [x], devices)
def test_autoregressive_sample_transformerlm(self): model = models.TransformerLM(10, d_model=32, d_ff=64, n_layers=1, n_heads=2, mode='predict') model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32)) s1 = trainer_lib.autoregressive_sample(model, batch_size=1, eos_id=-1, max_length=10) self.assertEqual(s1.shape[0], 1) self.assertEqual(s1.shape[1], 10) batch_per_device = 2 // fastmath.device_count() model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=jnp.int32)) s2 = trainer_lib.autoregressive_sample(model, batch_size=2, max_length=10) self.assertEqual(s2.shape[0], 2) self.assertLess(s2.shape[1], 11) model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32)) prefix = jnp.array([[1, 2, 3]]) s3 = trainer_lib.autoregressive_sample(model, eos_id=-1, max_length=10, batch_size=1, prefix=prefix) self.assertEqual(s3.shape[0], 1) self.assertEqual(int(s3[0][0]), 1) self.assertEqual(int(s3[0][1]), 2) self.assertEqual(int(s3[0][2]), 3)
def _fast_inference_init_state(input_signature, buffer_length, predict_mask=None): """Returns an initial state for causal attention layer fast inference.""" def zeros_for(batch_size, shape_dtype): shape, dtype = shape_dtype.as_tuple() d_feature = shape[-1] return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype) batch_size = input_signature[0].shape[0] k = zeros_for(batch_size, input_signature[1]) v = zeros_for(batch_size, input_signature[2]) if predict_mask is not None: mask_for_predict = jnp.zeros((buffer_length, )) != 0 return (mask_for_predict, k, v, jnp.array(0)) else: return (k, v, jnp.array(0))
def _UpdateRow(x): # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,) row_e, row_d, row_mask_e = x # final_row - (L1+L2, H) final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0) # Find the last real token/vector of the encoder. e_idx = jnp.sum(row_mask_e, dtype=jnp.int32) # Starting after that index, update with the decoder row. zero = jnp.array(0, dtype=e_idx.dtype) # avoid int32/int64 mismatch return fastmath.dynamic_update_slice(final_row, row_d, (e_idx, zero))
def _average_multidevice_gradients(gradients): """Averages gradients over all the devices across different hosts.""" # Sum gradients over all devices across all hosts. gradients = fastmath.psum(gradients, 'batch') # Calculate the total number of devices. # Note: the usual n_devices is only the number of devices at this host, # here we are calculating the number of all devices across all hosts. n_devices_total = fastmath.psum(jnp.array(1.0), 'batch') # Average across hosts. return fastmath.nested_map(lambda g: g / n_devices_total, gradients)
def single_device_update_fn( weights_and_slots, step, opt_params, batch, state, rng): step = jnp.array(step, dtype=jnp.int32) # Needed in TFNP backend. weights, slots = weights_and_slots (loss, state), gradients = forward_and_backward_fn( batch, weights, state, rng) weights, slots, stats = optimizer.tree_update( step, gradients, weights, slots, opt_params) stats['loss'] = loss return (weights, slots), state, stats
def _fast_inference_init_state(self, input_signature): """Returns an initial state for causal attention layer fast inference.""" def zeros_for(bs, shape_dtype): shape, dtype = shape_dtype.as_tuple() d_feature = shape[-1] return jnp.zeros((bs, self._max_len, d_feature), dtype=dtype) batch_size = input_signature[0].shape[0] k = zeros_for(batch_size, input_signature[0]) v = zeros_for(batch_size, input_signature[1]) return k, v, jnp.array(0)
def init_weights_and_state(self, input_signature): """Returns newly initialized weights for this layer. Weights is a single `w` tensor with previously specified shape. Args: input_signature: `ShapeDtype` instance characterizing the input this layer should compute on. Unused. """ del input_signature # Unused. self.weights = () self.state = {self._name: jnp.array(0.)}
def _fast_inference_init_state(input_signature, buffer_length): """Returns an initial state for causal attention layer fast inference.""" def zeros_for(batch_size, shape_dtype): shape, dtype = shape_dtype.as_tuple() d_feature = shape[-1] return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype) batch_size = input_signature[0].shape[0] k = zeros_for(batch_size, input_signature[1]) v = zeros_for(batch_size, input_signature[2]) mask = jnp.zeros((batch_size, 1, buffer_length)) return (k, v, mask, jnp.array(0))
def init_weights_and_state(self, input_signature): d_feature = input_signature.shape[-1] pe = np.zeros((self._max_len, d_feature), dtype=np.float32) position = np.arange(0, self._max_len)[:, np.newaxis] div_term = np.exp( np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, self._max_len, d_feature] self.weights = jnp.array(pe) # Trainable parameters, initialized above. if self._mode == 'predict': batch_size = input_signature.shape[0] self.state = jnp.zeros((batch_size,), dtype=jnp.int32)
def next_symbol(cur_output_tokens, model, d_model): """Returns the next symbol for a given sentence. Args: cur_output_tokens (list): tokenized sentence with EOS and PAD tokens at the end. model (trax.layers.combinators.Serial): The transformer model. Returns: int: tokenized symbol. """ token_length = len(cur_output_tokens) padded_length = 2**int(np.ceil(np.log2(token_length + 1))) assert(padded_length<=d_model*8, 'the d_features is 512 char, or 4096 bit maximum') #assuming 8bit char #print('token length: {}, padded length: {}'.format(token_length, padded_length)) padded = cur_output_tokens + [0] * (padded_length - token_length) padded_with_batch = np.array(padded)[None, :] # Don't replace this 'None'! This is a way of setting the batch dim output, _ = model((jnp.array(padded_with_batch), jnp.array(padded_with_batch))) log_probs = output[0, len(cur_output_tokens), :] return int(np.argmax(log_probs))
def _fast_inference_init_state(self, input_signature): """Returns an initial state for causal attention layer fast inference.""" def zeros_for_shape(bs, tokens_len, shape_dtype): shape, dtype = shape_dtype.as_tuple() d_feature = shape[-1] return jnp.zeros((bs, tokens_len, d_feature), dtype=dtype) batch_size = input_signature[0].shape[0] n_tokens = self._chunk_len if self._chunk_len is not None else self._max_len k = zeros_for_shape(batch_size, n_tokens, input_signature[0]) v = zeros_for_shape(batch_size, n_tokens, input_signature[1]) return k, v, jnp.array(0)
def __init__(self, mode=None, learn_epsilon=False, init_epsilon=1e-6, init_learnt_epsilon=1e-4): super().__init__() del mode # If we learn epsilon then epsilon = init_epsilon + |learnt_value| # where learnt_value is initialized to init_learnt_epsilon. # If learn_epsilon is false then epsilon is just init_epsilon. # # NOTE: I (afrozm) haven't been able to train with `learn_epsilon = True`. self._learn_epsilon = learn_epsilon # TODO(jonni): Replace asserts with ValueError. assert init_epsilon > 0 assert init_learnt_epsilon > 0 self._init_epsilon = jnp.array(init_epsilon, dtype=jnp.float32) self._init_learnt_epsilon = jnp.array(init_learnt_epsilon, dtype=jnp.float32)
def predict(num_chars, prefix): inp = [ord(c) for c in prefix] result = [c for c in prefix] max_len = len(prefix) + num_chars for _ in range(num_chars): cur_inp = np.array(inp + [0] * (max_len - len(inp))) outp = model(cur_inp[None, :]) # Add batch dim. next_char = gumbel_sample(outp[0, len(inp)]) inp += [int(next_char)] if inp[-1] == 1: break # EOS result.append(chr(int(next_char))) return "".join(result)
def init_weights_and_state(self, input_signature): weights = [] states = [] # In the code below, stack, inputs, and outputs are abstract (shapes and # dtypes), but weights and states are non-abstract actual values. stack = input_signature for sublayer in self.sublayers: inputs = _inputs_from_stack(sublayer, stack) weights_or_cache_marker, state_or_cache_marker = ( sublayer.init(inputs, use_cache=True)) outputs, _ = sublayer._forward_abstract(inputs) stack = _outputs_onto_stack(sublayer, outputs, stack) weights.append(weights_or_cache_marker) states.append(state_or_cache_marker) self.state = (jnp.array(0, dtype=jnp.int32), states) self.weights = weights
def mapped_update(weights_and_slots, i, opt_params, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots = weights_and_slots rng, subrng = jax_random.split(rng) grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just # the number of devices on this host machine, however psum goes over all # devices of all hosts (ex: a TPU pod) and we need to be averaging over all # of them. grads = jax.tree_util.tree_map( lambda g: ( # pylint: disable=g-long-lambda fastmath.psum(g, 'batch') / fastmath.psum(np.array(1.0), 'batch')), grads) new_weights, new_slots, stats = optimizer.tree_update( i, grads, weights, slots, opt_params) return (new_weights, new_slots), stats, state, subrng
def _multi_device_update_fn( weights_and_slots, step, opt_params, batch, state, rng): # We assume all tensors have the first dimension = n_devices. weights, slots = weights_and_slots (loss, state), gradients = forward_and_backward_fn( batch, weights, state, rng) # gradients now need to be summed over all the devices across different host # machines, n_devices is only the number of devices on *this* host machine. gradients = fastmath.psum(gradients, 'batch') n_devices_total = fastmath.psum(jnp.array(1.0), 'batch') # Average across hosts. gradients = jax.tree_util.tree_map(lambda g: g / n_devices_total, gradients) weights, slots, stats = optimizer.tree_update( step, gradients, weights, slots, opt_params) stats['loss'] = loss return (weights, slots), state, stats
def _make_weights_and_state_same_across_hosts(weights_and_state): """Makes train and eval model's weights and state the same across hosts.""" # We assume that they have been already replicated, i.e the leading axis is # self._n_devices # This is the total number of devices across all hosts. n_devices_total = fastmath.psum(jnp.array(1.0), 'devices') # This sums up the weights and state across all devices. # NOTE: There will not be any leading axis remaining because we psum # over it. weights_and_state = fastmath.psum(weights_and_state, 'devices') # We finally take the average over all devices. weights_and_state = jax.tree_util.tree_map( lambda ws: ws / n_devices_total, weights_and_state) return weights_and_state