예제 #1
0
    def align_batches(self, x, y, x_labels, y_labels, supervised=True):
        """Computes alignment between two mini batches.

    In the MultiEnvDomainMappingClassification, this calls the random alignment
    (based on labels) function.

    Args:
      x: jnp array; Batch of representations with shape '[bs, feature_size]'.
      y: jnp array; Batch of representations with shape '[bs, feature_size]'.
      x_labels: jnp array; labels of x with shape '[bs, 1]'.
      y_labels: jnp array; labels of y with shape '[bs, 1]'.
      supervised: bool; If False we can not use y_labels and it defaults back to
        random alignment otherwise it does label based alignment (tries to align
        examples that have similar labels).

    Returns:
      aligned indexes of x, aligned indexes of y.
    """
        del y
        # Get aligned example pairs.
        if supervised:
            rng = nn.make_rng()
            new_rngs = jax.random.split(rng, len(x_labels))
            aligned_pairs_idx = domain_mapping_utils.align_examples(
                new_rngs, x_labels, jnp.arange(len(x_labels)), y_labels)
        else:
            number_of_examples = len(x)
            rng = nn.make_rng()
            matching_matrix = jnp.eye(number_of_examples)
            matching_matrix = jax.random.permutation(rng, matching_matrix)

            aligned_pairs_idx = jnp.arange(len(x)), jnp.argmax(matching_matrix,
                                                               axis=-1)

        return aligned_pairs_idx
예제 #2
0
    def align_batches(self, x, y, x_labels, y_labels):
        """Computes optimal transport between two batches with Sinkhorn algorithm.

    This calls a sinkhorn solver in dual (log) space with a finite number
    of iterations and uses the dual unregularized transport cost as the OT cost.

    Args:
      x: jnp array; Batch of representations with shape '[bs, feature_size]'.
      y: jnp array; Batch of representations with shape '[bs, feature_size]'.
      x_labels: jnp array; labels of x with shape '[bs, 1]'.
      y_labels: jnp array; labels of y with shape '[bs, 1]'.

    Returns:
      ot_cost: scalar optimal transport loss.
    """

        epsilon = self.task_params.get('sinkhorn_eps', 0.1)
        num_iters = self.task_params.get('sinkhorn_iters', 50)
        label_weight = self.task_params.get('ot_label_cost', 0.)
        l2_weight = self.task_params.get('ot_l2_cost', 0.)
        noise_weight = self.task_params.get('ot_noise_cost', 1.0)
        x = x.reshape((x.shape[0], -1))
        y = y.reshape((x.shape[0], -1))

        # Solve sinkhorn in log space.
        num_x = x.shape[0]
        num_y = y.shape[0]

        x = x.reshape((num_x, -1))
        y = y.reshape((num_y, -1))

        # Marginal of rows (a) and columns (b)
        a = jnp.ones(shape=(num_x, ), dtype=x.dtype)
        b = jnp.ones(shape=(num_y, ), dtype=y.dtype)

        # TODO(samiraabnar): Check range of l2 cost?
        cost = domain_mapping_utils.pairwise_l2(x, y)

        # Adjust cost such that representations with different labels
        # get assigned a very high cost.
        same_labels = domain_mapping_utils.pairwise_equality_1d(
            x_labels, y_labels)
        adjusted_cost = (1 - same_labels) * label_weight + l2_weight * cost

        # Add noise to the cost.
        adjusted_cost += noise_weight * jax.random.uniform(
            nn.make_rng(), minval=0, maxval=1.0)
        _, matching, _ = domain_mapping_utils.sinkhorn_dual_solver(
            a, b, adjusted_cost, epsilon, num_iters)
        matching = domain_mapping_utils.round_coupling(matching, a, b)
        if self.task_params.get('interpolation_mode', 'hard') == 'hard':
            matching = domain_mapping_utils.sample_best_permutation(
                nn.make_rng(), coupling=matching, cost=adjusted_cost)

        return matching
