Ejemplo n.º 1
0
 def test_stochastic_rngs(self):
   rng = random.PRNGKey(0)
   with nn.stochastic(rng):
     r1 = nn.make_rng()
     r2 = nn.make_rng()
   self.assertTrue(onp.all(r1 == random.fold_in(rng, 1)))
   self.assertTrue(onp.all(r2 == random.fold_in(rng, 2)))
Ejemplo n.º 2
0
    def test_train_one_step(self):
        batch = train.get_batch(128)
        rng = random.PRNGKey(0)

        with nn.stochastic(rng):
            model = train.create_model(nn.make_rng())
            optimizer = train.create_optimizer(model, 0.003)
            optimizer, train_metrics = train.train_step(
                optimizer, batch, nn.make_rng())

        self.assertLessEqual(train_metrics['loss'], 5)
        self.assertGreaterEqual(train_metrics['accuracy'], 0)
Ejemplo n.º 3
0
def train_model():
    """Train for a fixed number of steps and decode during training."""
    with nn.stochastic(jax.random.PRNGKey(0)):
        model = create_model(nn.make_rng())
        optimizer = create_optimizer(model, FLAGS.learning_rate)
        for step in range(FLAGS.num_train_steps):
            batch = get_batch(FLAGS.batch_size)
            optimizer, metrics = train_step(optimizer, batch, nn.make_rng())
            if step % FLAGS.decode_frequency == 0:
                logging.info('train step: %d, loss: %.4f, accuracy: %.2f',
                             step, metrics['loss'], metrics['accuracy'] * 100)
                decode_batch(optimizer.target, 5)
        return optimizer.target
def select_patches_perturbed_topk(flatten_scores,
                                  sigma,
                                  *,
                                  k,
                                  num_samples=1000):
    """Select patches using a differentiable top-k based on perturbation.

  Uses https://q-berthet.github.io/papers/BerBloTeb20.pdf,
  see off_the_grid.lib.ops.perturbed_topk for more info.

  Args:
    flatten_scores: The flatten scores of shape (batch, num_patches).
    sigma: Standard deviation of the noise.
    k: The number of patches to extract.
    num_samples: Number of noisy inputs used to compute the output expectation.

  Returns:
    Indicator vectors of the selected patches (batch, num_patches, k).
  """
    batch_size = flatten_scores.shape[0]

    batch_topk_fn = jax.vmap(
        functools.partial(perturbed_topk.perturbed_sorted_topk_indicators,
                          num_samples=num_samples,
                          sigma=sigma,
                          k=k))

    rng_keys = jax.random.split(nn.make_rng(), batch_size)
    indicators = batch_topk_fn(flatten_scores, rng_keys)
    topk_indicators_flatten = einops.rearrange(indicators, "b k d -> b d k")
    return topk_indicators_flatten
Ejemplo n.º 5
0
 def get_drop_pattern(self, x, layer_drop_p):
   if nn.is_stochastic() and layer_drop_p:
     rng = nn.make_rng()
     shape = (x.shape[0],) + (1,) * (x.ndim - 1)
     return jax.random.bernoulli(rng, layer_drop_p, shape).astype("float32")
   else:
     return 0.0
Ejemplo n.º 6
0
def create_model(key, input_shape):
    def inducing_loc_init(key, shape):
        return jnp.linspace(-1.5, 1.5, FLAGS.num_inducing_points)[:,
                                                                  jnp.newaxis]

    kwargs = {}
    for i in range(1, FLAGS.num_layers + 1):
        kwargs['kernel_fn_{}_kwargs'.format(i)] = {
            'amplitude_init': lambda key, shape: jnp.ones(shape),
            'length_scale_init': lambda key, shape: jnp.ones(shape)
        }
        kwargs['inducing_var_{}_kwargs'.format(i)] = {
            'fixed_locations': False,
            'whiten': FLAGS.whiten,
            'inducing_locations_init': inducing_loc_init
        }

    model_def = DeepGPModel.partial(**kwargs)

    with nn.stochastic(key):
        _, params = model_def.init_by_shape(key, [
            (input_shape, jnp.float64),
        ], nn.make_rng(), **kwargs)

        return nn.Model(model_def, params)
