Ejemplo n.º 1
0
    def _get_candidates(self, originals: ep.Tensor, best_advs: ep.Tensor) -> ep.Tensor:
        """
        Find the lowest epsilon to misclassified x following the direction: q of class 1 / q + eps*direction of class 0
        """
        epsilons = ep.zeros(originals, len(originals))
        direction_2 = ep.zeros_like(originals)
        while (epsilons == 0).any():
            # if epsilon ==0, we are still searching a good direction
            direction_2 = ep.where(
                atleast_kd(epsilons == 0, direction_2.ndim),
                self._basis.get_vector(self._directions_ortho),
                direction_2
            )
            
            for i, eps_i in enumerate(epsilons):
                if eps_i == 0:
                    self._directions_ortho[i] = ep.concatenate((self._directions_ortho[i], direction_2[i].expand_dims(0)), axis=0)
                    if len(self._directions_ortho[i]) > self.n_ortho + 1:
                        self._directions_ortho[i] = ep.concatenate((self._directions_ortho[i][:1], self._directions_ortho[i][self.n_ortho:]))
                        
            function_evolution = self._get_evolution_function(originals, best_advs, direction_2)
            new_epsilons = self._get_best_theta(function_evolution, epsilons)

            self.theta_max = ep.where(new_epsilons == 0, self.theta_max * self.rho, self.theta_max)
            self.theta_max = ep.where((new_epsilons != 0) * (epsilons == 0), self.theta_max / self.rho, self.theta_max)
            epsilons = new_epsilons

        function_evolution = self._get_evolution_function(originals, best_advs, direction_2)
        if self.with_alpha_line_search:
            epsilons = self._binary_search_on_alpha(function_evolution, epsilons)

        epsilons = epsilons.expand_dims(0)
        if self.with_interpolation:
            epsilons =  ep.concatenate((epsilons, epsilons[0] / 2), axis=0)

        candidates = ep.concatenate([function_evolution(eps).expand_dims(0) for eps in epsilons], axis=0)

        if self.with_interpolation:
            d = self.distance(best_advs, originals)
            delta = self.distance(self._binary_search(originals, candidates[1],  boost=True), originals)
            theta_star = epsilons[0]

            num = theta_star * (4 * delta - d * (self._cos(theta_star.raw) + 3))
            den = 4 * (2 * delta - d * (self._cos(theta_star.raw) + 1))

            theta_hat = num / den
            q_interp = function_evolution(theta_hat)
            if self.with_distance_line_search:
                q_interp = self._binary_search(originals, q_interp,  boost=True)
            candidates = ep.concatenate((candidates, q_interp.expand_dims(0)), axis=0)

        return candidates
Ejemplo n.º 2
0
    def run(
        self,
        model: Model,
        inputs: T,
        criterion: Union[Criterion, T],
        *,
        early_stop: Optional[float] = None,
        starting_points: Optional[ep.Tensor] = None,
        **kwargs: Any,
    ) -> T:
        originals, restore_type = ep.astensor_(inputs)

        self._nqueries = {i: 0 for i in range(len(originals))}
        self._set_cos_sin_function(originals)
        self.theta_max = ep.ones(originals, len(originals)) * self._theta_max
        criterion = get_criterion(criterion)
        self._criterion_is_adversarial = get_is_adversarial(criterion, model)

        # Get Starting Point
        if starting_points is not None:
            best_advs = starting_points
        elif starting_points is None:
            init_attack: MinimizationAttack = LinearSearchBlendedUniformNoiseAttack(steps=50)
            best_advs = init_attack.run(model, originals, criterion, early_stop=early_stop)
        else:
            raise ValueError("starting_points {} doesn't exist.".format(starting_points))

        assert self._is_adversarial(best_advs).all()

        # Initialize the direction orthogonalized with the first direction
        fd = best_advs - originals
        norm = ep.norms.l2(fd.flatten(1), axis=1)
        fd = fd / atleast_kd(norm, fd.ndim)
        self._directions_ortho = {i: v.expand_dims(0) for i, v in enumerate(fd)}

        # Load Basis
        if "basis_params" in kwargs:
            self._basis = Basis(originals, **kwargs["basis_params"])
        else:
            self._basis = Basis(originals)

        for _ in range(self._steps):
            # Get candidates. Shape: (n_candidates, batch_size, image_size)
            candidates = self._get_candidates(originals, best_advs)
            candidates = candidates.transpose((1, 0, 2, 3, 4))

            
            best_candidates = ep.zeros_like(best_advs).raw
            for i, o in enumerate(originals):
                o_repeated = ep.concatenate([o.expand_dims(0)] * len(candidates[i]), axis=0)
                index = ep.argmax(self.distance(o_repeated, candidates[i])).raw
                best_candidates[i] = candidates[i][index].raw

            is_success = self.distance(best_candidates, originals) < self.distance(best_advs, originals)
            best_advs = ep.where(atleast_kd(is_success, best_candidates.ndim), ep.astensor(best_candidates), best_advs)

            if all(v > self._max_queries for v in self._nqueries.values()):
                print("Max queries attained for all the images.")
                break
        return restore_type(best_advs)