예제 #3
0
def select_patches_perturbed_topk(flatten_scores,
                                  sigma,
                                  *,
                                  k,
                                  num_samples=1000):
    """Select patches using a differentiable top-k based on perturbation.

  Uses https://q-berthet.github.io/papers/BerBloTeb20.pdf,
  see off_the_grid.lib.ops.perturbed_topk for more info.

  Args:
    flatten_scores: The flatten scores of shape (batch, num_patches).
    sigma: Standard deviation of the noise.
    k: The number of patches to extract.
    num_samples: Number of noisy inputs used to compute the output expectation.

  Returns:
    Indicator vectors of the selected patches (batch, num_patches, k).
  """
    batch_size = flatten_scores.shape[0]

    batch_topk_fn = jax.vmap(
        functools.partial(perturbed_topk.perturbed_sorted_topk_indicators,
                          num_samples=num_samples,
                          sigma=sigma,
                          k=k))

    rng_keys = jax.random.split(nn.make_rng(), batch_size)
    indicators = batch_topk_fn(flatten_scores, rng_keys)
    topk_indicators_flatten = einops.rearrange(indicators, "b k d -> b d k")
    return topk_indicators_flatten
  def maybe_inter_env_interpolation(self, batch, env_ids, flax_model,
                                    interpolate_fn, sampled_layer, sampled_reps,
                                    selected_env_reps, train_state):
    if len(env_ids) > 1 and self.hparams.get('inter_env_interpolation', True):
      # We call the alignment method of the task class:
      aligned_pairs = self.task.get_env_aligned_pairs_idx(
          selected_env_reps, batch, env_ids)
      pair_keys, alignments = zip(*aligned_pairs.items())

      # Convert alignments which is the array of aligned indices to match mat.
      alignments = jnp.asarray(alignments)
      num_env_pairs = alignments.shape[0]
      batch_size = alignments.shape[2]
      matching_matrix = jnp.zeros(
          shape=(num_env_pairs, batch_size, batch_size), dtype=jnp.float32)
      matching_matrix = matching_matrix.at[:, alignments[:, 0],
                                           alignments[:, 1]].set(1.0)

      # Convert pair keys to pair ids (indices in the env_ids list).
      pair_ids = [(env_ids.index(int(x[0])), env_ids.index(int(x[1])))
                  for x in pair_keys]

      # Get sampled layer activations and group them similar to env pairs.
      paired_reps = jnp.array([
          (sampled_reps[envs[0]], sampled_reps[envs[1]]) for envs in pair_ids
      ])

      # Set alpha and beta for sampling lambda:
      beta_params = pipeline_utils.get_weight_param(self.hparams, 'inter_beta',
                                                    1.0)
      alpha_params = pipeline_utils.get_weight_param(self.hparams,
                                                     'inter_alpha', 1.0)
      beta = pipeline_utils.scheduler(train_state.global_step, beta_params)
      alpha = pipeline_utils.scheduler(train_state.global_step, alpha_params)

      # Get interpolated reps for each env pair:
      inter_reps, sample_lambdas = interpolate_fn(
          jax.random.split(nn.make_rng(), len(paired_reps[:, 0])),
          matching_matrix, paired_reps[:, 0], paired_reps[:, 1],
          self.hparams.get('num_of_lambdas_samples_for_inter_mixup',
                           1), alpha, beta, -1)

      # Get interpolated batches for each env pair:
      interpolated_batches = self.get_interpolated_batches(
          batch, inter_reps, pair_ids, sample_lambdas,
          self.hparams.get('intra_interpolation_method',
                           'plain_convex_combination'))

      if self.hparams.get('stop_grad_for_inter_mixup', True):
        interpolated_batches = jax.lax.stop_gradient(interpolated_batches)

      # Compute logits for the interpolated states:
      _, interpolated_logits, _, train_state = self.stateful_forward_pass(
          flax_model, train_state, interpolated_batches, sampled_layer)

      return (interpolated_batches, interpolated_logits, sample_lambdas,
              train_state)

    return None, None, 0, train_state
 def stateless_forward_pass(self,
                            flax_model,
                            train_state,
                            batch,
                            input_key='input'):
     (all_env_logits, all_env_reps, selected_env_reps,
      _) = self.forward_pass(flax_model, train_state, batch, nn.make_rng(),
                             input_key)
     return all_env_logits, all_env_reps, selected_env_reps