Ejemplo n.º 7
0
Archivo: train.py Proyecto: us/flax
  def apply(self, inputs, eos_id=1, hidden_size=512):
    # inputs.shape = (batch_size, seq_length, vocab_size).
    batch_size = inputs.shape[0]

    lstm_cell = nn.LSTMCell.partial(name='lstm')
    init_lstm_state = nn.LSTMCell.initialize_carry(
        nn.make_rng(),
        (batch_size,),
        hidden_size)

    def encode_step_fn(carry, x):
      lstm_state, is_eos = carry
      new_lstm_state, y = lstm_cell(lstm_state, x)
      # Pass forward the previous state if EOS has already been reached.
      def select_carried_state(new_state, old_state):
        return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
      # LSTM state is a tuple (c, h).
      carried_lstm_state = tuple(
          select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
      # Update `is_eos`.
      is_eos = jnp.logical_or(is_eos, x[:, eos_id])
      return (carried_lstm_state, is_eos), y

    (final_state, _), _ = jax_utils.scan_in_dim(
        encode_step_fn,
        init=(init_lstm_state, jnp.zeros(batch_size, dtype=np.bool)),
        xs=inputs,
        axis=1)
    return final_state
Ejemplo n.º 8
0
    def apply(self, init_state, inputs, teacher_force=False):
        # inputs.shape = (batch_size, seq_length, vocab_size).
        vocab_size = inputs.shape[2]
        lstm_cell = nn.LSTMCell.shared(name='lstm')
        projection = nn.Dense.shared(features=vocab_size, name='projection')

        def decode_step_fn(carry, x):
            rng, lstm_state, last_prediction = carry
            carry_rng, categorical_rng = jax.random.split(rng, 2)
            if not teacher_force:
                x = last_prediction
            lstm_state, y = lstm_cell(lstm_state, x)
            logits = projection(y)
            predicted_tokens = jax.random.categorical(categorical_rng, logits)
            prediction = onehot(predicted_tokens, vocab_size)
            return (carry_rng, lstm_state, prediction), (logits, prediction)

        init_carry = (nn.make_rng(), init_state, inputs[:, 0])

        if self.is_initializing():
            # initialize parameters before scan
            decode_step_fn(init_carry, inputs[:, 0])

        _, (logits, predictions) = jax_utils.scan_in_dim(
            decode_step_fn,
            init=init_carry,  # rng, lstm_state, last_pred
            xs=inputs,
            axis=1)
        return logits, predictions
Ejemplo n.º 9
0
def word_dropout(inputs: jnp.ndarray, rate: float, unk_idx: int, 
        deterministic: bool = False):
  """Replaces a fraction (rate) of inputs with <unk>."""
  if deterministic or rate == 0.:
    return inputs

  mask = jax.random.bernoulli(nn.make_rng(), p=rate, shape=inputs.shape)
  return jnp.where(mask, jnp.array([unk_idx]), inputs)
Ejemplo n.º 10
0
def create_model():
    """Creates a seq2seq model."""
    vocab_size = CTABLE.vocab_size
    _, initial_params = Seq2seq.partial(eos_id=CTABLE.eos_id).init_by_shape(
        nn.make_rng(), [((1, get_max_input_len(), vocab_size), jnp.float32),
                        ((1, get_max_output_len(), vocab_size), jnp.float32)])
    model = nn.Model(Seq2seq, initial_params)
    return model
Ejemplo n.º 11
0
Archivo: train.py Proyecto: us/flax
def decode_batch(model, batch_size):
  """Decode and log results for a batch."""
  batch = get_batch(batch_size)
  inputs, outputs = batch['query'], batch['answer'][:, 1:]
  inferred = decode(model, inputs, nn.make_rng())
  questions = decode_onehot(inputs)
  infers = decode_onehot(inferred)
  goldens = decode_onehot(outputs)
  for question, inferred, golden in zip(questions, infers, goldens):
    log_decode(question, inferred, golden)
Ejemplo n.º 12
0
    def apply(self,
              *args,
              wrapped_module,
              num_heads=1,
              num_parallel_heads=None,
              use_python_loop=False,
              **kwargs):
        # Re-use the same rng key across all examples and heads. This will result in
        # broadcasted dropout, which saves memory.
        # TODO(kitaev): options to swap broadcasted RNG on/off
        rng = nn.make_rng() if nn.is_stochastic() else None

        def init_single_head(init_rng, args, kwargs):
            if rng is None:
                _, head_params = wrapped_module.init(init_rng, *args, **kwargs)
            else:
                with nn.stochastic(rng):
                    _, head_params = wrapped_module.init(
                        init_rng, *args, **kwargs)
            return head_params

        def init_wrapped_module(rng, unused_shape):
            single_example_args = jax.tree_map(lambda x: x[:1], args)
            return multihead.chunked_multihead_map(
                init_single_head,
                in_has_batch_dim=(False, True, False),
                in_has_head_dim=(True, False, False),
                out_has_batch_dim=False,
                out_has_head_dim=True,
                use_python_loop=True,
            )(jax.random.split(rng, num_heads), single_example_args, kwargs)

        # TODO(kitaev): The original intent was to have this be a transparent module
        # but for some reason naming this parameter '0' and inheriting from
        # nn.base.TransparentModule is not enough to stop this parameter name from
        # explicitly showing up in the parameter tree.
        params = self.param('attn', None, init_wrapped_module)

        def run_single_example_and_head(params, args, kwargs):
            if rng is None:
                return wrapped_module.call(params, *args, **kwargs)
            else:
                with nn.stochastic(rng):
                    return wrapped_module.call(params, *args, **kwargs)

        return multihead.chunked_multihead_map(
            run_single_example_and_head,
            in_has_batch_dim=(False, True, False),
            in_has_head_dim=(True, False, False),
            out_has_batch_dim=True,
            out_has_head_dim=False,
            num_parallel_heads=num_parallel_heads,
            use_python_loop=use_python_loop,
        )(params, args, kwargs)
Ejemplo n.º 13
0
def train(train_ds):
    rng = random.PRNGKey(0)

    with nn.stochastic(rng):
        model = create_model(rng, train_ds['index_points'].shape)
        optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.beta1)

        key = nn.make_rng()

        for epoch in range(1, FLAGS.num_epochs + 1):
            key = random.split(key, FLAGS.num_samples + 1)
            key, sample_key = (key[0], key[1:])
            optimizer, metrics = train_epoch(optimizer, train_ds, epoch,
                                             sample_key)

    return optimizer