Ejemplo n.º 3
0
 def _gram_schmidt(self, v: ep.Tensor, ortho_with: ep.Tensor):
     v_repeated = ep.concatenate([v.expand_dims(0)] * len(ortho_with), axis=0)
     
     #inner product
     gs_coeff = (ortho_with * v_repeated).flatten(1).sum(1)
     proj = atleast_kd(gs_coeff, ortho_with.ndim) * ortho_with
     v = v - proj.sum(0)
     return v / ep.norms.l2(v)
Ejemplo n.º 4
0
def uniform_l1_n_balls(dummy: ep.Tensor, batch_size: int, n: int) -> ep.Tensor:
    # https://mathoverflow.net/a/9188
    u = ep.uniform(dummy, (batch_size, n))
    v = u.sort(axis=-1)
    vp = ep.concatenate([ep.zeros(v, (batch_size, 1)), v[:, :n - 1]], axis=-1)
    assert v.shape == vp.shape
    x = v - vp
    sign = ep.uniform(dummy, (batch_size, n), low=-1.0, high=1.0).sign()
    return sign * x
Ejemplo n.º 5
0
def l2_clipping_aware_rescaling(x,
                                delta,
                                eps: float,
                                a: float = 0.0,
                                b: float = 1.0):  # type: ignore
    """Calculates eta such that norm(clip(x + eta * delta, a, b) - x) == eps.

    Assumes x and delta have a batch dimension and eps, a, b, and p are
    scalars. If the equation cannot be solved because eps is too large, the
    left hand side is maximized.

    Args:
        x: A batch of inputs (PyTorch Tensor, TensorFlow Eager Tensor, NumPy
            Array, JAX Array, or EagerPy Tensor).
        delta: A batch of perturbation directions (same shape and type as x).
        eps: The target norm (non-negative float).
        a: The lower bound of the data domain (float).
        b: The upper bound of the data domain (float).

    Returns:
        eta: A batch of scales with the same number of dimensions as x but all
            axis == 1 except for the batch dimension.
    """
    (x, delta), restore_fn = ep.astensors_(x, delta)
    N = x.shape[0]
    assert delta.shape[0] == N
    rows = ep.arange(x, N)

    delta2 = delta.square().reshape((N, -1))
    space = ep.where(delta >= 0, b - x, x - a).reshape((N, -1))
    f2 = space.square() / ep.maximum(delta2, 1e-20)
    ks = ep.argsort(f2, axis=-1)
    f2_sorted = f2[rows[:, ep.newaxis], ks]
    m = ep.cumsum(delta2[rows[:, ep.newaxis],
                         ks.flip(axis=1)], axis=-1).flip(axis=1)
    dx = f2_sorted[:, 1:] - f2_sorted[:, :-1]
    dx = ep.concatenate((f2_sorted[:, :1], dx), axis=-1)
    dy = m * dx
    y = ep.cumsum(dy, axis=-1)
    c = y >= eps**2

    # work-around to get first nonzero element in each row
    f = ep.arange(x, c.shape[-1], 0, -1)
    j = ep.argmax(c.astype(f.dtype) * f, axis=-1)

    eta2 = f2_sorted[rows, j] - (y[rows, j] - eps**2) / m[rows, j]
    # it can happen that for certain rows even the largest j is not large enough
    # (i.e. c[:, -1] is False), then we will just use it (without any correction) as it's
    # the best we can do (this should also be the only cases where m[j] can be
    # 0 and they are thus not a problem)
    eta2 = ep.where(c[:, -1], eta2, f2_sorted[:, -1])
    eta = ep.sqrt(eta2)
    eta = eta.reshape((-1, ) + (1, ) * (x.ndim - 1))

    # xp = ep.clip(x + eta * delta, a, b)
    # l2 = (xp - x).reshape((N, -1)).square().sum(axis=-1).sqrt()
    return restore_fn(eta)