예제 #6
0
def get_self_matching_matrix(batch,
                             reps,
                             mode='random',
                             label_cost=1.0,
                             l2_cost=1.0):
    """Align examples in a batch.

  Args:
    batch: list(dict); Batch of examples (with inputs, and label keys).
    reps: list(jnp array); List of representations of a selected layer for each
      batch.
    mode: str; Determines alignment method.
    label_cost: float; Weight of label cost when Sinkhorn matching is used.
    l2_cost: float; Weight of l2 cost when Sinkhorn matching is used.

  Returns:
    Matching matrix with shape `[num_batches, batch_size, batch_size]`.
  """
    if mode == 'random':
        number_of_examples = batch['inputs'].shape[0]
        rng = nn.make_rng()
        matching_matrix = jnp.eye(number_of_examples)
        matching_matrix = jax.random.permutation(rng, matching_matrix)
    elif mode == 'sinkhorn':
        epsilon = 0.1
        num_iters = 100

        reps = reps.reshape((reps.shape[0], -1))
        x = y = reps
        x_labels = y_labels = batch['label']

        # Solve sinkhorn in log space.
        num_x = x.shape[0]
        num_y = y.shape[0]

        # Marginal of rows (a) and columns (b)
        a = jnp.ones(shape=(num_x, ), dtype=x.dtype)
        b = jnp.ones(shape=(num_y, ), dtype=y.dtype)
        cost = domain_mapping_utils.pairwise_l2(x, y)
        cost += jnp.eye(num_x) * jnp.max(cost) * 10

        # Adjust cost such that representations with different labels
        # get assigned a very high cost.
        same_labels = domain_mapping_utils.pairwise_equality_1d(
            x_labels, y_labels)

        adjusted_cost = (1 - same_labels) * label_cost + l2_cost * cost
        _, matching, _ = domain_mapping_utils.sinkhorn_dual_solver(
            a, b, adjusted_cost, epsilon, num_iters)

        matching_matrix = domain_mapping_utils.round_coupling(
            matching, jnp.ones((matching.shape[0], )),
            jnp.ones((matching.shape[1], )))
    else:
        raise ValueError(
            '%s mode for self matching alignment is not supported.' % mode)
    return matching_matrix
 def stateful_forward_pass(self,
                           flax_model,
                           train_state,
                           batch,
                           input_key='input',
                           train=True):
   (env_logits, all_env_reps, selected_env_reps,
    new_model_state) = self.forward_pass(flax_model, train_state, batch,
                                         nn.make_rng(), input_key, train)
   #  Model state, e.g. batch statistics, are averaged over all environments
   #  because we use vmapped_flax_module_train.
   new_model_state = jax.tree_util.tree_map(
       functools.partial(jnp.mean, axis=0), new_model_state)
   # Update the model state already, since there is going to be another forward
   # pass.
   train_state = train_state.replace(model_state=new_model_state)
   return all_env_reps, env_logits, selected_env_reps, train_state
  def maybe_intra_env_interpolation(self, batch, env_ids, flax_model,
                                    interpolate_fn, sampled_layer, sampled_reps,
                                    train_state):
    if self.hparams.get('intra_env_interpolation', True):
      # Set alpha ans beta for sampling lambda:
      beta_params = pipeline_utils.get_weight_param(self.hparams, 'beta', 1.0)
      alpha_params = pipeline_utils.get_weight_param(self.hparams, 'alpha', 1.0)
      step = train_state.global_step
      beta = pipeline_utils.scheduler(step, beta_params)
      alpha = pipeline_utils.scheduler(step, alpha_params)

      # This is just a random matching (similar to manifold mixup paper).
      self_aligned_matching_matrix, self_pair_ids = self.get_intra_env_matchings(
          batch, sampled_reps, env_ids)

      # Compute interpolated representations of sampled layer:
      same_env_inter_reps, sample_lambdas = interpolate_fn(
          jax.random.split(nn.make_rng(), len(sampled_reps)),
          self_aligned_matching_matrix, sampled_reps, sampled_reps,
          self.hparams.get('num_of_lambdas_samples_for_mixup',
                           1), alpha, beta, -1)

      # Get interpolated batches (interpolated inputs, labels, and weights)
      same_env_interpolated_batches = self.get_interpolated_batches(
          batch, same_env_inter_reps, self_pair_ids, sample_lambdas,
          self.hparams.get('intra_interpolation_method',
                           'plain_convex_combination'))

      if self.hparams.get('stop_grad_for_intra_mixup', True):
        same_env_interpolated_batches = jax.lax.stop_gradient(
            same_env_interpolated_batches)

      # Compute logits for the interpolated states:
      (_, same_env_interpolated_logits, _,
       train_state) = self.stateful_forward_pass(flax_model, train_state,
                                                 same_env_interpolated_batches,
                                                 sampled_layer)

      return (same_env_interpolated_batches, same_env_interpolated_logits,
              sample_lambdas, train_state)

    return None, None, 0, train_state
