Esempio n. 1
0
    def align_batches(self, x, y, x_labels, y_labels):
        """Computes alignment between two mini batches.

    In the MultiEnvHungarianDomainMappingClassification, this calls the
     hungarian matching function.


    Args:
      x: jnp array; Batch of representations with shape '[bs, feature_size]'.
      y: jnp array; Batch of representations with shape '[bs, feature_size]'.
      x_labels: jnp array; labels of x with shape '[bs, 1]'.
      y_labels: jnp array; labels of y with shape '[bs, 1]'.

    Returns:
      aligned indexes of x, aligned indexes of y.

    """

        label_cost = self.task_params.get('ot_label_cost', 0.)

        cost = domain_mapping_utils.pairwise_l2(x, y)

        # Adjust cost such that representations with different labels
        # get assigned a very high cost.
        same_labels = domain_mapping_utils.pairwise_equality_1d(
            x_labels, y_labels)
        adjusted_cost = cost + (1 - same_labels) * label_cost

        # `linear_sum_assignment`  computes cheapest hard alignment.
        x_ind, y_ind = scipy.optimize.linear_sum_assignment(adjusted_cost)

        return x_ind, y_ind
Esempio n. 2
0
def get_self_matching_matrix(batch,
                             reps,
                             mode='random',
                             label_cost=1.0,
                             l2_cost=1.0):
    """Align examples in a batch.

  Args:
    batch: list(dict); Batch of examples (with inputs, and label keys).
    reps: list(jnp array); List of representations of a selected layer for each
      batch.
    mode: str; Determines alignment method.
    label_cost: float; Weight of label cost when Sinkhorn matching is used.
    l2_cost: float; Weight of l2 cost when Sinkhorn matching is used.

  Returns:
    Matching matrix with shape `[num_batches, batch_size, batch_size]`.
  """
    if mode == 'random':
        number_of_examples = batch['inputs'].shape[0]
        rng = nn.make_rng()
        matching_matrix = jnp.eye(number_of_examples)
        matching_matrix = jax.random.permutation(rng, matching_matrix)
    elif mode == 'sinkhorn':
        epsilon = 0.1
        num_iters = 100

        reps = reps.reshape((reps.shape[0], -1))
        x = y = reps
        x_labels = y_labels = batch['label']

        # Solve sinkhorn in log space.
        num_x = x.shape[0]
        num_y = y.shape[0]

        # Marginal of rows (a) and columns (b)
        a = jnp.ones(shape=(num_x, ), dtype=x.dtype)
        b = jnp.ones(shape=(num_y, ), dtype=y.dtype)
        cost = domain_mapping_utils.pairwise_l2(x, y)
        cost += jnp.eye(num_x) * jnp.max(cost) * 10

        # Adjust cost such that representations with different labels
        # get assigned a very high cost.
        same_labels = domain_mapping_utils.pairwise_equality_1d(
            x_labels, y_labels)

        adjusted_cost = (1 - same_labels) * label_cost + l2_cost * cost
        _, matching, _ = domain_mapping_utils.sinkhorn_dual_solver(
            a, b, adjusted_cost, epsilon, num_iters)

        matching_matrix = domain_mapping_utils.round_coupling(
            matching, jnp.ones((matching.shape[0], )),
            jnp.ones((matching.shape[1], )))
    else:
        raise ValueError(
            '%s mode for self matching alignment is not supported.' % mode)
    return matching_matrix
