def align_batches(self, x, y, x_labels, y_labels): """Computes alignment between two mini batches. In the MultiEnvHungarianDomainMappingClassification, this calls the hungarian matching function. Args: x: jnp array; Batch of representations with shape '[bs, feature_size]'. y: jnp array; Batch of representations with shape '[bs, feature_size]'. x_labels: jnp array; labels of x with shape '[bs, 1]'. y_labels: jnp array; labels of y with shape '[bs, 1]'. Returns: aligned indexes of x, aligned indexes of y. """ label_cost = self.task_params.get('ot_label_cost', 0.) cost = domain_mapping_utils.pairwise_l2(x, y) # Adjust cost such that representations with different labels # get assigned a very high cost. same_labels = domain_mapping_utils.pairwise_equality_1d( x_labels, y_labels) adjusted_cost = cost + (1 - same_labels) * label_cost # `linear_sum_assignment` computes cheapest hard alignment. x_ind, y_ind = scipy.optimize.linear_sum_assignment(adjusted_cost) return x_ind, y_ind
def get_self_matching_matrix(batch, reps, mode='random', label_cost=1.0, l2_cost=1.0): """Align examples in a batch. Args: batch: list(dict); Batch of examples (with inputs, and label keys). reps: list(jnp array); List of representations of a selected layer for each batch. mode: str; Determines alignment method. label_cost: float; Weight of label cost when Sinkhorn matching is used. l2_cost: float; Weight of l2 cost when Sinkhorn matching is used. Returns: Matching matrix with shape `[num_batches, batch_size, batch_size]`. """ if mode == 'random': number_of_examples = batch['inputs'].shape[0] rng = nn.make_rng() matching_matrix = jnp.eye(number_of_examples) matching_matrix = jax.random.permutation(rng, matching_matrix) elif mode == 'sinkhorn': epsilon = 0.1 num_iters = 100 reps = reps.reshape((reps.shape[0], -1)) x = y = reps x_labels = y_labels = batch['label'] # Solve sinkhorn in log space. num_x = x.shape[0] num_y = y.shape[0] # Marginal of rows (a) and columns (b) a = jnp.ones(shape=(num_x, ), dtype=x.dtype) b = jnp.ones(shape=(num_y, ), dtype=y.dtype) cost = domain_mapping_utils.pairwise_l2(x, y) cost += jnp.eye(num_x) * jnp.max(cost) * 10 # Adjust cost such that representations with different labels # get assigned a very high cost. same_labels = domain_mapping_utils.pairwise_equality_1d( x_labels, y_labels) adjusted_cost = (1 - same_labels) * label_cost + l2_cost * cost _, matching, _ = domain_mapping_utils.sinkhorn_dual_solver( a, b, adjusted_cost, epsilon, num_iters) matching_matrix = domain_mapping_utils.round_coupling( matching, jnp.ones((matching.shape[0], )), jnp.ones((matching.shape[1], ))) else: raise ValueError( '%s mode for self matching alignment is not supported.' % mode) return matching_matrix
def align_batches(self, x, y, x_labels, y_labels): """Computes optimal transport between two batches with Sinkhorn algorithm. This calls a sinkhorn solver in dual (log) space with a finite number of iterations and uses the dual unregularized transport cost as the OT cost. Args: x: jnp array; Batch of representations with shape '[bs, feature_size]'. y: jnp array; Batch of representations with shape '[bs, feature_size]'. x_labels: jnp array; labels of x with shape '[bs, 1]'. y_labels: jnp array; labels of y with shape '[bs, 1]'. Returns: ot_cost: scalar optimal transport loss. """ epsilon = self.task_params.get('sinkhorn_eps', 0.1) num_iters = self.task_params.get('sinkhorn_iters', 50) label_weight = self.task_params.get('ot_label_cost', 0.) l2_weight = self.task_params.get('ot_l2_cost', 0.) noise_weight = self.task_params.get('ot_noise_cost', 1.0) x = x.reshape((x.shape[0], -1)) y = y.reshape((x.shape[0], -1)) # Solve sinkhorn in log space. num_x = x.shape[0] num_y = y.shape[0] x = x.reshape((num_x, -1)) y = y.reshape((num_y, -1)) # Marginal of rows (a) and columns (b) a = jnp.ones(shape=(num_x, ), dtype=x.dtype) b = jnp.ones(shape=(num_y, ), dtype=y.dtype) # TODO(samiraabnar): Check range of l2 cost? cost = domain_mapping_utils.pairwise_l2(x, y) # Adjust cost such that representations with different labels # get assigned a very high cost. same_labels = domain_mapping_utils.pairwise_equality_1d( x_labels, y_labels) adjusted_cost = (1 - same_labels) * label_weight + l2_weight * cost # Add noise to the cost. adjusted_cost += noise_weight * jax.random.uniform( nn.make_rng(), minval=0, maxval=1.0) _, matching, _ = domain_mapping_utils.sinkhorn_dual_solver( a, b, adjusted_cost, epsilon, num_iters) matching = domain_mapping_utils.round_coupling(matching, a, b) if self.task_params.get('interpolation_mode', 'hard') == 'hard': matching = domain_mapping_utils.sample_best_permutation( nn.make_rng(), coupling=matching, cost=adjusted_cost) return matching
def ot_loss(self, x, y, x_labels, y_labels): """Computes optimal transport between two batches with Sinkhorn algorithm. This calls a sinkhorn solver in dual (log) space with a finite number of iterations and uses the dual unregularized transport cost as the OT cost. Args: x: jnp array; Batch of representations with shape '[bs, feature_size]'. y: jnp array; Batch of representations with shape '[bs, feature_size]'. x_labels: jnp array; labels of x with shape '[bs, 1]'. y_labels: jnp array; labels of y with shape '[bs, 1]'. Returns: ot_cost: scalar optimal transport loss. """ epsilon = self.task_params.get('sinkhorn_eps', 0.1) num_iters = self.task_params.get('sinkhorn_iters', 100) label_cost = self.task_params.get('ot_label_cost', 0.) # Solve sinkhorn in log space. num_x = x.shape[0] num_y = y.shape[0] x = x.reshape((num_x, -1)) y = y.reshape((num_y, -1)) # Marginal of rows (a) and columns (b) a = jnp.ones(shape=(num_x, ), dtype=x.dtype) / float(num_x) b = jnp.ones(shape=(num_y, ), dtype=y.dtype) / float(num_y) cost = domain_mapping_utils.pairwise_l2(x, y) # Adjust cost such that representations with different labels # get assigned a very high cost. same_labels = domain_mapping_utils.pairwise_equality_1d( x_labels, y_labels) # adjusted_cost = same_labels * cost + (1 - same_labels) * ( # label_cost * jnp.max(cost)) adjusted_cost = cost + (1 - same_labels) * label_cost ot_cost, _, _ = domain_mapping_utils.sinkhorn_dual_solver( a, b, adjusted_cost, epsilon, num_iters) return ot_cost