예제 #9
0
    def setup_transformers(self, hidden_reps_dim):
        """Sets up linear transformers for the auxiliary loss.

    Args:
      hidden_reps_dim: int; Dimensionality of the representational space (size
        of the representations used for computing the domain mapping loss.
    """
        transformer_class = self.get_transformer_module(hidden_reps_dim)
        self.state_transformers = {}
        env_keys = list(map(int, self.dataset.splits.train.keys()))
        # Get list of all possible environment pairs (this includes
        # different permutations).
        env_pairs = list(itertools.permutations(env_keys, 2))

        rng = nn.make_rng()
        for env_pair in env_pairs:
            rng, params_rng = jax.random.split(rng)
            _, init_params = transformer_class.init_by_shape(
                params_rng, [((1, hidden_reps_dim), jnp.float32)])
            self.state_transformers[env_pair] = nn.Model(
                transformer_class, init_params)
    def maybe_reset_train_state(self):
        optimizer = jax_utils.unreplicate(self.train_state.optimizer)

        if self.hparams.get('reinitilize_params_at_each_step', False):
            del optimizer.target
            (flax_model, _, _) = pipeline_utils.create_flax_module(
                optimizer.target.module,
                self.task.dataset.meta_data['input_shape'], self.hparams,
                nn.make_rng(),
                self.task.dataset.meta_data.get('input_dtype', jnp.float32))
        else:
            flax_model = optimizer.target

        # Reset optimizer
        if self.hparams.get('reinitialize_optimizer_at_each_step', False):
            optimizer = optimizers.get_optimizer(
                self.hparams).create(flax_model)
        else:
            optimizer = optimizer.replace(target=flax_model)

        optimizer = jax_utils.replicate(optimizer)
        self.train_state = self.train_state.replace(optimizer=optimizer)
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        del step
        assert hyper_params.learning_rate is not None, "no learning rate provided."
        if hyper_params.weight_decay != 0:
            raise NotImplementedError("Weight decay not supported")

        noise = jax.random.normal(key=nn.make_rng(),
                                  shape=param.shape,
                                  dtype=param.dtype)

        momentum = state.momentum
        h = hyper_params.step_size
        gamma = hyper_params.friction
        t = hyper_params.temperature
        n = hyper_params.train_size

        new_momentum = ((1 - h * gamma) * momentum - h * n * grad +
                        jnp.sqrt(2 * gamma * h * t) *
                        jnp.sqrt(state.preconditioner) * noise)

        new_param = param + h * (1. / state.preconditioner) * new_momentum
        new_state = _SymEulerSGMCMCParamState(new_momentum,
                                              state.preconditioner)
        return new_param, new_state