Esempio n. 3
0
    def align_batches(self, x, y, x_labels, y_labels):
        """Computes optimal transport between two batches with Sinkhorn algorithm.

    This calls a sinkhorn solver in dual (log) space with a finite number
    of iterations and uses the dual unregularized transport cost as the OT cost.

    Args:
      x: jnp array; Batch of representations with shape '[bs, feature_size]'.
      y: jnp array; Batch of representations with shape '[bs, feature_size]'.
      x_labels: jnp array; labels of x with shape '[bs, 1]'.
      y_labels: jnp array; labels of y with shape '[bs, 1]'.

    Returns:
      ot_cost: scalar optimal transport loss.
    """

        epsilon = self.task_params.get('sinkhorn_eps', 0.1)
        num_iters = self.task_params.get('sinkhorn_iters', 50)
        label_weight = self.task_params.get('ot_label_cost', 0.)
        l2_weight = self.task_params.get('ot_l2_cost', 0.)
        noise_weight = self.task_params.get('ot_noise_cost', 1.0)
        x = x.reshape((x.shape[0], -1))
        y = y.reshape((x.shape[0], -1))

        # Solve sinkhorn in log space.
        num_x = x.shape[0]
        num_y = y.shape[0]

        x = x.reshape((num_x, -1))
        y = y.reshape((num_y, -1))

        # Marginal of rows (a) and columns (b)
        a = jnp.ones(shape=(num_x, ), dtype=x.dtype)
        b = jnp.ones(shape=(num_y, ), dtype=y.dtype)

        # TODO(samiraabnar): Check range of l2 cost?
        cost = domain_mapping_utils.pairwise_l2(x, y)

        # Adjust cost such that representations with different labels
        # get assigned a very high cost.
        same_labels = domain_mapping_utils.pairwise_equality_1d(
            x_labels, y_labels)
        adjusted_cost = (1 - same_labels) * label_weight + l2_weight * cost

        # Add noise to the cost.
        adjusted_cost += noise_weight * jax.random.uniform(
            nn.make_rng(), minval=0, maxval=1.0)
        _, matching, _ = domain_mapping_utils.sinkhorn_dual_solver(
            a, b, adjusted_cost, epsilon, num_iters)
        matching = domain_mapping_utils.round_coupling(matching, a, b)
        if self.task_params.get('interpolation_mode', 'hard') == 'hard':
            matching = domain_mapping_utils.sample_best_permutation(
                nn.make_rng(), coupling=matching, cost=adjusted_cost)

        return matching
Esempio n. 4
0
    def ot_loss(self, x, y, x_labels, y_labels):
        """Computes optimal transport between two batches with Sinkhorn algorithm.

    This calls a sinkhorn solver in dual (log) space with a finite number
    of iterations and uses the dual unregularized transport cost as the OT cost.

    Args:
      x: jnp array; Batch of representations with shape '[bs, feature_size]'.
      y: jnp array; Batch of representations with shape '[bs, feature_size]'.
      x_labels: jnp array; labels of x with shape '[bs, 1]'.
      y_labels: jnp array; labels of y with shape '[bs, 1]'.

    Returns:
      ot_cost: scalar optimal transport loss.
    """

        epsilon = self.task_params.get('sinkhorn_eps', 0.1)
        num_iters = self.task_params.get('sinkhorn_iters', 100)
        label_cost = self.task_params.get('ot_label_cost', 0.)

        # Solve sinkhorn in log space.
        num_x = x.shape[0]
        num_y = y.shape[0]

        x = x.reshape((num_x, -1))
        y = y.reshape((num_y, -1))

        # Marginal of rows (a) and columns (b)
        a = jnp.ones(shape=(num_x, ), dtype=x.dtype) / float(num_x)
        b = jnp.ones(shape=(num_y, ), dtype=y.dtype) / float(num_y)
        cost = domain_mapping_utils.pairwise_l2(x, y)

        # Adjust cost such that representations with different labels
        # get assigned a very high cost.
        same_labels = domain_mapping_utils.pairwise_equality_1d(
            x_labels, y_labels)
        # adjusted_cost = same_labels * cost + (1 - same_labels) * (
        #     label_cost * jnp.max(cost))
        adjusted_cost = cost + (1 - same_labels) * label_cost
        ot_cost, _, _ = domain_mapping_utils.sinkhorn_dual_solver(
            a, b, adjusted_cost, epsilon, num_iters)

        return ot_cost