Ejemplo n.º 6
0
    def process_raw(self) -> None:
        raw_inputs = self.raw_inputs
        raw_outputs = self.raw_outputs
        assert len(raw_inputs) == len(raw_outputs)
        assert (self.inputs is None) == (self.outputs is None)

        if self.inputs is None:
            if len(raw_inputs) == 0:
                raise ValueError(
                    "DatasetAttack can only be called after data has been provided using 'feed()'"
                )
        elif self.inputs is not None:
            assert self.outputs is not None
            raw_inputs = [self.inputs] + raw_inputs
            raw_outputs = [self.outputs] + raw_outputs

        self.inputs = ep.concatenate(raw_inputs, axis=0)
        self.outputs = ep.concatenate(raw_outputs, axis=0)
        self.raw_inputs = []
        self.raw_outputs = []
Ejemplo n.º 7
0
    def get_vector(self, ortho_with: Optional[Dict] = None, bounds: Tuple[float, float] = (0, 1)) -> ep.Tensor:
        if ortho_with is None:
            ortho_with = {i: None for i in range(len(self._originals))}
        r: ep.Tensor = self._function_generation()

        vectors = [
            self._gram_schmidt(r[i], ortho_with[i]).expand_dims(0)
            for i in ortho_with
        ]
        vectors = ep.concatenate(vectors, axis=0)

        return vectors
Ejemplo n.º 8
0
def test_newtonfool_run_raises(
    fmodel_and_data_ext_for_attacks: ModeAndDataAndDescription, ) -> None:
    (fmodel, x, y), _, _ = fmodel_and_data_ext_for_attacks
    if isinstance(x, ep.NumPyTensor):
        pytest.skip()

    with pytest.raises(ValueError, match="unsupported criterion"):
        attack = fbn.attacks.NewtonFoolAttack()
        attack.run(fmodel, x, fbn.TargetedMisclassification(y))

    with pytest.raises(ValueError, match="expected labels to have shape"):
        attack = fbn.attacks.NewtonFoolAttack(steps=10)
        attack.run(fmodel, x, ep.concatenate((y, y), 0))
Ejemplo n.º 9
0
def run_attacks(MODEL_DIR, res_path):
    rel_dirs = [x for x in os.listdir(MODEL_DIR) if '2020' in x]
    alpha = [re.findall('a=([0-9, \.]*)_', d)[0] for d in rel_dirs]
    res = dict.fromkeys(alpha)
    learner = prep_learner()

    for model_path, curr_alpha in tqdm(zip(rel_dirs, alpha), total=len(alpha)):
        conf.save_path = Path(path.join(MODEL_DIR, model_path))
        fix_str = [
            x for x in os.listdir(path.join(MODEL_DIR, model_path))
            if 'model' in x
        ][0][8:]
        learner.load_state(conf,
                           fix_str,
                           model_only=True,
                           from_save_folder=True)

        # probs
        set_probes(learner)

        for model in learner.models:
            model = torch.nn.DataParallel(model.cuda(),
                                          device_ids=list(range(4)))
            model.eval()

        res[curr_alpha] = dict()
        for (attack,
             eps), attack_name in tqdm(zip(attack_list, attack_list_names),
                                       desc='attaking ' + str(curr_alpha),
                                       total=len(attack_list)):
            fmodel = JointModelEP(
                [PyTorchModel(m, bounds=(0, 1)) for m in learner.models],
                'cuda')
            attack = attack()
            success_tot = []
            for images, labels in tqdm(learner.eval_loader,
                                       total=len(learner.eval_loader),
                                       desc=attack_name):
                images, labels = ep.astensors(images.to('cuda'),
                                              labels.to('cuda'))
                _, _, success = attack(fmodel, images, labels, epsilons=eps)
                success_tot.append(success)
            success_tot = ep.concatenate(success_tot, -1)

            # calculate and report the robust accuracy
            robust_accuracy = 1 - success_tot.float32().mean(axis=-1)
            for epsilon, acc in zip(eps, robust_accuracy):
                res[curr_alpha][attack_name + '_' + str(epsilon)] = acc.item()

            pickle.dump(res, open(res_path, 'wb'))
        pickle.dump(res, open(res_path, 'wb'))