예제 #12
0
    def apply(self, x, config, num_classes, train=True):
        """Creates a model definition."""
        b, c = x.shape[0], x.shape[3]
        k = config.k
        sigma = config.ptopk_sigma
        num_samples = config.ptopk_num_samples

        sigma *= self.state("sigma_mutiplier",
                            shape=(),
                            initializer=nn.initializers.ones).value

        stats = {"x": x, "sigma": sigma}

        feature_extractor = models.ResNet50.shared(train=train,
                                                   name="ResNet_0")

        rpn_feature = feature_extractor(x)
        rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature),
                                            communication=Communication(
                                                config.communication),
                                            train=train)
        stats.update(rpn_stats)

        # rpn_scores are a list of score images. We keep track of the structure
        # because it is used in the aggregation step later-on.
        rpn_scores_shapes = [s.shape for s in rpn_scores]
        rpn_scores_flat = jnp.concatenate(
            [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1)
        top_k_indicators = sample_patches.select_patches_perturbed_topk(
            rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples)
        top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1])
        offset = 0
        weights = []
        for sh in rpn_scores_shapes:
            cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]]
            cur = jnp.reshape(cur, [b, k, sh[1], sh[2]])
            weights.append(cur)
            offset += sh[1] * sh[2]
        chex.assert_equal(offset, top_k_indicators.shape[-1])

        part_imgs = weighted_anchor_aggregator(x, weights)
        chex.assert_shape(part_imgs, (b * k, 224, 224, c))
        stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c])

        part_features = feature_extractor(part_imgs)
        part_features = jnp.mean(part_features,
                                 axis=[1, 2])  # GAP the spatial dims

        part_features = nn.dropout(  # features from parts
            jnp.reshape(part_features, [b * k, 2048]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())
        features = nn.dropout(  # features from whole image
            jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())

        # Mean pool all part features, add it to features and predict logits.
        concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]),
                              axis=1) + features
        concat_logits = nn.Dense(concat_out, num_classes)
        raw_logits = nn.Dense(features, num_classes)
        part_logits = jnp.reshape(nn.Dense(part_features, num_classes),
                                  [b, k, -1])

        all_logits = {
            "raw_logits": raw_logits,
            "concat_logits": concat_logits,
            "part_logits": part_logits,
        }
        # add entropy into it for entropy regularization.
        stats["rpn_scores_entropy"] = jax.scipy.special.entr(
            jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0)
        return all_logits, stats
    def maybe_gradual_interpolation(
            self, batch, unlabeled_batch, env_ids, unlabeled_env_ids,
            flax_model, interpolate_fn, sampled_layer, selected_env_reps,
            selected_unlabeled_env_reps, sampled_reps, sampled_unlabeled_reps,
            logits, unlabled_logits, train_state, teacher_train_state):

        # Compute alignment based on the selected reps.
        aligned_pairs = self.task.get_bipartite_env_aligned_pairs_idx(
            selected_env_reps, batch, env_ids, selected_unlabeled_env_reps,
            unlabeled_batch, unlabeled_env_ids)
        pair_keys, matching_matrix = zip(*aligned_pairs.items())
        matching_matrix = jnp.array(matching_matrix)

        # Convert pair keys to pair ids (indices in the env_ids list).
        pair_ids = [(env_ids.index(int(x[0])),
                     unlabeled_env_ids.index(int(x[1]))) for x in pair_keys]

        # Get sampled layer activations and group them similar to env pairs.
        paired_reps = jnp.array([(sampled_reps[envs[0]],
                                  sampled_unlabeled_reps[envs[1]])
                                 for envs in pair_ids])

        # Set alpha and beta for sampling lambda:
        beta_params = pipeline_utils.get_weight_param(self.hparams,
                                                      'unlabeled_beta', 1.0)
        alpha_params = pipeline_utils.get_weight_param(self.hparams,
                                                       'unlabeled_alpha', 1.0)
        step = train_state.global_step
        beta = pipeline_utils.scheduler(step, beta_params)
        alpha = pipeline_utils.scheduler(step, alpha_params)
        if self.hparams.get('unlabeled_lambda_params', None):
            lambda_params = pipeline_utils.get_weight_param(
                self.hparams, 'unlabeled_lambda', .0)
            lmbda = pipeline_utils.scheduler(step, lambda_params)
        else:
            lmbda = -1
        # Get interpolated reps for each en pair:
        inter_reps, sample_lambdas = interpolate_fn(
            jax.random.split(nn.make_rng(), len(paired_reps[:, 0])),
            matching_matrix, paired_reps[:, 0], paired_reps[:, 1],
            self.hparams.get('num_of_lambda_samples_for_inter_mixup',
                             1), alpha, beta, lmbda)

        # Get interpolated batches for each env pair:
        interpolated_batches = self.get_interpolated_batches(
            batch, inter_reps, pair_ids, sample_lambdas,
            self.hparams.get('interpolation_method',
                             'plain_convex_combination'))
        if self.hparams.get('stop_gradient_for_interpolations', False):
            interpolated_batches = jax.lax.stop_gradient(interpolated_batches)

        if self.hparams.get('interpolated_labels'):
            # Get logits for the interpolated states by interpoting pseudo labels on
            # source and target.
            if self.hparams.get(
                    'interpolation_method') == 'plain_convex_combination':
                teacher_interpolated_logits = jax.vmap(
                    tensor_util.convex_interpolate)(logits, unlabled_logits,
                                                    sample_lambdas)
            else:
                teacher_interpolated_logits = logits
        else:
            # Get logits for the interpolated states from the teacher.
            teacher_interpolated_logits, _, _, _ = self.forward_pass(
                teacher_train_state.optimizer.target, teacher_train_state,
                interpolated_batches, nn.make_rng(), sampled_layer)

        # Do we want to propagate the gradients  to the teacher?
        if self.hparams.get('stop_gradient_for_teacher', True):
            teacher_interpolated_logits = jax.lax.stop_gradient(
                teacher_interpolated_logits)

        for i in range(len(interpolated_batches)):
            (interpolated_batches[i]['label'],
             interpolated_batches[i]['weights']
             ) = pipeline_utils.logit_transformer(
                 logits=teacher_interpolated_logits[i],
                 temp=self.hparams.get('label_temp') or 1.0,
                 confidence_quantile_threshold=self.hparams.get(
                     'confidence_quantile_threshold', 0.1),
                 self_supervised_label_transformation=self.hparams.get(
                     'self_supervised_label_transformation', 'sharp'),
                 logit_indices=None)

        # Compute logits for the interpolated states:
        (_, interpolated_logits, _,
         train_state) = self.stateful_forward_pass(flax_model, train_state,
                                                   interpolated_batches,
                                                   sampled_layer)

        return (interpolated_batches, interpolated_logits, sample_lambdas,
                alpha, beta, train_state)
