def test_stochastic_rngs(self): rng = random.PRNGKey(0) with nn.stochastic(rng): r1 = nn.make_rng() r2 = nn.make_rng() self.assertTrue(onp.all(r1 == random.fold_in(rng, 1))) self.assertTrue(onp.all(r2 == random.fold_in(rng, 2)))
def test_train_one_step(self): batch = train.get_batch(128) rng = random.PRNGKey(0) with nn.stochastic(rng): model = train.create_model(nn.make_rng()) optimizer = train.create_optimizer(model, 0.003) optimizer, train_metrics = train.train_step( optimizer, batch, nn.make_rng()) self.assertLessEqual(train_metrics['loss'], 5) self.assertGreaterEqual(train_metrics['accuracy'], 0)
def train_model(): """Train for a fixed number of steps and decode during training.""" with nn.stochastic(jax.random.PRNGKey(0)): model = create_model(nn.make_rng()) optimizer = create_optimizer(model, FLAGS.learning_rate) for step in range(FLAGS.num_train_steps): batch = get_batch(FLAGS.batch_size) optimizer, metrics = train_step(optimizer, batch, nn.make_rng()) if step % FLAGS.decode_frequency == 0: logging.info('train step: %d, loss: %.4f, accuracy: %.2f', step, metrics['loss'], metrics['accuracy'] * 100) decode_batch(optimizer.target, 5) return optimizer.target
def select_patches_perturbed_topk(flatten_scores, sigma, *, k, num_samples=1000): """Select patches using a differentiable top-k based on perturbation. Uses https://q-berthet.github.io/papers/BerBloTeb20.pdf, see off_the_grid.lib.ops.perturbed_topk for more info. Args: flatten_scores: The flatten scores of shape (batch, num_patches). sigma: Standard deviation of the noise. k: The number of patches to extract. num_samples: Number of noisy inputs used to compute the output expectation. Returns: Indicator vectors of the selected patches (batch, num_patches, k). """ batch_size = flatten_scores.shape[0] batch_topk_fn = jax.vmap( functools.partial(perturbed_topk.perturbed_sorted_topk_indicators, num_samples=num_samples, sigma=sigma, k=k)) rng_keys = jax.random.split(nn.make_rng(), batch_size) indicators = batch_topk_fn(flatten_scores, rng_keys) topk_indicators_flatten = einops.rearrange(indicators, "b k d -> b d k") return topk_indicators_flatten
def get_drop_pattern(self, x, layer_drop_p): if nn.is_stochastic() and layer_drop_p: rng = nn.make_rng() shape = (x.shape[0],) + (1,) * (x.ndim - 1) return jax.random.bernoulli(rng, layer_drop_p, shape).astype("float32") else: return 0.0
def create_model(key, input_shape): def inducing_loc_init(key, shape): return jnp.linspace(-1.5, 1.5, FLAGS.num_inducing_points)[:, jnp.newaxis] kwargs = {} for i in range(1, FLAGS.num_layers + 1): kwargs['kernel_fn_{}_kwargs'.format(i)] = { 'amplitude_init': lambda key, shape: jnp.ones(shape), 'length_scale_init': lambda key, shape: jnp.ones(shape) } kwargs['inducing_var_{}_kwargs'.format(i)] = { 'fixed_locations': False, 'whiten': FLAGS.whiten, 'inducing_locations_init': inducing_loc_init } model_def = DeepGPModel.partial(**kwargs) with nn.stochastic(key): _, params = model_def.init_by_shape(key, [ (input_shape, jnp.float64), ], nn.make_rng(), **kwargs) return nn.Model(model_def, params)
def apply(self, inputs, eos_id=1, hidden_size=512): # inputs.shape = (batch_size, seq_length, vocab_size). batch_size = inputs.shape[0] lstm_cell = nn.LSTMCell.partial(name='lstm') init_lstm_state = nn.LSTMCell.initialize_carry( nn.make_rng(), (batch_size,), hidden_size) def encode_step_fn(carry, x): lstm_state, is_eos = carry new_lstm_state, y = lstm_cell(lstm_state, x) # Pass forward the previous state if EOS has already been reached. def select_carried_state(new_state, old_state): return jnp.where(is_eos[:, np.newaxis], old_state, new_state) # LSTM state is a tuple (c, h). carried_lstm_state = tuple( select_carried_state(*s) for s in zip(new_lstm_state, lstm_state)) # Update `is_eos`. is_eos = jnp.logical_or(is_eos, x[:, eos_id]) return (carried_lstm_state, is_eos), y (final_state, _), _ = jax_utils.scan_in_dim( encode_step_fn, init=(init_lstm_state, jnp.zeros(batch_size, dtype=np.bool)), xs=inputs, axis=1) return final_state
def apply(self, init_state, inputs, teacher_force=False): # inputs.shape = (batch_size, seq_length, vocab_size). vocab_size = inputs.shape[2] lstm_cell = nn.LSTMCell.shared(name='lstm') projection = nn.Dense.shared(features=vocab_size, name='projection') def decode_step_fn(carry, x): rng, lstm_state, last_prediction = carry carry_rng, categorical_rng = jax.random.split(rng, 2) if not teacher_force: x = last_prediction lstm_state, y = lstm_cell(lstm_state, x) logits = projection(y) predicted_tokens = jax.random.categorical(categorical_rng, logits) prediction = onehot(predicted_tokens, vocab_size) return (carry_rng, lstm_state, prediction), (logits, prediction) init_carry = (nn.make_rng(), init_state, inputs[:, 0]) if self.is_initializing(): # initialize parameters before scan decode_step_fn(init_carry, inputs[:, 0]) _, (logits, predictions) = jax_utils.scan_in_dim( decode_step_fn, init=init_carry, # rng, lstm_state, last_pred xs=inputs, axis=1) return logits, predictions
def word_dropout(inputs: jnp.ndarray, rate: float, unk_idx: int, deterministic: bool = False): """Replaces a fraction (rate) of inputs with <unk>.""" if deterministic or rate == 0.: return inputs mask = jax.random.bernoulli(nn.make_rng(), p=rate, shape=inputs.shape) return jnp.where(mask, jnp.array([unk_idx]), inputs)
def create_model(): """Creates a seq2seq model.""" vocab_size = CTABLE.vocab_size _, initial_params = Seq2seq.partial(eos_id=CTABLE.eos_id).init_by_shape( nn.make_rng(), [((1, get_max_input_len(), vocab_size), jnp.float32), ((1, get_max_output_len(), vocab_size), jnp.float32)]) model = nn.Model(Seq2seq, initial_params) return model
def decode_batch(model, batch_size): """Decode and log results for a batch.""" batch = get_batch(batch_size) inputs, outputs = batch['query'], batch['answer'][:, 1:] inferred = decode(model, inputs, nn.make_rng()) questions = decode_onehot(inputs) infers = decode_onehot(inferred) goldens = decode_onehot(outputs) for question, inferred, golden in zip(questions, infers, goldens): log_decode(question, inferred, golden)
def apply(self, *args, wrapped_module, num_heads=1, num_parallel_heads=None, use_python_loop=False, **kwargs): # Re-use the same rng key across all examples and heads. This will result in # broadcasted dropout, which saves memory. # TODO(kitaev): options to swap broadcasted RNG on/off rng = nn.make_rng() if nn.is_stochastic() else None def init_single_head(init_rng, args, kwargs): if rng is None: _, head_params = wrapped_module.init(init_rng, *args, **kwargs) else: with nn.stochastic(rng): _, head_params = wrapped_module.init( init_rng, *args, **kwargs) return head_params def init_wrapped_module(rng, unused_shape): single_example_args = jax.tree_map(lambda x: x[:1], args) return multihead.chunked_multihead_map( init_single_head, in_has_batch_dim=(False, True, False), in_has_head_dim=(True, False, False), out_has_batch_dim=False, out_has_head_dim=True, use_python_loop=True, )(jax.random.split(rng, num_heads), single_example_args, kwargs) # TODO(kitaev): The original intent was to have this be a transparent module # but for some reason naming this parameter '0' and inheriting from # nn.base.TransparentModule is not enough to stop this parameter name from # explicitly showing up in the parameter tree. params = self.param('attn', None, init_wrapped_module) def run_single_example_and_head(params, args, kwargs): if rng is None: return wrapped_module.call(params, *args, **kwargs) else: with nn.stochastic(rng): return wrapped_module.call(params, *args, **kwargs) return multihead.chunked_multihead_map( run_single_example_and_head, in_has_batch_dim=(False, True, False), in_has_head_dim=(True, False, False), out_has_batch_dim=True, out_has_head_dim=False, num_parallel_heads=num_parallel_heads, use_python_loop=use_python_loop, )(params, args, kwargs)
def train(train_ds): rng = random.PRNGKey(0) with nn.stochastic(rng): model = create_model(rng, train_ds['index_points'].shape) optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.beta1) key = nn.make_rng() for epoch in range(1, FLAGS.num_epochs + 1): key = random.split(key, FLAGS.num_samples + 1) key, sample_key = (key[0], key[1:]) optimizer, metrics = train_epoch(optimizer, train_ds, epoch, sample_key) return optimizer
def drop_path(x: jnp.array, drop_rate: float = 0., rng=None) -> jnp.array: """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_rate == 0.: return x keep_prob = 1. - drop_rate if rng is None: rng = make_rng() mask = random.bernoulli(key=rng, p=keep_prob, shape=(x.shape[0], 1, 1, 1)) mask = jnp.broadcast_to(mask, x.shape) return lax.select(mask, x / keep_prob, jnp.zeros_like(x))
def drop_path(x, drop_prob: float = 0., rng=None): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ # FIXME not tested if drop_prob == 0.: return x keep_prob = 1 - drop_prob if rng is None: rng = make_rng('dropout') random_tensor = keep_prob + random.bernoulli( key=rng, p=keep_prob, shape=(x.shape[0], 1, 1, 1)) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output
def apply_param_gradient(self, step, hyper_params, param, state, grad): del step assert hyper_params.learning_rate is not None, "no learning rate provided." if hyper_params.weight_decay != 0: raise NotImplementedError("Weight decay not supported") noise = jax.random.normal( key=nn.make_rng(), shape=param.shape, dtype=param.dtype) momentum = state.momentum h = hyper_params.step_size gamma = hyper_params.friction t = hyper_params.temperature n = hyper_params.train_size new_momentum = ( (1 - h * gamma) * momentum - h * n * grad + jnp.sqrt(2 * gamma * h * t) * jnp.sqrt(state.preconditioner) * noise) new_param = param + h * (1. / state.preconditioner) * new_momentum new_state = _SymEulerSGMCMCParamState(new_momentum, state.preconditioner) return new_param, new_state
def apply(self, x, config, num_classes, train=True): """Creates a model definition.""" b, c = x.shape[0], x.shape[3] k = config.k sigma = config.ptopk_sigma num_samples = config.ptopk_num_samples sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats = {"x": x, "sigma": sigma} feature_extractor = models.ResNet50.shared(train=train, name="ResNet_0") rpn_feature = feature_extractor(x) rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature), communication=Communication( config.communication), train=train) stats.update(rpn_stats) # rpn_scores are a list of score images. We keep track of the structure # because it is used in the aggregation step later-on. rpn_scores_shapes = [s.shape for s in rpn_scores] rpn_scores_flat = jnp.concatenate( [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1) top_k_indicators = sample_patches.select_patches_perturbed_topk( rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples) top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1]) offset = 0 weights = [] for sh in rpn_scores_shapes: cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]] cur = jnp.reshape(cur, [b, k, sh[1], sh[2]]) weights.append(cur) offset += sh[1] * sh[2] chex.assert_equal(offset, top_k_indicators.shape[-1]) part_imgs = weighted_anchor_aggregator(x, weights) chex.assert_shape(part_imgs, (b * k, 224, 224, c)) stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c]) part_features = feature_extractor(part_imgs) part_features = jnp.mean(part_features, axis=[1, 2]) # GAP the spatial dims part_features = nn.dropout( # features from parts jnp.reshape(part_features, [b * k, 2048]), 0.5, deterministic=not train, rng=nn.make_rng()) features = nn.dropout( # features from whole image jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]), 0.5, deterministic=not train, rng=nn.make_rng()) # Mean pool all part features, add it to features and predict logits. concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]), axis=1) + features concat_logits = nn.Dense(concat_out, num_classes) raw_logits = nn.Dense(features, num_classes) part_logits = jnp.reshape(nn.Dense(part_features, num_classes), [b, k, -1]) all_logits = { "raw_logits": raw_logits, "concat_logits": concat_logits, "part_logits": part_logits, } # add entropy into it for entropy regularization. stats["rpn_scores_entropy"] = jax.scipy.special.entr( jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0) return all_logits, stats
def init_param_state(self, param): # TODO(basv): do we want to init momentum randomly? return _SymEulerSGMCMCParamState( jax.random.normal(nn.make_rng(), param.shape, param.dtype), jnp.ones_like(param))
def apply(self, x, *, patch_size, k, downscale, scorer_has_se, normalization_str="identity", selection_method, selection_method_kwargs=None, selection_method_inference=None, patch_dropout=0., hard_topk_probability=0., random_patch_probability=0., use_iterative_extraction, append_position_to_input, feature_network, aggregation_method, aggregation_method_kwargs=None, train): """Process a high resolution image by selecting a subset of useful patches. This model processes the input as follow: 1. Compute scores per patch on a downscaled version of the input. 2. Select "important" patches using sampling or top-k methods. 3. Extract the patches from the high-resolution image. 4. Compute representation vector for each patch with a feature network. 5. Aggregate the patch representation to obtain an image representation. Args: x: Input tensor of shape (batch, height, witdh, channels). patch_size: Size of the (squared) patches to extract. k: Number of patches to extract per image. downscale: Downscale multiplier for the input of the scorer network. scorer_has_se: Whether scorer network has Squeeze-excite layers. normalization_str: String specifying the normalization of the scores. selection_method: Method that selects which patches should be extracted, based on their scores. Either returns indices (hard selection) or indicators vectors (which could yield interpolated patches). selection_method_kwargs: Keyword args for the selection_method. selection_method_inference: Selection method used at inference. patch_dropout: Probability to replace a patch by 0 values. hard_topk_probability: Probability to use the true topk on the scores to select the patches. This operation has no gradient so scorer's weights won't be trained. random_patch_probability: Probability to replace each patch by a random patch in the image during training. use_iterative_extraction: If True, uses a for loop instead of patch indexing for memory efficiency. append_position_to_input: Append normalized (height, width) position to the channels of the input. feature_network: Network to be applied on each patch individually to obtain patch representation vectors. aggregation_method: Method to aggregate the representations of the k patches of each image to obtain the image representation. aggregation_method_kwargs: Keywords arguments for aggregation_method. train: If the model is being trained. Disable dropout otherwise. Returns: A representation vector for each image in the batch. """ selection_method = SelectionMethod(selection_method) aggregation_method = AggregationMethod(aggregation_method) if selection_method_inference: selection_method_inference = SelectionMethod( selection_method_inference) selection_method_kwargs = selection_method_kwargs or {} aggregation_method_kwargs = aggregation_method_kwargs or {} stats = {} # Compute new dimension of the scoring image. b, h, w, c = x.shape scoring_shape = (b, h // downscale, w // downscale, c) # === Compute the scores with a small CNN. if selection_method == SelectionMethod.RANDOM: scores_h, scores_w = Scorer.compute_output_size( h // downscale, w // downscale) num_patches = scores_h * scores_w else: # Downscale input to run scorer on. scoring_x = jax.image.resize(x, scoring_shape, method="bilinear") scores = Scorer(scoring_x, use_squeeze_excite=scorer_has_se, name="scorer") flatten_scores = einops.rearrange(scores, "b h w -> b (h w)") num_patches = flatten_scores.shape[-1] scores_h, scores_w = scores.shape[1:3] # Compute entropy before normalization prob_scores = jax.nn.softmax(flatten_scores) stats["entropy_before_normalization"] = jax.scipy.special.entr( prob_scores).sum(axis=1).mean(axis=0) # Normalize the flatten scores normalization_fn = create_normalization_fn(normalization_str) flatten_scores = normalization_fn(flatten_scores) scores = flatten_scores.reshape(scores.shape) stats["scores"] = scores[Ellipsis, None] # Concatenate height and width position to the input channels. if append_position_to_input: coords = utils.create_grid([h, w], value_range=(0., 1.)) x = jnp.concatenate( [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1) c += 2 # Overwrite the selection method at inference if selection_method_inference and not train: selection_method = selection_method_inference # === Patch selection # Select the patches by sampling or top-k. Some methods returns the indices # of the selected patches, other methods return indicator vectors. extract_by_indices = selection_method in [ SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM ] if selection_method is SelectionMethod.SINKHORN_TOPK: indicators = select_patches_sinkhorn_topk( flatten_scores, k=k, **selection_method_kwargs) elif selection_method is SelectionMethod.PERTURBED_TOPK: sigma = selection_method_kwargs["sigma"] num_samples = selection_method_kwargs["num_samples"] sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats["sigma"] = sigma indicators = select_patches_perturbed_topk(flatten_scores, k=k, sigma=sigma, num_samples=num_samples) elif selection_method is SelectionMethod.HARD_TOPK: indices = select_patches_hard_topk(flatten_scores, k=k) elif selection_method is SelectionMethod.RANDOM: batch_random_indices_fn = jax.vmap( functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False)) indices = batch_random_indices_fn( jax.random.split(nn.make_rng(), b)) # Compute scores entropy for regularization if selection_method not in [SelectionMethod.RANDOM]: prob_scores = flatten_scores # Normalize the scores if it is not already done. if "softmax" not in normalization_str: prob_scores = jax.nn.softmax(prob_scores) stats["entropy"] = jax.scipy.special.entr(prob_scores).sum( axis=1).mean(axis=0) # Randomly use hard topk at training. if (train and hard_topk_probability > 0 and selection_method not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]): true_indices = select_patches_hard_topk(flatten_scores, k=k) random_values = jax.random.uniform(nn.make_rng(), (b, )) use_hard = random_values < hard_topk_probability if extract_by_indices: indices = jnp.where(use_hard[:, None], true_indices, indices) else: true_indicators = make_indicators(true_indices, num_patches) indicators = jnp.where(use_hard[:, None, None], true_indicators, indicators) # Sample some random patches during training with random_patch_probability. if (train and random_patch_probability > 0 and selection_method is not SelectionMethod.RANDOM): single_random_patches = functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False) random_indices = jax.vmap(single_random_patches)(jax.random.split( nn.make_rng(), b)) random_values = jax.random.uniform(nn.make_rng(), (b, k)) use_random = random_values < random_patch_probability if extract_by_indices: indices = jnp.where(use_random, random_indices, indices) else: random_indicators = make_indicators(random_indices, num_patches) indicators = jnp.where(use_random[:, None, :], random_indicators, indicators) # === Patch extraction if extract_by_indices: patches = extract_patches_from_indices(x, indices, patch_size=patch_size, grid_shape=(scores_h, scores_w)) indicators = make_indicators(indices, num_patches) else: patches = extract_patches_from_indicators( x, indicators, patch_size, grid_shape=(scores_h, scores_w), iterative=use_iterative_extraction, patch_dropout=patch_dropout, train=train) chex.assert_shape(patches, (b, k, patch_size, patch_size, c)) stats["extracted_patches"] = einops.rearrange( patches, "b k i j c -> b i (k j) c") # Remove position channels for plotting. if append_position_to_input: stats["extracted_patches"] = ( stats["extracted_patches"][Ellipsis, :-2]) # === Compute patch features flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c") representations = feature_network(flatten_patches, train=train) if representations.ndim > 2: collapse_axis = tuple(range(1, representations.ndim - 1)) representations = representations.mean(axis=collapse_axis) representations = einops.rearrange(representations, "(b k) d -> b k d", k=k) stats["patch_representations"] = representations # === Aggregate the k patches # - for sampling we are forced to take an expectation # - for topk we have multiple options: mean, max, transformer. if aggregation_method is AggregationMethod.TRANSFORMER: patch_pos_encoding = nn.Dense(einops.rearrange( indicators, "b d k -> b k d"), features=representations.shape[-1]) chex.assert_equal_shape([representations, patch_pos_encoding]) representations += patch_pos_encoding representations = transformer.Transformer( representations, **aggregation_method_kwargs, is_training=train) elif aggregation_method is AggregationMethod.MEANPOOLING: representations = representations.mean(axis=1) elif aggregation_method is AggregationMethod.MAXPOOLING: representations = representations.max(axis=1) elif aggregation_method is AggregationMethod.SUM_LAYERNORM: representations = representations.sum(axis=1) representations = nn.LayerNorm(representations) representations = nn.Dense(representations, features=representations.shape[-1], name="classification_dense1") representations = nn.swish(representations) return representations, stats
def apply(self): return nn.make_rng()
def test_decode_batch(self): with nn.stochastic(random.PRNGKey(0)): model = train.create_model(nn.make_rng()) train.decode_batch(model, 5)
def test_make_rng_requires_stochastic(self): with self.assertRaises(ValueError): nn.make_rng()
def dot_product_attention(query, key, value, dtype=jnp.float32, bias=None, axis=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. This function supports multi-dimensional inputs. This version is modified to move the softmax division after the dot product. Args: query: queries for calculating attention with shape of `[batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]`. key: keys for calculating attention with shape of `[batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]`. value: values to be used in attention with shape of `[batch_size, dim1, dim2,..., dimN, num_heads, value_channels]`. dtype: the dtype of the computation (default: float32) bias: bias for the attention weights. This can be used for incorporating autoregressive mask, padding mask, proximity bias. axis: axises over which the attention is applied. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. Returns: Output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`. """ assert key.shape[:-1] == value.shape[:-1] assert (query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1]) if axis is None: axis = tuple(range(1, key.ndim - 2)) if not isinstance(axis, Iterable): axis = (axis, ) assert key.ndim == query.ndim assert key.ndim == value.ndim for ax in axis: if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): raise ValueError('Attention axis must be between the batch ' 'axis and the last-two axes.') depth = query.shape[-1] n = key.ndim # batch_dims is <bs, <non-attention dims>, num_heads> batch_dims = tuple(np.delete(range(n), axis + (n - 1, ))) # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels) qk_perm = batch_dims + axis + (n - 1, ) key = key.transpose(qk_perm) query = query.transpose(qk_perm) # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>) v_perm = batch_dims + (n - 1, ) + axis value = value.transpose(v_perm) query = query / jnp.sqrt(depth).astype(dtype) batch_dims_t = tuple(range(len(batch_dims))) attn_weights = lax.dot_general(query, key, (((n - 1, ), (n - 1, )), (batch_dims_t, batch_dims_t)), precision=precision) # apply attention bias: masking, droput, proximity bias, ect. if bias is not None: attn_weights = attn_weights + bias # normalize the attention weights norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim)) decoding = attn_weights.shape[-2] != 256 if decoding: attn_weights = lax.exp(attn_weights - jax.scipy.special.logsumexp( attn_weights, axis=norm_dims, keepdims=True)) else: # move the division by the softmax denominator to after the dot product attn_weights = jnp.exp(attn_weights - lax.stop_gradient( jnp.max(attn_weights, axis=norm_dims, keepdims=True))) softmax_denominator = jnp.sum(attn_weights, axis=norm_dims, keepdims=False) attn_weights = attn_weights.astype(dtype) # apply dropout if not deterministic and dropout_rate > 0.: if dropout_rng is None: dropout_rng = nn.make_rng() keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate) if broadcast_dropout: # dropout is broadcast across the batch+head+non-attention dimension dropout_dims = attn_weights.shape[-(2 * len(axis)):] dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims) keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier # compute the new values given the attention weights wv_contracting_dims = (norm_dims, range(value.ndim - len(axis), value.ndim)) y = lax.dot_general(attn_weights, value, (wv_contracting_dims, (batch_dims_t, batch_dims_t)), precision=precision) if not decoding: # divide by the denominator of the attention softmax now, when the array is # O(N*H) rather than O(N^2) y = y / jnp.expand_dims(softmax_denominator, -1) # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) perm_inv = _invert_perm(qk_perm) y = y.transpose(perm_inv) return y
def self_attention(inputs, variable_dictionary, num_heads: int, qkv_features: int = None, padding_mask: List[bool] = None, dropout_rate: float = 0., deterministic: bool = False, precision: Precision = None, kernel_init: List[float] = nn.linear.default_kernel_init, bias_init: List[float] = nn.initializers.zeros, dtype: jnp.dtype = jnp.float32, bias: bool = True): """Applies Multi-head self-attention on the input data. Args: inputs: input data of shape `[bs, dim1, dim2, ..., dimN, features]`. variable_dictionary: Parameter dictionary. num_heads: number of attention heads. Features (i.e. inputs.shape[-1]) should be divisible by the number of heads. qkv_features: dimension of the key, query, and value. padding_mask: boolean specifying tokens that are pad token. dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. dtype: datatype for the activiations, jnp.bfloat16 or jnp.float32 bias: bool: whether pointwise QKVO dense transforms use bias. Returns: output of shape `[bs, dim1, dim2, ..., dimN, features//num_heads]`. """ features = inputs.shape[-1] qkv_features = qkv_features or features assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads inputs = inputs.astype(dtype) if FLAGS.use_einsum: dense_module = Dense3D else: dense_module = attention.DenseGeneral query = dense_module.call(variable_dictionary['query'], inputs, axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision, dtype=dtype, name='query') query = jnp.multiply(query, 1.0 / math.sqrt(float(head_dim))) key = dense_module.call(variable_dictionary['key'], inputs, axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision, dtype=dtype, name='key') value = dense_module.call(variable_dictionary['value'], inputs, axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision, dtype=dtype, name='value') assert query.dtype == dtype assert key.dtype == dtype assert value.dtype == dtype # get raw attention scores from dot product between key and query # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_heads` # H = `head_dim` (qkv_features // num_heads) attention_scores = jnp.einsum('BTNH,BFNH->BNFT', key, query) assert attention_scores.dtype == dtype assert attention_scores.dtype == dtype # create attention masks if padding_mask is not None: assert padding_mask.dtype == bool, ('Mask should have bool type.') attention_mask = jnp.expand_dims(padding_mask, axis=1) adder = (1.0 - attention_mask) * NEG_INFINITY attention_scores += adder.astype(dtype) assert attention_scores.dtype == dtype attention_scores = attention_scores - lax.stop_gradient( jnp.max(attention_scores, axis=-1, keepdims=True)) attention_scores = jnp.exp(attention_scores) attention_sum = jnp.sum(attention_scores, axis=-1, keepdims=True) keep_prob = 1 - dropout_rate if not deterministic: keep_mask = jax.random.bernoulli(nn.make_rng(), keep_prob, attention_scores.shape).astype(dtype) assert keep_mask.dtype == dtype attention_probs = jnp.multiply(keep_mask, attention_scores) else: attention_probs = attention_scores assert attention_probs.dtype == dtype attention_probs = jnp.einsum('BNFT,BTNH->BFNH', attention_probs, value) assert attention_probs.dtype == dtype attention_probs = attention_probs / jnp.transpose(attention_sum, [0, 2, 1, 3]) # split mask and scaling ops in dropout # move the scaling from dropout to here to save same mul ops # TODO(yuemmawang) automate this optimization in xla if not deterministic: scale = 1 / keep_prob if dtype == jnp.bfloat16: scale = jnp.bfloat16(scale) attention_probs = jnp.multiply(attention_probs, scale) assert attention_probs.dtype == dtype return attention_probs
def lsh_attention_single_head(query, value, n_buckets, n_hashes, causal_mask=True, length_norm=False): """Applies LSH attention on a single head and a single batch. Args: query: query tensor of shape [qlength, dims]. value: value tensor of shape [vlength, dims]. n_buckets: integer, number of buckets. n_hashes: integer, number of hashes. causal_mask: boolean, to use causal mask or not. length_norm: boolean, to normalize k or not. Returns: output tensor of shape [qlength, dims] """ qdim, vdim = query.shape[-1], value.shape[-1] chunk_size = n_hashes * n_buckets seqlen = query.shape[0] with nn.stochastic(jax.random.PRNGKey(0)): rng = nn.make_rng() buckets = hash_vectors(query, rng, num_buckets=n_buckets, num_hashes=n_hashes) # buckets should be (seq_len) assert buckets.shape[-1] == n_hashes * seqlen total_hashes = n_hashes # create sort and unsort ticker = jax.lax.tie_in(query, jnp.arange(n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # ticker = jnp.tile(jnp.reshape(ticker, [1, -1]), [batch_size, 1]) sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sqk = jnp.take(query, st, axis=0) sv = jnp.take(value, st, axis=0) bkv_t = jnp.reshape(st, (chunk_size, -1)) bqk = jnp.reshape(sqk, (chunk_size, -1, qdim)) bv = jnp.reshape(sv, (chunk_size, -1, vdim)) bq = bqk bk = bqk if length_norm: bk = length_normalized(bk) # get previous chunks bk = look_one_back(bk) bv = look_one_back(bv) bkv_t = look_one_back(bkv_t) # compute dot product attention dots = jnp.einsum('hie,hje->hij', bq, bk) * (qdim**0.5) if causal_mask: # apply causal mask # TODO(yitay): This is not working yet # We don't need causal reformer for any task YET. pass dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True) slogits = jnp.reshape(dots_logsumexp, [-1]) dots = jnp.exp(dots - dots_logsumexp) x = jnp.matmul(dots, bv) x = jnp.reshape(x, [-1, qdim]) # Unsort o = permute_via_gather(x, undo_sort, sticker, axis=0) logits = permute_via_sort(slogits, sticker, undo_sort, axis=0) logits = jnp.reshape(logits, [total_hashes, seqlen, 1]) probs = jnp.exp(logits - logsumexp(logits, axis=0, keepdims=True)) o = jnp.reshape(o, [n_hashes, seqlen, qdim]) out = jnp.sum(o * probs, axis=0) out = jnp.reshape(out, [seqlen, qdim]) return out