Ejemplo n.º 10
0
def test_vat_run_raises(
    fmodel_and_data_ext_for_attacks: ModelDescriptionAndData,
) -> None:
    (fmodel, x, y), _ = fmodel_and_data_ext_for_attacks
    if isinstance(x, ep.NumPyTensor):
        pytest.skip()

    with pytest.raises(ValueError, match="unsupported criterion"):
        attack = fbn.attacks.VirtualAdversarialAttack(steps=10)
        attack.run(fmodel, x, fbn.TargetedMisclassification(y), epsilon=1.0)

    with pytest.raises(ValueError, match="expected labels to have shape"):
        attack = fbn.attacks.VirtualAdversarialAttack(steps=10)
        attack.run(fmodel, x, ep.concatenate((y, y), 0), epsilon=1.0)
Ejemplo n.º 11
0
def test_targeted_attacks_call_raises_exception(
    fmodel_and_data_ext_for_attacks: Tuple[Tuple[fbn.Model, ep.Tensor,
                                                 ep.Tensor], bool],
    attack_exception_text_and_grad: Tuple[fbn.Attack, bool],
) -> None:

    attack, attack_uses_grad = attack_exception_text_and_grad
    (fmodel, x, y), _ = fmodel_and_data_ext_for_attacks

    if isinstance(x, ep.NumPyTensor) and attack_uses_grad:
        pytest.skip()

    x = (x - fmodel.bounds.lower) / (fmodel.bounds.upper - fmodel.bounds.lower)
    fmodel = fmodel.transform_bounds((0, 1))

    num_classes = fmodel(x).shape[-1]
    target_classes = (y + 1) % num_classes
    invalid_target_classes = ep.concatenate((target_classes, target_classes),
                                            0)
    invalid_targeted_criterion = fbn.TargetedMisclassification(
        invalid_target_classes)

    class DummyCriterion(fbn.Criterion):
        """Criterion without any functionality which is just meant to be
        rejected by the attacks
        """
        def __repr__(self) -> str:
            return ""

        def __call__(self, perturbed: fbn.criteria.T,
                     outputs: fbn.criteria.T) -> fbn.criteria.T:
            return perturbed

    invalid_criterion = DummyCriterion()

    # check if targeted attack criterion with invalid number of classes is rejected
    with pytest.raises(ValueError):
        attack(fmodel, x, invalid_targeted_criterion, epsilons=1000.0)

    # check if only the two valid criteria are accepted
    with pytest.raises(ValueError):
        attack(fmodel, x, invalid_criterion, epsilons=1000.0)
Ejemplo n.º 12
0
def test_concatenate_axis1(dummy: Tensor) -> Tensor:
    t1 = ep.arange(dummy, 12).float32().reshape((3, 4))
    t2 = ep.arange(dummy, 20, 32, 2).float32().reshape((3, 2))
    return ep.concatenate([t1, t2], axis=1)