Ejemplo n.º 14
0
def drop_path(x: jnp.array, drop_rate: float = 0., rng=None) -> jnp.array:
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_rate == 0.:
        return x
    keep_prob = 1. - drop_rate
    if rng is None:
        rng = make_rng()
    mask = random.bernoulli(key=rng, p=keep_prob, shape=(x.shape[0], 1, 1, 1))
    mask = jnp.broadcast_to(mask, x.shape)
    return lax.select(mask, x / keep_prob, jnp.zeros_like(x))
Ejemplo n.º 15
0
def drop_path(x, drop_prob: float = 0., rng=None):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    # FIXME not tested
    if drop_prob == 0.:
        return x
    keep_prob = 1 - drop_prob
    if rng is None:
        rng = make_rng('dropout')
    random_tensor = keep_prob + random.bernoulli(
        key=rng, p=keep_prob, shape=(x.shape[0], 1, 1, 1))
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output
Ejemplo n.º 16
0
  def apply_param_gradient(self, step, hyper_params, param, state, grad):
    del step
    assert hyper_params.learning_rate is not None, "no learning rate provided."
    if hyper_params.weight_decay != 0:
      raise NotImplementedError("Weight decay not supported")

    noise = jax.random.normal(
        key=nn.make_rng(), shape=param.shape, dtype=param.dtype)

    momentum = state.momentum
    h = hyper_params.step_size
    gamma = hyper_params.friction
    t = hyper_params.temperature
    n = hyper_params.train_size

    new_momentum = (
        (1 - h * gamma) * momentum - h * n * grad +
        jnp.sqrt(2 * gamma * h * t) * jnp.sqrt(state.preconditioner) * noise)

    new_param = param + h * (1. / state.preconditioner) * new_momentum
    new_state = _SymEulerSGMCMCParamState(new_momentum, state.preconditioner)
    return new_param, new_state
Ejemplo n.º 17
0
    def apply(self, x, config, num_classes, train=True):
        """Creates a model definition."""
        b, c = x.shape[0], x.shape[3]
        k = config.k
        sigma = config.ptopk_sigma
        num_samples = config.ptopk_num_samples

        sigma *= self.state("sigma_mutiplier",
                            shape=(),
                            initializer=nn.initializers.ones).value

        stats = {"x": x, "sigma": sigma}

        feature_extractor = models.ResNet50.shared(train=train,
                                                   name="ResNet_0")

        rpn_feature = feature_extractor(x)
        rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature),
                                            communication=Communication(
                                                config.communication),
                                            train=train)
        stats.update(rpn_stats)

        # rpn_scores are a list of score images. We keep track of the structure
        # because it is used in the aggregation step later-on.
        rpn_scores_shapes = [s.shape for s in rpn_scores]
        rpn_scores_flat = jnp.concatenate(
            [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1)
        top_k_indicators = sample_patches.select_patches_perturbed_topk(
            rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples)
        top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1])
        offset = 0
        weights = []
        for sh in rpn_scores_shapes:
            cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]]
            cur = jnp.reshape(cur, [b, k, sh[1], sh[2]])
            weights.append(cur)
            offset += sh[1] * sh[2]
        chex.assert_equal(offset, top_k_indicators.shape[-1])

        part_imgs = weighted_anchor_aggregator(x, weights)
        chex.assert_shape(part_imgs, (b * k, 224, 224, c))
        stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c])

        part_features = feature_extractor(part_imgs)
        part_features = jnp.mean(part_features,
                                 axis=[1, 2])  # GAP the spatial dims

        part_features = nn.dropout(  # features from parts
            jnp.reshape(part_features, [b * k, 2048]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())
        features = nn.dropout(  # features from whole image
            jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())

        # Mean pool all part features, add it to features and predict logits.
        concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]),
                              axis=1) + features
        concat_logits = nn.Dense(concat_out, num_classes)
        raw_logits = nn.Dense(features, num_classes)
        part_logits = jnp.reshape(nn.Dense(part_features, num_classes),
                                  [b, k, -1])

        all_logits = {
            "raw_logits": raw_logits,
            "concat_logits": concat_logits,
            "part_logits": part_logits,
        }
        # add entropy into it for entropy regularization.
        stats["rpn_scores_entropy"] = jax.scipy.special.entr(
            jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0)
        return all_logits, stats
