def align_batches(self, x, y, x_labels, y_labels, supervised=True): """Computes alignment between two mini batches. In the MultiEnvDomainMappingClassification, this calls the random alignment (based on labels) function. Args: x: jnp array; Batch of representations with shape '[bs, feature_size]'. y: jnp array; Batch of representations with shape '[bs, feature_size]'. x_labels: jnp array; labels of x with shape '[bs, 1]'. y_labels: jnp array; labels of y with shape '[bs, 1]'. supervised: bool; If False we can not use y_labels and it defaults back to random alignment otherwise it does label based alignment (tries to align examples that have similar labels). Returns: aligned indexes of x, aligned indexes of y. """ del y # Get aligned example pairs. if supervised: rng = nn.make_rng() new_rngs = jax.random.split(rng, len(x_labels)) aligned_pairs_idx = domain_mapping_utils.align_examples( new_rngs, x_labels, jnp.arange(len(x_labels)), y_labels) else: number_of_examples = len(x) rng = nn.make_rng() matching_matrix = jnp.eye(number_of_examples) matching_matrix = jax.random.permutation(rng, matching_matrix) aligned_pairs_idx = jnp.arange(len(x)), jnp.argmax(matching_matrix, axis=-1) return aligned_pairs_idx
def align_batches(self, x, y, x_labels, y_labels): """Computes optimal transport between two batches with Sinkhorn algorithm. This calls a sinkhorn solver in dual (log) space with a finite number of iterations and uses the dual unregularized transport cost as the OT cost. Args: x: jnp array; Batch of representations with shape '[bs, feature_size]'. y: jnp array; Batch of representations with shape '[bs, feature_size]'. x_labels: jnp array; labels of x with shape '[bs, 1]'. y_labels: jnp array; labels of y with shape '[bs, 1]'. Returns: ot_cost: scalar optimal transport loss. """ epsilon = self.task_params.get('sinkhorn_eps', 0.1) num_iters = self.task_params.get('sinkhorn_iters', 50) label_weight = self.task_params.get('ot_label_cost', 0.) l2_weight = self.task_params.get('ot_l2_cost', 0.) noise_weight = self.task_params.get('ot_noise_cost', 1.0) x = x.reshape((x.shape[0], -1)) y = y.reshape((x.shape[0], -1)) # Solve sinkhorn in log space. num_x = x.shape[0] num_y = y.shape[0] x = x.reshape((num_x, -1)) y = y.reshape((num_y, -1)) # Marginal of rows (a) and columns (b) a = jnp.ones(shape=(num_x, ), dtype=x.dtype) b = jnp.ones(shape=(num_y, ), dtype=y.dtype) # TODO(samiraabnar): Check range of l2 cost? cost = domain_mapping_utils.pairwise_l2(x, y) # Adjust cost such that representations with different labels # get assigned a very high cost. same_labels = domain_mapping_utils.pairwise_equality_1d( x_labels, y_labels) adjusted_cost = (1 - same_labels) * label_weight + l2_weight * cost # Add noise to the cost. adjusted_cost += noise_weight * jax.random.uniform( nn.make_rng(), minval=0, maxval=1.0) _, matching, _ = domain_mapping_utils.sinkhorn_dual_solver( a, b, adjusted_cost, epsilon, num_iters) matching = domain_mapping_utils.round_coupling(matching, a, b) if self.task_params.get('interpolation_mode', 'hard') == 'hard': matching = domain_mapping_utils.sample_best_permutation( nn.make_rng(), coupling=matching, cost=adjusted_cost) return matching
def select_patches_perturbed_topk(flatten_scores, sigma, *, k, num_samples=1000): """Select patches using a differentiable top-k based on perturbation. Uses https://q-berthet.github.io/papers/BerBloTeb20.pdf, see off_the_grid.lib.ops.perturbed_topk for more info. Args: flatten_scores: The flatten scores of shape (batch, num_patches). sigma: Standard deviation of the noise. k: The number of patches to extract. num_samples: Number of noisy inputs used to compute the output expectation. Returns: Indicator vectors of the selected patches (batch, num_patches, k). """ batch_size = flatten_scores.shape[0] batch_topk_fn = jax.vmap( functools.partial(perturbed_topk.perturbed_sorted_topk_indicators, num_samples=num_samples, sigma=sigma, k=k)) rng_keys = jax.random.split(nn.make_rng(), batch_size) indicators = batch_topk_fn(flatten_scores, rng_keys) topk_indicators_flatten = einops.rearrange(indicators, "b k d -> b d k") return topk_indicators_flatten
def maybe_inter_env_interpolation(self, batch, env_ids, flax_model, interpolate_fn, sampled_layer, sampled_reps, selected_env_reps, train_state): if len(env_ids) > 1 and self.hparams.get('inter_env_interpolation', True): # We call the alignment method of the task class: aligned_pairs = self.task.get_env_aligned_pairs_idx( selected_env_reps, batch, env_ids) pair_keys, alignments = zip(*aligned_pairs.items()) # Convert alignments which is the array of aligned indices to match mat. alignments = jnp.asarray(alignments) num_env_pairs = alignments.shape[0] batch_size = alignments.shape[2] matching_matrix = jnp.zeros( shape=(num_env_pairs, batch_size, batch_size), dtype=jnp.float32) matching_matrix = matching_matrix.at[:, alignments[:, 0], alignments[:, 1]].set(1.0) # Convert pair keys to pair ids (indices in the env_ids list). pair_ids = [(env_ids.index(int(x[0])), env_ids.index(int(x[1]))) for x in pair_keys] # Get sampled layer activations and group them similar to env pairs. paired_reps = jnp.array([ (sampled_reps[envs[0]], sampled_reps[envs[1]]) for envs in pair_ids ]) # Set alpha and beta for sampling lambda: beta_params = pipeline_utils.get_weight_param(self.hparams, 'inter_beta', 1.0) alpha_params = pipeline_utils.get_weight_param(self.hparams, 'inter_alpha', 1.0) beta = pipeline_utils.scheduler(train_state.global_step, beta_params) alpha = pipeline_utils.scheduler(train_state.global_step, alpha_params) # Get interpolated reps for each env pair: inter_reps, sample_lambdas = interpolate_fn( jax.random.split(nn.make_rng(), len(paired_reps[:, 0])), matching_matrix, paired_reps[:, 0], paired_reps[:, 1], self.hparams.get('num_of_lambdas_samples_for_inter_mixup', 1), alpha, beta, -1) # Get interpolated batches for each env pair: interpolated_batches = self.get_interpolated_batches( batch, inter_reps, pair_ids, sample_lambdas, self.hparams.get('intra_interpolation_method', 'plain_convex_combination')) if self.hparams.get('stop_grad_for_inter_mixup', True): interpolated_batches = jax.lax.stop_gradient(interpolated_batches) # Compute logits for the interpolated states: _, interpolated_logits, _, train_state = self.stateful_forward_pass( flax_model, train_state, interpolated_batches, sampled_layer) return (interpolated_batches, interpolated_logits, sample_lambdas, train_state) return None, None, 0, train_state
def stateless_forward_pass(self, flax_model, train_state, batch, input_key='input'): (all_env_logits, all_env_reps, selected_env_reps, _) = self.forward_pass(flax_model, train_state, batch, nn.make_rng(), input_key) return all_env_logits, all_env_reps, selected_env_reps
def get_self_matching_matrix(batch, reps, mode='random', label_cost=1.0, l2_cost=1.0): """Align examples in a batch. Args: batch: list(dict); Batch of examples (with inputs, and label keys). reps: list(jnp array); List of representations of a selected layer for each batch. mode: str; Determines alignment method. label_cost: float; Weight of label cost when Sinkhorn matching is used. l2_cost: float; Weight of l2 cost when Sinkhorn matching is used. Returns: Matching matrix with shape `[num_batches, batch_size, batch_size]`. """ if mode == 'random': number_of_examples = batch['inputs'].shape[0] rng = nn.make_rng() matching_matrix = jnp.eye(number_of_examples) matching_matrix = jax.random.permutation(rng, matching_matrix) elif mode == 'sinkhorn': epsilon = 0.1 num_iters = 100 reps = reps.reshape((reps.shape[0], -1)) x = y = reps x_labels = y_labels = batch['label'] # Solve sinkhorn in log space. num_x = x.shape[0] num_y = y.shape[0] # Marginal of rows (a) and columns (b) a = jnp.ones(shape=(num_x, ), dtype=x.dtype) b = jnp.ones(shape=(num_y, ), dtype=y.dtype) cost = domain_mapping_utils.pairwise_l2(x, y) cost += jnp.eye(num_x) * jnp.max(cost) * 10 # Adjust cost such that representations with different labels # get assigned a very high cost. same_labels = domain_mapping_utils.pairwise_equality_1d( x_labels, y_labels) adjusted_cost = (1 - same_labels) * label_cost + l2_cost * cost _, matching, _ = domain_mapping_utils.sinkhorn_dual_solver( a, b, adjusted_cost, epsilon, num_iters) matching_matrix = domain_mapping_utils.round_coupling( matching, jnp.ones((matching.shape[0], )), jnp.ones((matching.shape[1], ))) else: raise ValueError( '%s mode for self matching alignment is not supported.' % mode) return matching_matrix
def stateful_forward_pass(self, flax_model, train_state, batch, input_key='input', train=True): (env_logits, all_env_reps, selected_env_reps, new_model_state) = self.forward_pass(flax_model, train_state, batch, nn.make_rng(), input_key, train) # Model state, e.g. batch statistics, are averaged over all environments # because we use vmapped_flax_module_train. new_model_state = jax.tree_util.tree_map( functools.partial(jnp.mean, axis=0), new_model_state) # Update the model state already, since there is going to be another forward # pass. train_state = train_state.replace(model_state=new_model_state) return all_env_reps, env_logits, selected_env_reps, train_state
def maybe_intra_env_interpolation(self, batch, env_ids, flax_model, interpolate_fn, sampled_layer, sampled_reps, train_state): if self.hparams.get('intra_env_interpolation', True): # Set alpha ans beta for sampling lambda: beta_params = pipeline_utils.get_weight_param(self.hparams, 'beta', 1.0) alpha_params = pipeline_utils.get_weight_param(self.hparams, 'alpha', 1.0) step = train_state.global_step beta = pipeline_utils.scheduler(step, beta_params) alpha = pipeline_utils.scheduler(step, alpha_params) # This is just a random matching (similar to manifold mixup paper). self_aligned_matching_matrix, self_pair_ids = self.get_intra_env_matchings( batch, sampled_reps, env_ids) # Compute interpolated representations of sampled layer: same_env_inter_reps, sample_lambdas = interpolate_fn( jax.random.split(nn.make_rng(), len(sampled_reps)), self_aligned_matching_matrix, sampled_reps, sampled_reps, self.hparams.get('num_of_lambdas_samples_for_mixup', 1), alpha, beta, -1) # Get interpolated batches (interpolated inputs, labels, and weights) same_env_interpolated_batches = self.get_interpolated_batches( batch, same_env_inter_reps, self_pair_ids, sample_lambdas, self.hparams.get('intra_interpolation_method', 'plain_convex_combination')) if self.hparams.get('stop_grad_for_intra_mixup', True): same_env_interpolated_batches = jax.lax.stop_gradient( same_env_interpolated_batches) # Compute logits for the interpolated states: (_, same_env_interpolated_logits, _, train_state) = self.stateful_forward_pass(flax_model, train_state, same_env_interpolated_batches, sampled_layer) return (same_env_interpolated_batches, same_env_interpolated_logits, sample_lambdas, train_state) return None, None, 0, train_state
def setup_transformers(self, hidden_reps_dim): """Sets up linear transformers for the auxiliary loss. Args: hidden_reps_dim: int; Dimensionality of the representational space (size of the representations used for computing the domain mapping loss. """ transformer_class = self.get_transformer_module(hidden_reps_dim) self.state_transformers = {} env_keys = list(map(int, self.dataset.splits.train.keys())) # Get list of all possible environment pairs (this includes # different permutations). env_pairs = list(itertools.permutations(env_keys, 2)) rng = nn.make_rng() for env_pair in env_pairs: rng, params_rng = jax.random.split(rng) _, init_params = transformer_class.init_by_shape( params_rng, [((1, hidden_reps_dim), jnp.float32)]) self.state_transformers[env_pair] = nn.Model( transformer_class, init_params)
def maybe_reset_train_state(self): optimizer = jax_utils.unreplicate(self.train_state.optimizer) if self.hparams.get('reinitilize_params_at_each_step', False): del optimizer.target (flax_model, _, _) = pipeline_utils.create_flax_module( optimizer.target.module, self.task.dataset.meta_data['input_shape'], self.hparams, nn.make_rng(), self.task.dataset.meta_data.get('input_dtype', jnp.float32)) else: flax_model = optimizer.target # Reset optimizer if self.hparams.get('reinitialize_optimizer_at_each_step', False): optimizer = optimizers.get_optimizer( self.hparams).create(flax_model) else: optimizer = optimizer.replace(target=flax_model) optimizer = jax_utils.replicate(optimizer) self.train_state = self.train_state.replace(optimizer=optimizer)
def apply_param_gradient(self, step, hyper_params, param, state, grad): del step assert hyper_params.learning_rate is not None, "no learning rate provided." if hyper_params.weight_decay != 0: raise NotImplementedError("Weight decay not supported") noise = jax.random.normal(key=nn.make_rng(), shape=param.shape, dtype=param.dtype) momentum = state.momentum h = hyper_params.step_size gamma = hyper_params.friction t = hyper_params.temperature n = hyper_params.train_size new_momentum = ((1 - h * gamma) * momentum - h * n * grad + jnp.sqrt(2 * gamma * h * t) * jnp.sqrt(state.preconditioner) * noise) new_param = param + h * (1. / state.preconditioner) * new_momentum new_state = _SymEulerSGMCMCParamState(new_momentum, state.preconditioner) return new_param, new_state
def apply(self, x, config, num_classes, train=True): """Creates a model definition.""" b, c = x.shape[0], x.shape[3] k = config.k sigma = config.ptopk_sigma num_samples = config.ptopk_num_samples sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats = {"x": x, "sigma": sigma} feature_extractor = models.ResNet50.shared(train=train, name="ResNet_0") rpn_feature = feature_extractor(x) rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature), communication=Communication( config.communication), train=train) stats.update(rpn_stats) # rpn_scores are a list of score images. We keep track of the structure # because it is used in the aggregation step later-on. rpn_scores_shapes = [s.shape for s in rpn_scores] rpn_scores_flat = jnp.concatenate( [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1) top_k_indicators = sample_patches.select_patches_perturbed_topk( rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples) top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1]) offset = 0 weights = [] for sh in rpn_scores_shapes: cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]] cur = jnp.reshape(cur, [b, k, sh[1], sh[2]]) weights.append(cur) offset += sh[1] * sh[2] chex.assert_equal(offset, top_k_indicators.shape[-1]) part_imgs = weighted_anchor_aggregator(x, weights) chex.assert_shape(part_imgs, (b * k, 224, 224, c)) stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c]) part_features = feature_extractor(part_imgs) part_features = jnp.mean(part_features, axis=[1, 2]) # GAP the spatial dims part_features = nn.dropout( # features from parts jnp.reshape(part_features, [b * k, 2048]), 0.5, deterministic=not train, rng=nn.make_rng()) features = nn.dropout( # features from whole image jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]), 0.5, deterministic=not train, rng=nn.make_rng()) # Mean pool all part features, add it to features and predict logits. concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]), axis=1) + features concat_logits = nn.Dense(concat_out, num_classes) raw_logits = nn.Dense(features, num_classes) part_logits = jnp.reshape(nn.Dense(part_features, num_classes), [b, k, -1]) all_logits = { "raw_logits": raw_logits, "concat_logits": concat_logits, "part_logits": part_logits, } # add entropy into it for entropy regularization. stats["rpn_scores_entropy"] = jax.scipy.special.entr( jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0) return all_logits, stats
def maybe_gradual_interpolation( self, batch, unlabeled_batch, env_ids, unlabeled_env_ids, flax_model, interpolate_fn, sampled_layer, selected_env_reps, selected_unlabeled_env_reps, sampled_reps, sampled_unlabeled_reps, logits, unlabled_logits, train_state, teacher_train_state): # Compute alignment based on the selected reps. aligned_pairs = self.task.get_bipartite_env_aligned_pairs_idx( selected_env_reps, batch, env_ids, selected_unlabeled_env_reps, unlabeled_batch, unlabeled_env_ids) pair_keys, matching_matrix = zip(*aligned_pairs.items()) matching_matrix = jnp.array(matching_matrix) # Convert pair keys to pair ids (indices in the env_ids list). pair_ids = [(env_ids.index(int(x[0])), unlabeled_env_ids.index(int(x[1]))) for x in pair_keys] # Get sampled layer activations and group them similar to env pairs. paired_reps = jnp.array([(sampled_reps[envs[0]], sampled_unlabeled_reps[envs[1]]) for envs in pair_ids]) # Set alpha and beta for sampling lambda: beta_params = pipeline_utils.get_weight_param(self.hparams, 'unlabeled_beta', 1.0) alpha_params = pipeline_utils.get_weight_param(self.hparams, 'unlabeled_alpha', 1.0) step = train_state.global_step beta = pipeline_utils.scheduler(step, beta_params) alpha = pipeline_utils.scheduler(step, alpha_params) if self.hparams.get('unlabeled_lambda_params', None): lambda_params = pipeline_utils.get_weight_param( self.hparams, 'unlabeled_lambda', .0) lmbda = pipeline_utils.scheduler(step, lambda_params) else: lmbda = -1 # Get interpolated reps for each en pair: inter_reps, sample_lambdas = interpolate_fn( jax.random.split(nn.make_rng(), len(paired_reps[:, 0])), matching_matrix, paired_reps[:, 0], paired_reps[:, 1], self.hparams.get('num_of_lambda_samples_for_inter_mixup', 1), alpha, beta, lmbda) # Get interpolated batches for each env pair: interpolated_batches = self.get_interpolated_batches( batch, inter_reps, pair_ids, sample_lambdas, self.hparams.get('interpolation_method', 'plain_convex_combination')) if self.hparams.get('stop_gradient_for_interpolations', False): interpolated_batches = jax.lax.stop_gradient(interpolated_batches) if self.hparams.get('interpolated_labels'): # Get logits for the interpolated states by interpoting pseudo labels on # source and target. if self.hparams.get( 'interpolation_method') == 'plain_convex_combination': teacher_interpolated_logits = jax.vmap( tensor_util.convex_interpolate)(logits, unlabled_logits, sample_lambdas) else: teacher_interpolated_logits = logits else: # Get logits for the interpolated states from the teacher. teacher_interpolated_logits, _, _, _ = self.forward_pass( teacher_train_state.optimizer.target, teacher_train_state, interpolated_batches, nn.make_rng(), sampled_layer) # Do we want to propagate the gradients to the teacher? if self.hparams.get('stop_gradient_for_teacher', True): teacher_interpolated_logits = jax.lax.stop_gradient( teacher_interpolated_logits) for i in range(len(interpolated_batches)): (interpolated_batches[i]['label'], interpolated_batches[i]['weights'] ) = pipeline_utils.logit_transformer( logits=teacher_interpolated_logits[i], temp=self.hparams.get('label_temp') or 1.0, confidence_quantile_threshold=self.hparams.get( 'confidence_quantile_threshold', 0.1), self_supervised_label_transformation=self.hparams.get( 'self_supervised_label_transformation', 'sharp'), logit_indices=None) # Compute logits for the interpolated states: (_, interpolated_logits, _, train_state) = self.stateful_forward_pass(flax_model, train_state, interpolated_batches, sampled_layer) return (interpolated_batches, interpolated_logits, sample_lambdas, alpha, beta, train_state)
def apply(self, x, *, patch_size, k, downscale, scorer_has_se, normalization_str="identity", selection_method, selection_method_kwargs=None, selection_method_inference=None, patch_dropout=0., hard_topk_probability=0., random_patch_probability=0., use_iterative_extraction, append_position_to_input, feature_network, aggregation_method, aggregation_method_kwargs=None, train): """Process a high resolution image by selecting a subset of useful patches. This model processes the input as follow: 1. Compute scores per patch on a downscaled version of the input. 2. Select "important" patches using sampling or top-k methods. 3. Extract the patches from the high-resolution image. 4. Compute representation vector for each patch with a feature network. 5. Aggregate the patch representation to obtain an image representation. Args: x: Input tensor of shape (batch, height, witdh, channels). patch_size: Size of the (squared) patches to extract. k: Number of patches to extract per image. downscale: Downscale multiplier for the input of the scorer network. scorer_has_se: Whether scorer network has Squeeze-excite layers. normalization_str: String specifying the normalization of the scores. selection_method: Method that selects which patches should be extracted, based on their scores. Either returns indices (hard selection) or indicators vectors (which could yield interpolated patches). selection_method_kwargs: Keyword args for the selection_method. selection_method_inference: Selection method used at inference. patch_dropout: Probability to replace a patch by 0 values. hard_topk_probability: Probability to use the true topk on the scores to select the patches. This operation has no gradient so scorer's weights won't be trained. random_patch_probability: Probability to replace each patch by a random patch in the image during training. use_iterative_extraction: If True, uses a for loop instead of patch indexing for memory efficiency. append_position_to_input: Append normalized (height, width) position to the channels of the input. feature_network: Network to be applied on each patch individually to obtain patch representation vectors. aggregation_method: Method to aggregate the representations of the k patches of each image to obtain the image representation. aggregation_method_kwargs: Keywords arguments for aggregation_method. train: If the model is being trained. Disable dropout otherwise. Returns: A representation vector for each image in the batch. """ selection_method = SelectionMethod(selection_method) aggregation_method = AggregationMethod(aggregation_method) if selection_method_inference: selection_method_inference = SelectionMethod( selection_method_inference) selection_method_kwargs = selection_method_kwargs or {} aggregation_method_kwargs = aggregation_method_kwargs or {} stats = {} # Compute new dimension of the scoring image. b, h, w, c = x.shape scoring_shape = (b, h // downscale, w // downscale, c) # === Compute the scores with a small CNN. if selection_method == SelectionMethod.RANDOM: scores_h, scores_w = Scorer.compute_output_size( h // downscale, w // downscale) num_patches = scores_h * scores_w else: # Downscale input to run scorer on. scoring_x = jax.image.resize(x, scoring_shape, method="bilinear") scores = Scorer(scoring_x, use_squeeze_excite=scorer_has_se, name="scorer") flatten_scores = einops.rearrange(scores, "b h w -> b (h w)") num_patches = flatten_scores.shape[-1] scores_h, scores_w = scores.shape[1:3] # Compute entropy before normalization prob_scores = jax.nn.softmax(flatten_scores) stats["entropy_before_normalization"] = jax.scipy.special.entr( prob_scores).sum(axis=1).mean(axis=0) # Normalize the flatten scores normalization_fn = create_normalization_fn(normalization_str) flatten_scores = normalization_fn(flatten_scores) scores = flatten_scores.reshape(scores.shape) stats["scores"] = scores[Ellipsis, None] # Concatenate height and width position to the input channels. if append_position_to_input: coords = utils.create_grid([h, w], value_range=(0., 1.)) x = jnp.concatenate( [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1) c += 2 # Overwrite the selection method at inference if selection_method_inference and not train: selection_method = selection_method_inference # === Patch selection # Select the patches by sampling or top-k. Some methods returns the indices # of the selected patches, other methods return indicator vectors. extract_by_indices = selection_method in [ SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM ] if selection_method is SelectionMethod.SINKHORN_TOPK: indicators = select_patches_sinkhorn_topk( flatten_scores, k=k, **selection_method_kwargs) elif selection_method is SelectionMethod.PERTURBED_TOPK: sigma = selection_method_kwargs["sigma"] num_samples = selection_method_kwargs["num_samples"] sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats["sigma"] = sigma indicators = select_patches_perturbed_topk(flatten_scores, k=k, sigma=sigma, num_samples=num_samples) elif selection_method is SelectionMethod.HARD_TOPK: indices = select_patches_hard_topk(flatten_scores, k=k) elif selection_method is SelectionMethod.RANDOM: batch_random_indices_fn = jax.vmap( functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False)) indices = batch_random_indices_fn( jax.random.split(nn.make_rng(), b)) # Compute scores entropy for regularization if selection_method not in [SelectionMethod.RANDOM]: prob_scores = flatten_scores # Normalize the scores if it is not already done. if "softmax" not in normalization_str: prob_scores = jax.nn.softmax(prob_scores) stats["entropy"] = jax.scipy.special.entr(prob_scores).sum( axis=1).mean(axis=0) # Randomly use hard topk at training. if (train and hard_topk_probability > 0 and selection_method not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]): true_indices = select_patches_hard_topk(flatten_scores, k=k) random_values = jax.random.uniform(nn.make_rng(), (b, )) use_hard = random_values < hard_topk_probability if extract_by_indices: indices = jnp.where(use_hard[:, None], true_indices, indices) else: true_indicators = make_indicators(true_indices, num_patches) indicators = jnp.where(use_hard[:, None, None], true_indicators, indicators) # Sample some random patches during training with random_patch_probability. if (train and random_patch_probability > 0 and selection_method is not SelectionMethod.RANDOM): single_random_patches = functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False) random_indices = jax.vmap(single_random_patches)(jax.random.split( nn.make_rng(), b)) random_values = jax.random.uniform(nn.make_rng(), (b, k)) use_random = random_values < random_patch_probability if extract_by_indices: indices = jnp.where(use_random, random_indices, indices) else: random_indicators = make_indicators(random_indices, num_patches) indicators = jnp.where(use_random[:, None, :], random_indicators, indicators) # === Patch extraction if extract_by_indices: patches = extract_patches_from_indices(x, indices, patch_size=patch_size, grid_shape=(scores_h, scores_w)) indicators = make_indicators(indices, num_patches) else: patches = extract_patches_from_indicators( x, indicators, patch_size, grid_shape=(scores_h, scores_w), iterative=use_iterative_extraction, patch_dropout=patch_dropout, train=train) chex.assert_shape(patches, (b, k, patch_size, patch_size, c)) stats["extracted_patches"] = einops.rearrange( patches, "b k i j c -> b i (k j) c") # Remove position channels for plotting. if append_position_to_input: stats["extracted_patches"] = ( stats["extracted_patches"][Ellipsis, :-2]) # === Compute patch features flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c") representations = feature_network(flatten_patches, train=train) if representations.ndim > 2: collapse_axis = tuple(range(1, representations.ndim - 1)) representations = representations.mean(axis=collapse_axis) representations = einops.rearrange(representations, "(b k) d -> b k d", k=k) stats["patch_representations"] = representations # === Aggregate the k patches # - for sampling we are forced to take an expectation # - for topk we have multiple options: mean, max, transformer. if aggregation_method is AggregationMethod.TRANSFORMER: patch_pos_encoding = nn.Dense(einops.rearrange( indicators, "b d k -> b k d"), features=representations.shape[-1]) chex.assert_equal_shape([representations, patch_pos_encoding]) representations += patch_pos_encoding representations = transformer.Transformer( representations, **aggregation_method_kwargs, is_training=train) elif aggregation_method is AggregationMethod.MEANPOOLING: representations = representations.mean(axis=1) elif aggregation_method is AggregationMethod.MAXPOOLING: representations = representations.max(axis=1) elif aggregation_method is AggregationMethod.SUM_LAYERNORM: representations = representations.sum(axis=1) representations = nn.LayerNorm(representations) representations = nn.Dense(representations, features=representations.shape[-1], name="classification_dense1") representations = nn.swish(representations) return representations, stats
def init_param_state(self, param): # TODO(basv): do we want to init momentum randomly? return _SymEulerSGMCMCParamState( jax.random.normal(nn.make_rng(), param.shape, param.dtype), jnp.ones_like(param))
def training_loss_fn(self, flax_module, train_state, batch, dropout_rng, mixup_rng, sampled_layer): """Runs forward pass and computes loss. Args: flax_module: A flax module. train_state: TrainState, the state of training including the current global_step, model_state, rng, and optimizer. batch: Batches from different environments. dropout_rng: FLAX PRNG key. mixup_rng: FLAX PRNG key. sampled_layer: str; Name of the layer on which mixup will be applied. Returns: loss, new_module_state and computed logits for each batch. """ with nn.stochastic(dropout_rng): with nn.stateful(train_state.model_state) as new_model_state: logits, reps, _ = flax_module(batch['inputs'], train=True, return_activations=True) # Get mathing between examples from the mini batch: matching_matrix = pipeline_utils.get_self_matching_matrix( batch, reps[sampled_layer], mode=self.hparams.get('intra_mixup_mode', 'random'), label_cost=self.hparams.get('intra_mixup_label_cost', 1.0), l2_cost=self.hparams.get('intra_mixup_l2_cost', 0.001)) beta_params = self.hparams.get('beta_schedule_params') or { 'initial_value': 1.0, 'mode': 'constant' } alpha_params = self.hparams.get('alpha_schedule_params') or { 'initial_value': 1.0, 'mode': 'constant' } step = train_state.global_step beta = pipeline_utils.scheduler(step, beta_params) alpha = pipeline_utils.scheduler(step, alpha_params) with nn.stochastic(mixup_rng): with nn.stateful(new_model_state) as new_model_state: new_logits, sample_lambdas = self.interpolate_and_predict( nn.make_rng(), flax_module, matching_matrix, reps, sampled_layer, alpha, beta) new_batch = copy.deepcopy(batch) # Compute labels for the interpolated states: new_batch['label'] = tensor_util.convex_interpolate( batch['label'], batch['label'][jnp.argmax(matching_matrix, axis=-1)], sample_lambdas) # Compute weights for the interpolated states: if batch.get('weights') is not None: new_batch['weights'] = tensor_util.convex_interpolate( batch['weights'], batch['weights'][jnp.argmax(matching_matrix, axis=-1)], sample_lambdas) # Standard loss: loss = self.task.loss_function(logits, batch, flax_module.params) # Add the loss from interpolated states: loss += self.task.loss_function(new_logits, new_batch) return loss, (new_model_state, logits)