Ejemplo n.º 13
0
def wasserstein_distance(X,
                         Y,
                         matching=False,
                         order=1.,
                         internal_p=np.inf,
                         enable_autodiff=False):
    '''
    :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
                (i.e. with infinite coordinate).
    :param Y: (m x 2) numpy.array encoding the second diagram.
    :param matching: if True, computes and returns the optimal matching between X and Y, encoded as
                     a (n x 2) np.array  [...[i,j]...], meaning the i-th point in X is matched to
                     the j-th point in Y, with the convention (-1) represents the diagonal.
    :param order: exponent for Wasserstein; Default value is 1.
    :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
                       Default value is `np.inf`.
    :param enable_autodiff: If X and Y are torch.tensor or tensorflow.Tensor, make the computation
        transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
        with `matching=True`.

        .. note:: This considers the function defined on the coordinates of the off-diagonal points of X and Y
            and lets the various frameworks compute its gradient. It never pulls new points from the diagonal.
    :type enable_autodiff: bool
    :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
              respect to the internal_p-norm as ground metric.
              If matching is set to True, also returns the optimal matching between X and Y.
    '''
    n = len(X)
    m = len(Y)

    # handle empty diagrams
    if n == 0:
        if m == 0:
            if not matching:
                # What if enable_autodiff?
                return 0.
            else:
                return 0., np.array([])
        else:
            if not matching:
                return _perstot(Y, order, internal_p, enable_autodiff)
            else:
                return _perstot(Y, order, internal_p,
                                enable_autodiff), np.array([[-1, j]
                                                            for j in range(m)])
    elif m == 0:
        if not matching:
            return _perstot(X, order, internal_p, enable_autodiff)
        else:
            return _perstot(X, order, internal_p,
                            enable_autodiff), np.array([[i, -1]
                                                        for i in range(n)])

    if enable_autodiff:
        import eagerpy as ep

        X_orig = ep.astensor(X)
        Y_orig = ep.astensor(Y)
        X = X_orig.numpy()
        Y = Y_orig.numpy()
    M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p)
    a = np.ones(n + 1)  # weight vector of the input diagram. Uniform here.
    a[-1] = m
    b = np.ones(m + 1)  # weight vector of the input diagram. Uniform here.
    b[-1] = n

    if matching:
        assert not enable_autodiff, "matching and enable_autodiff are currently incompatible"
        P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
        ot_cost = np.sum(np.multiply(P, M))
        P[-1, -1] = 0  # Remove matching corresponding to the diagonal
        match = np.argwhere(P)
        # Now we turn to -1 points encoding the diagonal
        match[:, 0][match[:, 0] >= n] = -1
        match[:, 1][match[:, 1] >= m] = -1
        return ot_cost**(1. / order), match

    if enable_autodiff:
        P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
        pairs_X_Y = np.argwhere(P[:-1, :-1])
        pairs_X_diag = np.nonzero(P[:-1, -1])
        pairs_Y_diag = np.nonzero(P[-1, :-1])
        dists = []
        # empty arrays are not handled properly by the helpers, so we avoid calling them
        if len(pairs_X_Y):
            dists.append(
                (Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(
                    internal_p, axis=-1).norms.lp(order))
        if len(pairs_X_diag[0]):
            dists.append(
                _perstot_autodiff(X_orig[pairs_X_diag], order, internal_p))
        if len(pairs_Y_diag[0]):
            dists.append(
                _perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p))
        dists = [dist.reshape(1) for dist in dists]
        return ep.concatenate(dists).norms.lp(order).raw
        # We can also concatenate the 3 vectors to compute just one norm.

    # Comptuation of the otcost using the ot.emd2 library.
    # Note: it is the Wasserstein distance to the power q.
    # The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?
    ot_cost = ot.emd2(a, b, M, numItermax=2000000)

    return ot_cost**(1. / order)