예제 #14
0
    def apply(self,
              x,
              *,
              patch_size,
              k,
              downscale,
              scorer_has_se,
              normalization_str="identity",
              selection_method,
              selection_method_kwargs=None,
              selection_method_inference=None,
              patch_dropout=0.,
              hard_topk_probability=0.,
              random_patch_probability=0.,
              use_iterative_extraction,
              append_position_to_input,
              feature_network,
              aggregation_method,
              aggregation_method_kwargs=None,
              train):
        """Process a high resolution image by selecting a subset of useful patches.

    This model processes the input as follow:
    1. Compute scores per patch on a downscaled version of the input.
    2. Select "important" patches using sampling or top-k methods.
    3. Extract the patches from the high-resolution image.
    4. Compute representation vector for each patch with a feature network.
    5. Aggregate the patch representation to obtain an image representation.

    Args:
      x: Input tensor of shape (batch, height, witdh, channels).
      patch_size: Size of the (squared) patches to extract.
      k: Number of patches to extract per image.
      downscale: Downscale multiplier for the input of the scorer network.
      scorer_has_se: Whether scorer network has Squeeze-excite layers.
      normalization_str: String specifying the normalization of the scores.
      selection_method: Method that selects which patches should be extracted,
        based on their scores. Either returns indices (hard selection) or
        indicators vectors (which could yield interpolated patches).
      selection_method_kwargs: Keyword args for the selection_method.
      selection_method_inference: Selection method used at inference.
      patch_dropout: Probability to replace a patch by 0 values.
      hard_topk_probability: Probability to use the true topk on the scores to
        select the patches. This operation has no gradient so scorer's weights
        won't be trained.
      random_patch_probability: Probability to replace each patch by a random
        patch in the image during training.
      use_iterative_extraction: If True, uses a for loop instead of patch
        indexing for memory efficiency.
      append_position_to_input: Append normalized (height, width) position to
        the channels of the input.
      feature_network: Network to be applied on each patch individually to
        obtain patch representation vectors.
      aggregation_method: Method to aggregate the representations of the k
        patches of each image to obtain the image representation.
      aggregation_method_kwargs: Keywords arguments for aggregation_method.
      train: If the model is being trained. Disable dropout otherwise.

    Returns:
      A representation vector for each image in the batch.
    """
        selection_method = SelectionMethod(selection_method)
        aggregation_method = AggregationMethod(aggregation_method)
        if selection_method_inference:
            selection_method_inference = SelectionMethod(
                selection_method_inference)

        selection_method_kwargs = selection_method_kwargs or {}
        aggregation_method_kwargs = aggregation_method_kwargs or {}

        stats = {}

        # Compute new dimension of the scoring image.
        b, h, w, c = x.shape
        scoring_shape = (b, h // downscale, w // downscale, c)

        # === Compute the scores with a small CNN.
        if selection_method == SelectionMethod.RANDOM:
            scores_h, scores_w = Scorer.compute_output_size(
                h // downscale, w // downscale)
            num_patches = scores_h * scores_w
        else:
            # Downscale input to run scorer on.
            scoring_x = jax.image.resize(x, scoring_shape, method="bilinear")
            scores = Scorer(scoring_x,
                            use_squeeze_excite=scorer_has_se,
                            name="scorer")
            flatten_scores = einops.rearrange(scores, "b h w -> b (h w)")
            num_patches = flatten_scores.shape[-1]
            scores_h, scores_w = scores.shape[1:3]

            # Compute entropy before normalization
            prob_scores = jax.nn.softmax(flatten_scores)
            stats["entropy_before_normalization"] = jax.scipy.special.entr(
                prob_scores).sum(axis=1).mean(axis=0)

            # Normalize the flatten scores
            normalization_fn = create_normalization_fn(normalization_str)
            flatten_scores = normalization_fn(flatten_scores)
            scores = flatten_scores.reshape(scores.shape)
            stats["scores"] = scores[Ellipsis, None]

        # Concatenate height and width position to the input channels.
        if append_position_to_input:
            coords = utils.create_grid([h, w], value_range=(0., 1.))
            x = jnp.concatenate(
                [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1)
            c += 2

        # Overwrite the selection method at inference
        if selection_method_inference and not train:
            selection_method = selection_method_inference

        # === Patch selection

        # Select the patches by sampling or top-k. Some methods returns the indices
        # of the selected patches, other methods return indicator vectors.
        extract_by_indices = selection_method in [
            SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM
        ]
        if selection_method is SelectionMethod.SINKHORN_TOPK:
            indicators = select_patches_sinkhorn_topk(
                flatten_scores, k=k, **selection_method_kwargs)
        elif selection_method is SelectionMethod.PERTURBED_TOPK:
            sigma = selection_method_kwargs["sigma"]
            num_samples = selection_method_kwargs["num_samples"]
            sigma *= self.state("sigma_mutiplier",
                                shape=(),
                                initializer=nn.initializers.ones).value
            stats["sigma"] = sigma
            indicators = select_patches_perturbed_topk(flatten_scores,
                                                       k=k,
                                                       sigma=sigma,
                                                       num_samples=num_samples)
        elif selection_method is SelectionMethod.HARD_TOPK:
            indices = select_patches_hard_topk(flatten_scores, k=k)
        elif selection_method is SelectionMethod.RANDOM:
            batch_random_indices_fn = jax.vmap(
                functools.partial(jax.random.choice,
                                  a=num_patches,
                                  shape=(k, ),
                                  replace=False))
            indices = batch_random_indices_fn(
                jax.random.split(nn.make_rng(), b))

        # Compute scores entropy for regularization
        if selection_method not in [SelectionMethod.RANDOM]:
            prob_scores = flatten_scores
            # Normalize the scores if it is not already done.
            if "softmax" not in normalization_str:
                prob_scores = jax.nn.softmax(prob_scores)
            stats["entropy"] = jax.scipy.special.entr(prob_scores).sum(
                axis=1).mean(axis=0)

        # Randomly use hard topk at training.
        if (train and hard_topk_probability > 0 and selection_method
                not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]):
            true_indices = select_patches_hard_topk(flatten_scores, k=k)
            random_values = jax.random.uniform(nn.make_rng(), (b, ))
            use_hard = random_values < hard_topk_probability
            if extract_by_indices:
                indices = jnp.where(use_hard[:, None], true_indices, indices)
            else:
                true_indicators = make_indicators(true_indices, num_patches)
                indicators = jnp.where(use_hard[:, None, None],
                                       true_indicators, indicators)

        # Sample some random patches during training with random_patch_probability.
        if (train and random_patch_probability > 0
                and selection_method is not SelectionMethod.RANDOM):
            single_random_patches = functools.partial(jax.random.choice,
                                                      a=num_patches,
                                                      shape=(k, ),
                                                      replace=False)
            random_indices = jax.vmap(single_random_patches)(jax.random.split(
                nn.make_rng(), b))
            random_values = jax.random.uniform(nn.make_rng(), (b, k))
            use_random = random_values < random_patch_probability
            if extract_by_indices:
                indices = jnp.where(use_random, random_indices, indices)
            else:
                random_indicators = make_indicators(random_indices,
                                                    num_patches)
                indicators = jnp.where(use_random[:, None, :],
                                       random_indicators, indicators)

        # === Patch extraction
        if extract_by_indices:
            patches = extract_patches_from_indices(x,
                                                   indices,
                                                   patch_size=patch_size,
                                                   grid_shape=(scores_h,
                                                               scores_w))
            indicators = make_indicators(indices, num_patches)
        else:
            patches = extract_patches_from_indicators(
                x,
                indicators,
                patch_size,
                grid_shape=(scores_h, scores_w),
                iterative=use_iterative_extraction,
                patch_dropout=patch_dropout,
                train=train)

        chex.assert_shape(patches, (b, k, patch_size, patch_size, c))

        stats["extracted_patches"] = einops.rearrange(
            patches, "b k i j c -> b i (k j) c")
        # Remove position channels for plotting.
        if append_position_to_input:
            stats["extracted_patches"] = (
                stats["extracted_patches"][Ellipsis, :-2])

        # === Compute patch features
        flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c")
        representations = feature_network(flatten_patches, train=train)
        if representations.ndim > 2:
            collapse_axis = tuple(range(1, representations.ndim - 1))
            representations = representations.mean(axis=collapse_axis)
        representations = einops.rearrange(representations,
                                           "(b k) d -> b k d",
                                           k=k)

        stats["patch_representations"] = representations

        # === Aggregate the k patches

        # - for sampling we are forced to take an expectation
        # - for topk we have multiple options: mean, max, transformer.
        if aggregation_method is AggregationMethod.TRANSFORMER:
            patch_pos_encoding = nn.Dense(einops.rearrange(
                indicators, "b d k -> b k d"),
                                          features=representations.shape[-1])

            chex.assert_equal_shape([representations, patch_pos_encoding])
            representations += patch_pos_encoding
            representations = transformer.Transformer(
                representations,
                **aggregation_method_kwargs,
                is_training=train)

        elif aggregation_method is AggregationMethod.MEANPOOLING:
            representations = representations.mean(axis=1)
        elif aggregation_method is AggregationMethod.MAXPOOLING:
            representations = representations.max(axis=1)
        elif aggregation_method is AggregationMethod.SUM_LAYERNORM:
            representations = representations.sum(axis=1)
            representations = nn.LayerNorm(representations)

        representations = nn.Dense(representations,
                                   features=representations.shape[-1],
                                   name="classification_dense1")
        representations = nn.swish(representations)

        return representations, stats
 def init_param_state(self, param):
     # TODO(basv): do we want to init momentum randomly?
     return _SymEulerSGMCMCParamState(
         jax.random.normal(nn.make_rng(), param.shape, param.dtype),
         jnp.ones_like(param))
    def training_loss_fn(self, flax_module, train_state, batch, dropout_rng,
                         mixup_rng, sampled_layer):
        """Runs forward pass and computes loss.

    Args:
      flax_module: A flax module.
      train_state: TrainState, the state of training including the current
        global_step, model_state, rng, and optimizer.
      batch: Batches from different environments.
      dropout_rng: FLAX PRNG key.
      mixup_rng: FLAX PRNG key.
      sampled_layer: str; Name of the layer on which mixup will be applied.

    Returns:
      loss, new_module_state and computed logits for each batch.
    """

        with nn.stochastic(dropout_rng):
            with nn.stateful(train_state.model_state) as new_model_state:
                logits, reps, _ = flax_module(batch['inputs'],
                                              train=True,
                                              return_activations=True)

                # Get mathing between examples from the mini batch:
                matching_matrix = pipeline_utils.get_self_matching_matrix(
                    batch,
                    reps[sampled_layer],
                    mode=self.hparams.get('intra_mixup_mode', 'random'),
                    label_cost=self.hparams.get('intra_mixup_label_cost', 1.0),
                    l2_cost=self.hparams.get('intra_mixup_l2_cost', 0.001))

        beta_params = self.hparams.get('beta_schedule_params') or {
            'initial_value': 1.0,
            'mode': 'constant'
        }
        alpha_params = self.hparams.get('alpha_schedule_params') or {
            'initial_value': 1.0,
            'mode': 'constant'
        }
        step = train_state.global_step
        beta = pipeline_utils.scheduler(step, beta_params)
        alpha = pipeline_utils.scheduler(step, alpha_params)

        with nn.stochastic(mixup_rng):
            with nn.stateful(new_model_state) as new_model_state:
                new_logits, sample_lambdas = self.interpolate_and_predict(
                    nn.make_rng(), flax_module, matching_matrix, reps,
                    sampled_layer, alpha, beta)

            new_batch = copy.deepcopy(batch)

            # Compute labels for the interpolated states:
            new_batch['label'] = tensor_util.convex_interpolate(
                batch['label'], batch['label'][jnp.argmax(matching_matrix,
                                                          axis=-1)],
                sample_lambdas)

            # Compute weights for the interpolated states:
            if batch.get('weights') is not None:
                new_batch['weights'] = tensor_util.convex_interpolate(
                    batch['weights'],
                    batch['weights'][jnp.argmax(matching_matrix,
                                                axis=-1)], sample_lambdas)

        # Standard loss:
        loss = self.task.loss_function(logits, batch, flax_module.params)
        # Add the loss from interpolated states:
        loss += self.task.loss_function(new_logits, new_batch)

        return loss, (new_model_state, logits)