Ejemplo n.º 18
0
 def init_param_state(self, param):
   # TODO(basv): do we want to init momentum randomly?
   return _SymEulerSGMCMCParamState(
       jax.random.normal(nn.make_rng(), param.shape, param.dtype),
       jnp.ones_like(param))
Ejemplo n.º 19
0
    def apply(self,
              x,
              *,
              patch_size,
              k,
              downscale,
              scorer_has_se,
              normalization_str="identity",
              selection_method,
              selection_method_kwargs=None,
              selection_method_inference=None,
              patch_dropout=0.,
              hard_topk_probability=0.,
              random_patch_probability=0.,
              use_iterative_extraction,
              append_position_to_input,
              feature_network,
              aggregation_method,
              aggregation_method_kwargs=None,
              train):
        """Process a high resolution image by selecting a subset of useful patches.

    This model processes the input as follow:
    1. Compute scores per patch on a downscaled version of the input.
    2. Select "important" patches using sampling or top-k methods.
    3. Extract the patches from the high-resolution image.
    4. Compute representation vector for each patch with a feature network.
    5. Aggregate the patch representation to obtain an image representation.

    Args:
      x: Input tensor of shape (batch, height, witdh, channels).
      patch_size: Size of the (squared) patches to extract.
      k: Number of patches to extract per image.
      downscale: Downscale multiplier for the input of the scorer network.
      scorer_has_se: Whether scorer network has Squeeze-excite layers.
      normalization_str: String specifying the normalization of the scores.
      selection_method: Method that selects which patches should be extracted,
        based on their scores. Either returns indices (hard selection) or
        indicators vectors (which could yield interpolated patches).
      selection_method_kwargs: Keyword args for the selection_method.
      selection_method_inference: Selection method used at inference.
      patch_dropout: Probability to replace a patch by 0 values.
      hard_topk_probability: Probability to use the true topk on the scores to
        select the patches. This operation has no gradient so scorer's weights
        won't be trained.
      random_patch_probability: Probability to replace each patch by a random
        patch in the image during training.
      use_iterative_extraction: If True, uses a for loop instead of patch
        indexing for memory efficiency.
      append_position_to_input: Append normalized (height, width) position to
        the channels of the input.
      feature_network: Network to be applied on each patch individually to
        obtain patch representation vectors.
      aggregation_method: Method to aggregate the representations of the k
        patches of each image to obtain the image representation.
      aggregation_method_kwargs: Keywords arguments for aggregation_method.
      train: If the model is being trained. Disable dropout otherwise.

    Returns:
      A representation vector for each image in the batch.
    """
        selection_method = SelectionMethod(selection_method)
        aggregation_method = AggregationMethod(aggregation_method)
        if selection_method_inference:
            selection_method_inference = SelectionMethod(
                selection_method_inference)

        selection_method_kwargs = selection_method_kwargs or {}
        aggregation_method_kwargs = aggregation_method_kwargs or {}

        stats = {}

        # Compute new dimension of the scoring image.
        b, h, w, c = x.shape
        scoring_shape = (b, h // downscale, w // downscale, c)

        # === Compute the scores with a small CNN.
        if selection_method == SelectionMethod.RANDOM:
            scores_h, scores_w = Scorer.compute_output_size(
                h // downscale, w // downscale)
            num_patches = scores_h * scores_w
        else:
            # Downscale input to run scorer on.
            scoring_x = jax.image.resize(x, scoring_shape, method="bilinear")
            scores = Scorer(scoring_x,
                            use_squeeze_excite=scorer_has_se,
                            name="scorer")
            flatten_scores = einops.rearrange(scores, "b h w -> b (h w)")
            num_patches = flatten_scores.shape[-1]
            scores_h, scores_w = scores.shape[1:3]

            # Compute entropy before normalization
            prob_scores = jax.nn.softmax(flatten_scores)
            stats["entropy_before_normalization"] = jax.scipy.special.entr(
                prob_scores).sum(axis=1).mean(axis=0)

            # Normalize the flatten scores
            normalization_fn = create_normalization_fn(normalization_str)
            flatten_scores = normalization_fn(flatten_scores)
            scores = flatten_scores.reshape(scores.shape)
            stats["scores"] = scores[Ellipsis, None]

        # Concatenate height and width position to the input channels.
        if append_position_to_input:
            coords = utils.create_grid([h, w], value_range=(0., 1.))
            x = jnp.concatenate(
                [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1)
            c += 2

        # Overwrite the selection method at inference
        if selection_method_inference and not train:
            selection_method = selection_method_inference

        # === Patch selection

        # Select the patches by sampling or top-k. Some methods returns the indices
        # of the selected patches, other methods return indicator vectors.
        extract_by_indices = selection_method in [
            SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM
        ]
        if selection_method is SelectionMethod.SINKHORN_TOPK:
            indicators = select_patches_sinkhorn_topk(
                flatten_scores, k=k, **selection_method_kwargs)
        elif selection_method is SelectionMethod.PERTURBED_TOPK:
            sigma = selection_method_kwargs["sigma"]
            num_samples = selection_method_kwargs["num_samples"]
            sigma *= self.state("sigma_mutiplier",
                                shape=(),
                                initializer=nn.initializers.ones).value
            stats["sigma"] = sigma
            indicators = select_patches_perturbed_topk(flatten_scores,
                                                       k=k,
                                                       sigma=sigma,
                                                       num_samples=num_samples)
        elif selection_method is SelectionMethod.HARD_TOPK:
            indices = select_patches_hard_topk(flatten_scores, k=k)
        elif selection_method is SelectionMethod.RANDOM:
            batch_random_indices_fn = jax.vmap(
                functools.partial(jax.random.choice,
                                  a=num_patches,
                                  shape=(k, ),
                                  replace=False))
            indices = batch_random_indices_fn(
                jax.random.split(nn.make_rng(), b))

        # Compute scores entropy for regularization
        if selection_method not in [SelectionMethod.RANDOM]:
            prob_scores = flatten_scores
            # Normalize the scores if it is not already done.
            if "softmax" not in normalization_str:
                prob_scores = jax.nn.softmax(prob_scores)
            stats["entropy"] = jax.scipy.special.entr(prob_scores).sum(
                axis=1).mean(axis=0)

        # Randomly use hard topk at training.
        if (train and hard_topk_probability > 0 and selection_method
                not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]):
            true_indices = select_patches_hard_topk(flatten_scores, k=k)
            random_values = jax.random.uniform(nn.make_rng(), (b, ))
            use_hard = random_values < hard_topk_probability
            if extract_by_indices:
                indices = jnp.where(use_hard[:, None], true_indices, indices)
            else:
                true_indicators = make_indicators(true_indices, num_patches)
                indicators = jnp.where(use_hard[:, None, None],
                                       true_indicators, indicators)

        # Sample some random patches during training with random_patch_probability.
        if (train and random_patch_probability > 0
                and selection_method is not SelectionMethod.RANDOM):
            single_random_patches = functools.partial(jax.random.choice,
                                                      a=num_patches,
                                                      shape=(k, ),
                                                      replace=False)
            random_indices = jax.vmap(single_random_patches)(jax.random.split(
                nn.make_rng(), b))
            random_values = jax.random.uniform(nn.make_rng(), (b, k))
            use_random = random_values < random_patch_probability
            if extract_by_indices:
                indices = jnp.where(use_random, random_indices, indices)
            else:
                random_indicators = make_indicators(random_indices,
                                                    num_patches)
                indicators = jnp.where(use_random[:, None, :],
                                       random_indicators, indicators)

        # === Patch extraction
        if extract_by_indices:
            patches = extract_patches_from_indices(x,
                                                   indices,
                                                   patch_size=patch_size,
                                                   grid_shape=(scores_h,
                                                               scores_w))
            indicators = make_indicators(indices, num_patches)
        else:
            patches = extract_patches_from_indicators(
                x,
                indicators,
                patch_size,
                grid_shape=(scores_h, scores_w),
                iterative=use_iterative_extraction,
                patch_dropout=patch_dropout,
                train=train)

        chex.assert_shape(patches, (b, k, patch_size, patch_size, c))

        stats["extracted_patches"] = einops.rearrange(
            patches, "b k i j c -> b i (k j) c")
        # Remove position channels for plotting.
        if append_position_to_input:
            stats["extracted_patches"] = (
                stats["extracted_patches"][Ellipsis, :-2])

        # === Compute patch features
        flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c")
        representations = feature_network(flatten_patches, train=train)
        if representations.ndim > 2:
            collapse_axis = tuple(range(1, representations.ndim - 1))
            representations = representations.mean(axis=collapse_axis)
        representations = einops.rearrange(representations,
                                           "(b k) d -> b k d",
                                           k=k)

        stats["patch_representations"] = representations

        # === Aggregate the k patches

        # - for sampling we are forced to take an expectation
        # - for topk we have multiple options: mean, max, transformer.
        if aggregation_method is AggregationMethod.TRANSFORMER:
            patch_pos_encoding = nn.Dense(einops.rearrange(
                indicators, "b d k -> b k d"),
                                          features=representations.shape[-1])

            chex.assert_equal_shape([representations, patch_pos_encoding])
            representations += patch_pos_encoding
            representations = transformer.Transformer(
                representations,
                **aggregation_method_kwargs,
                is_training=train)

        elif aggregation_method is AggregationMethod.MEANPOOLING:
            representations = representations.mean(axis=1)
        elif aggregation_method is AggregationMethod.MAXPOOLING:
            representations = representations.max(axis=1)
        elif aggregation_method is AggregationMethod.SUM_LAYERNORM:
            representations = representations.sum(axis=1)
            representations = nn.LayerNorm(representations)

        representations = nn.Dense(representations,
                                   features=representations.shape[-1],
                                   name="classification_dense1")
        representations = nn.swish(representations)

        return representations, stats
Ejemplo n.º 20
0
 def apply(self):
   return nn.make_rng()
Ejemplo n.º 21
0
 def test_decode_batch(self):
     with nn.stochastic(random.PRNGKey(0)):
         model = train.create_model(nn.make_rng())
         train.decode_batch(model, 5)
Ejemplo n.º 22
0
 def test_make_rng_requires_stochastic(self):
   with self.assertRaises(ValueError):
     nn.make_rng()
Ejemplo n.º 23
0
def dot_product_attention(query,
                          key,
                          value,
                          dtype=jnp.float32,
                          bias=None,
                          axis=None,
                          broadcast_dropout=True,
                          dropout_rng=None,
                          dropout_rate=0.,
                          deterministic=False,
                          precision=None):
    """Computes dot-product attention given query, key, and value.

  This is the core function for applying attention based on
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
  query and key and combines the values using the attention weights. This
  function supports multi-dimensional inputs. This version is modified to
  move the softmax division after the dot product.


  Args:
    query: queries for calculating attention with shape of `[batch_size, dim1,
      dim2, ..., dimN, num_heads, mem_channels]`.
    key: keys for calculating attention with shape of `[batch_size, dim1, dim2,
      ..., dimN, num_heads, mem_channels]`.
    value: values to be used in attention with shape of `[batch_size, dim1,
      dim2,..., dimN, num_heads, value_channels]`.
    dtype: the dtype of the computation (default: float32)
    bias: bias for the attention weights. This can be used for incorporating
      autoregressive mask, padding mask, proximity bias.
    axis: axises over which the attention is applied.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.

  Returns:
    Output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`.
  """
    assert key.shape[:-1] == value.shape[:-1]
    assert (query.shape[0:1] == key.shape[0:1]
            and query.shape[-1] == key.shape[-1])

    if axis is None:
        axis = tuple(range(1, key.ndim - 2))
    if not isinstance(axis, Iterable):
        axis = (axis, )
    assert key.ndim == query.ndim
    assert key.ndim == value.ndim
    for ax in axis:
        if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
            raise ValueError('Attention axis must be between the batch '
                             'axis and the last-two axes.')
    depth = query.shape[-1]
    n = key.ndim
    # batch_dims is  <bs, <non-attention dims>, num_heads>
    batch_dims = tuple(np.delete(range(n), axis + (n - 1, )))
    # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
    qk_perm = batch_dims + axis + (n - 1, )
    key = key.transpose(qk_perm)
    query = query.transpose(qk_perm)
    # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>)
    v_perm = batch_dims + (n - 1, ) + axis
    value = value.transpose(v_perm)

    query = query / jnp.sqrt(depth).astype(dtype)
    batch_dims_t = tuple(range(len(batch_dims)))
    attn_weights = lax.dot_general(query,
                                   key, (((n - 1, ), (n - 1, )),
                                         (batch_dims_t, batch_dims_t)),
                                   precision=precision)

    # apply attention bias: masking, droput, proximity bias, ect.
    if bias is not None:
        attn_weights = attn_weights + bias

    # normalize the attention weights
    norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim))
    decoding = attn_weights.shape[-2] != 256
    if decoding:
        attn_weights = lax.exp(attn_weights - jax.scipy.special.logsumexp(
            attn_weights, axis=norm_dims, keepdims=True))
    else:
        # move the division by the softmax denominator to after the dot product
        attn_weights = jnp.exp(attn_weights - lax.stop_gradient(
            jnp.max(attn_weights, axis=norm_dims, keepdims=True)))
        softmax_denominator = jnp.sum(attn_weights,
                                      axis=norm_dims,
                                      keepdims=False)
    attn_weights = attn_weights.astype(dtype)

    # apply dropout
    if not deterministic and dropout_rate > 0.:
        if dropout_rng is None:
            dropout_rng = nn.make_rng()
        keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate)
        if broadcast_dropout:
            # dropout is broadcast across the batch+head+non-attention dimension
            dropout_dims = attn_weights.shape[-(2 * len(axis)):]
            dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims)
            keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        else:
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
        multiplier = (keep.astype(attn_weights.dtype) /
                      jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    # compute the new values given the attention weights
    wv_contracting_dims = (norm_dims, range(value.ndim - len(axis),
                                            value.ndim))
    y = lax.dot_general(attn_weights,
                        value,
                        (wv_contracting_dims, (batch_dims_t, batch_dims_t)),
                        precision=precision)
    if not decoding:
        # divide by the denominator of the attention softmax now, when the array is
        # O(N*H) rather than O(N^2)
        y = y / jnp.expand_dims(softmax_denominator, -1)

    # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
    perm_inv = _invert_perm(qk_perm)
    y = y.transpose(perm_inv)
    return y
Ejemplo n.º 24
0
def self_attention(inputs,
                   variable_dictionary,
                   num_heads: int,
                   qkv_features: int = None,
                   padding_mask: List[bool] = None,
                   dropout_rate: float = 0.,
                   deterministic: bool = False,
                   precision: Precision = None,
                   kernel_init: List[float] = nn.linear.default_kernel_init,
                   bias_init: List[float] = nn.initializers.zeros,
                   dtype: jnp.dtype = jnp.float32,
                   bias: bool = True):
    """Applies Multi-head self-attention on the input data.

  Args:
    inputs: input data of shape `[bs, dim1, dim2, ..., dimN, features]`.
    variable_dictionary: Parameter dictionary.
    num_heads: number of attention heads. Features (i.e. inputs.shape[-1])
      should be divisible by the number of heads.
    qkv_features: dimension of the key, query, and value.
    padding_mask: boolean specifying tokens that are pad token.
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.
    kernel_init: initializer for the kernel of the Dense layers.
    bias_init: initializer for the bias of the Dense layers.
    dtype: datatype for the activiations, jnp.bfloat16 or jnp.float32
    bias: bool: whether pointwise QKVO dense transforms use bias.

  Returns:
    output of shape `[bs, dim1, dim2, ..., dimN, features//num_heads]`.
  """

    features = inputs.shape[-1]
    qkv_features = qkv_features or features

    assert qkv_features % num_heads == 0, (
        'Memory dimension must be divisible by number of heads.')
    head_dim = qkv_features // num_heads
    inputs = inputs.astype(dtype)
    if FLAGS.use_einsum:
        dense_module = Dense3D
    else:
        dense_module = attention.DenseGeneral

    query = dense_module.call(variable_dictionary['query'],
                              inputs,
                              axis=-1,
                              features=(num_heads, head_dim),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              precision=precision,
                              dtype=dtype,
                              name='query')
    query = jnp.multiply(query, 1.0 / math.sqrt(float(head_dim)))
    key = dense_module.call(variable_dictionary['key'],
                            inputs,
                            axis=-1,
                            features=(num_heads, head_dim),
                            kernel_init=kernel_init,
                            bias_init=bias_init,
                            bias=bias,
                            precision=precision,
                            dtype=dtype,
                            name='key')
    value = dense_module.call(variable_dictionary['value'],
                              inputs,
                              axis=-1,
                              features=(num_heads, head_dim),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              precision=precision,
                              dtype=dtype,
                              name='value')

    assert query.dtype == dtype
    assert key.dtype == dtype
    assert value.dtype == dtype
    # get raw attention scores from dot product between key and query
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_heads`
    #   H = `head_dim` (qkv_features // num_heads)
    attention_scores = jnp.einsum('BTNH,BFNH->BNFT', key, query)
    assert attention_scores.dtype == dtype

    assert attention_scores.dtype == dtype
    # create attention masks
    if padding_mask is not None:
        assert padding_mask.dtype == bool, ('Mask should have bool type.')
        attention_mask = jnp.expand_dims(padding_mask, axis=1)
        adder = (1.0 - attention_mask) * NEG_INFINITY
        attention_scores += adder.astype(dtype)
    assert attention_scores.dtype == dtype

    attention_scores = attention_scores - lax.stop_gradient(
        jnp.max(attention_scores, axis=-1, keepdims=True))
    attention_scores = jnp.exp(attention_scores)
    attention_sum = jnp.sum(attention_scores, axis=-1, keepdims=True)

    keep_prob = 1 - dropout_rate
    if not deterministic:
        keep_mask = jax.random.bernoulli(nn.make_rng(), keep_prob,
                                         attention_scores.shape).astype(dtype)
        assert keep_mask.dtype == dtype
        attention_probs = jnp.multiply(keep_mask, attention_scores)
    else:
        attention_probs = attention_scores

    assert attention_probs.dtype == dtype

    attention_probs = jnp.einsum('BNFT,BTNH->BFNH', attention_probs, value)
    assert attention_probs.dtype == dtype
    attention_probs = attention_probs / jnp.transpose(attention_sum,
                                                      [0, 2, 1, 3])

    # split mask and scaling ops in dropout
    # move the scaling from dropout to here to save same mul ops
    # TODO(yuemmawang) automate this optimization in xla
    if not deterministic:
        scale = 1 / keep_prob
        if dtype == jnp.bfloat16:
            scale = jnp.bfloat16(scale)
        attention_probs = jnp.multiply(attention_probs, scale)
    assert attention_probs.dtype == dtype

    return attention_probs
def lsh_attention_single_head(query,
                              value,
                              n_buckets,
                              n_hashes,
                              causal_mask=True,
                              length_norm=False):
    """Applies LSH attention on a single head and a single batch.

  Args:
    query: query tensor of shape [qlength, dims].
    value: value tensor of shape [vlength, dims].
    n_buckets: integer, number of buckets.
    n_hashes: integer, number of hashes.
    causal_mask: boolean, to use causal mask or not.
    length_norm: boolean, to normalize k or not.
  Returns:
    output tensor of shape [qlength, dims]
  """

    qdim, vdim = query.shape[-1], value.shape[-1]
    chunk_size = n_hashes * n_buckets

    seqlen = query.shape[0]

    with nn.stochastic(jax.random.PRNGKey(0)):
        rng = nn.make_rng()

    buckets = hash_vectors(query,
                           rng,
                           num_buckets=n_buckets,
                           num_hashes=n_hashes)
    # buckets should be (seq_len)
    assert buckets.shape[-1] == n_hashes * seqlen

    total_hashes = n_hashes

    # create sort and unsort
    ticker = jax.lax.tie_in(query, jnp.arange(n_hashes * seqlen))
    buckets_and_t = seqlen * buckets + (ticker % seqlen)
    buckets_and_t = jax.lax.stop_gradient(buckets_and_t)
    # ticker = jnp.tile(jnp.reshape(ticker, [1, -1]), [batch_size, 1])
    sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                   ticker,
                                                   dimension=-1)
    _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
    sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
    sticker = jax.lax.stop_gradient(sticker)
    undo_sort = jax.lax.stop_gradient(undo_sort)

    st = (sticker % seqlen)

    sqk = jnp.take(query, st, axis=0)
    sv = jnp.take(value, st, axis=0)

    bkv_t = jnp.reshape(st, (chunk_size, -1))
    bqk = jnp.reshape(sqk, (chunk_size, -1, qdim))
    bv = jnp.reshape(sv, (chunk_size, -1, vdim))
    bq = bqk
    bk = bqk

    if length_norm:
        bk = length_normalized(bk)

    # get previous chunks
    bk = look_one_back(bk)
    bv = look_one_back(bv)
    bkv_t = look_one_back(bkv_t)

    # compute dot product attention
    dots = jnp.einsum('hie,hje->hij', bq, bk) * (qdim**0.5)

    if causal_mask:
        # apply causal mask
        # TODO(yitay): This is not working yet
        # We don't need causal reformer for any task YET.
        pass

    dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True)
    slogits = jnp.reshape(dots_logsumexp, [-1])
    dots = jnp.exp(dots - dots_logsumexp)

    x = jnp.matmul(dots, bv)
    x = jnp.reshape(x, [-1, qdim])

    # Unsort
    o = permute_via_gather(x, undo_sort, sticker, axis=0)
    logits = permute_via_sort(slogits, sticker, undo_sort, axis=0)
    logits = jnp.reshape(logits, [total_hashes, seqlen, 1])
    probs = jnp.exp(logits - logsumexp(logits, axis=0, keepdims=True))
    o = jnp.reshape(o, [n_hashes, seqlen, qdim])
    out = jnp.sum(o * probs, axis=0)
    out = jnp.reshape(out, [seqlen, qdim])

    return out