Ejemplo n.º 14
0
def wasserstein_distance(X,
                         Y,
                         matching=False,
                         order=1.,
                         internal_p=np.inf,
                         enable_autodiff=False,
                         keep_essential_parts=True):
    '''
    Compute the Wasserstein distance between persistence diagram using Python Optimal Transport backend.
    Diagrams can contain points with infinity coordinates (essential parts).
    Points with (-inf,-inf) and (+inf,+inf) coordinates are considered as belonging to the diagonal.
    If the distance between two diagrams is +inf (which happens if the cardinalities of essential
    parts differ) and optimal matching is required, it will be set to ``None``.

    :param X: The first diagram.
    :type X: n x 2 numpy.array
    :param Y:  The second diagram.
    :type Y: m x 2 numpy.array
    :param matching: if ``True``, computes and returns the optimal matching between X and Y, encoded as
        a (n x 2) np.array  [...[i,j]...], meaning the i-th point in X is matched to
        the j-th point in Y, with the convention that (-1) represents the diagonal.
    :param order: Wasserstein exponent q (1 <= q < infinity).
    :type order: float
    :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2).
    :type internal_p: float
    :param enable_autodiff: If X and Y are ``torch.tensor`` or ``tensorflow.Tensor``, make the computation
        transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
        with ``matching=True`` and with ``keep_essential_parts=True``.

        .. note:: This considers the function defined on the coordinates of the off-diagonal finite points of X and Y
            and lets the various frameworks compute its gradient. It never pulls new points from the diagonal.
    :type enable_autodiff: bool
    :param keep_essential_parts: If ``False``, only considers the finite points in the diagrams.
                                 Otherwise, include essential parts in cost and matching computation.
    :type keep_essential_parts: bool
    :returns: The Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
              respect to the internal_p-norm as ground metric.
              If matching is set to True, also returns the optimal matching between X and Y.
              If cost is +inf, any matching is optimal and thus it returns `None` instead.
    '''

    # First step: handle empty diagrams
    n = len(X)
    m = len(Y)

    if n == 0:
        if m == 0:
            if not matching:
                # What if enable_autodiff?
                return 0.
            else:
                return 0., np.array([])
        else:
            cost = _perstot(Y, order, internal_p, enable_autodiff)
            if cost == np.inf:
                return _warn_infty(matching)
            else:
                if not matching:
                    return cost
                else:
                    return cost, np.array([[-1, j] for j in range(m)])
    elif m == 0:
        cost = _perstot(X, order, internal_p, enable_autodiff)
        if cost == np.inf:
            return _warn_infty(matching)
        else:
            if not matching:
                return cost
            else:
                return cost, np.array([[i, -1] for i in range(n)])

    # Check essential part and enable autodiff together
    if enable_autodiff and keep_essential_parts:
        warnings.warn(
            '''enable_autodiff=True and keep_essential_parts=True are incompatible together.
                      keep_essential_parts is set to False: only points with finite coordinates are considered
                      in the following.
                      ''')
        keep_essential_parts = False

    # Second step: handle essential parts if needed.
    if keep_essential_parts:
        essential_cost, essential_matching = _handle_essential_parts(
            X, Y, order=order)
        if (essential_cost == np.inf):
            return _warn_infty(
                matching
            )  # Tells the user that cost is infty and matching (if True) is None.
            # avoid computing transport cost between the finite parts if essential parts
            # cardinalities do not match (saves time)
    else:
        essential_cost = 0
        essential_matching = None

    # Now the standard pipeline for finite parts
    if enable_autodiff:
        import eagerpy as ep

        X_orig = ep.astensor(X)
        Y_orig = ep.astensor(Y)
        X = X_orig.numpy()
        Y = Y_orig.numpy()

    # Extract finite points of the diagrams.
    X, Y = _finite_part(X), _finite_part(Y)
    n = len(X)
    m = len(Y)

    M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p)
    a = np.ones(n + 1)  # weight vector of the input diagram. Uniform here.
    a[-1] = m
    b = np.ones(m + 1)  # weight vector of the input diagram. Uniform here.
    b[-1] = n

    if matching:
        assert not enable_autodiff, "matching and enable_autodiff are currently incompatible"
        P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
        ot_cost = np.sum(np.multiply(P, M))
        P[-1, -1] = 0  # Remove matching corresponding to the diagonal
        match = np.argwhere(P)
        # Now we turn to -1 points encoding the diagonal
        match[:, 0][match[:, 0] >= n] = -1
        match[:, 1][match[:, 1] >= m] = -1
        # Finally incorporate the essential part matching
        if essential_matching is not None:
            match = np.concatenate([match, essential_matching
                                    ]) if essential_matching.size else match
        return (ot_cost + essential_cost)**(1. / order), match

    if enable_autodiff:
        P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
        pairs_X_Y = np.argwhere(P[:-1, :-1])
        pairs_X_diag = np.nonzero(P[:-1, -1])
        pairs_Y_diag = np.nonzero(P[-1, :-1])
        dists = []
        # empty arrays are not handled properly by the helpers, so we avoid calling them
        if len(pairs_X_Y):
            dists.append(
                (Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(
                    internal_p, axis=-1).norms.lp(order))
        if len(pairs_X_diag[0]):
            dists.append(
                _perstot_autodiff(X_orig[pairs_X_diag], order, internal_p))
        if len(pairs_Y_diag[0]):
            dists.append(
                _perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p))
        dists = [dist.reshape(1) for dist in dists]
        return ep.concatenate(dists).norms.lp(order).raw
        # We can also concatenate the 3 vectors to compute just one norm.

    # Comptuation of the ot cost using the ot.emd2 library.
    # Note: it is the Wasserstein distance to the power q.
    # The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?
    ot_cost = ot.emd2(a, b, M, numItermax=2000000)

    return (ot_cost + essential_cost)**(1. / order)
Ejemplo n.º 15
0
 def x_path(self):
     path = ep.concatenate(self._x_path, axis=0)
     return path[:-1, ...]  # removes last point