Beispiel #1
0
def sample_random_triplets(key, inputs, n_random, distance_fn, sig):
    """Sample uniformly random triplets.

  Args:
    key: Random key.
    inputs: Input points.
    n_random: Number of random triplets per point.
    distance_fn: Distance function.
    sig: Scaling factor for the distances

  Returns:
    triplets: Sampled triplets.
  """
    n_points = inputs.shape[0]
    anchors = jnp.tile(jnp.arange(n_points).reshape([-1, 1]),
                       [1, n_random]).reshape([-1, 1])
    pairs = rejection_sample(key, (n_points * n_random, 2), n_points, anchors)
    triplets = jnp.concatenate((anchors, pairs), 1)
    anc = triplets[:, 0]
    sim = triplets[:, 1]
    out = triplets[:, 2]
    p_sim = -(sliced_distances(anc, sim, inputs, distance_fn)**2) / (sig[anc] *
                                                                     sig[sim])
    p_out = -(sliced_distances(anc, out, inputs, distance_fn)**2) / (sig[anc] *
                                                                     sig[out])
    flip = p_sim < p_out
    weights = p_sim - p_out
    pairs = jnp.where(jnp.tile(flip.reshape([-1, 1]), [1, 2]),
                      jnp.fliplr(pairs), pairs)
    triplets = jnp.concatenate((anchors, pairs), 1)
    return triplets, weights
Beispiel #2
0
def transpose(m: Array) -> Array:

    n = m.shape[0]
    fmp = jnp.fliplr(jnp.pad(m, ((_W // 2, _W // 2), (0, 0))))

    def onerow(i):
        return jnp.diag(lax.dynamic_slice_in_dim(fmp, i, _W, axis=0))

    return vmap(onerow)(jnp.arange(n))
Beispiel #3
0
    def lr_flip(o):
        image, boxes = o

        image = np.fliplr(image)

        # Flip the box
        x1, y1, x2, y2 = np.split(boxes, 4, axis=-1)

        bb_w = x2 - x1
        delta_W = np.expand_dims(boxes[..., 0], axis=-1)

        x1 = 1. - delta_W - bb_w
        x2 = 1. - delta_W

        boxes = np.stack([x1, y1, x2, y2], axis=-1)
        boxes = boxes.reshape(-1, 4)

        return image, boxes
Beispiel #4
0
 def flip_multiply(x, y):
     return x * np.fliplr(y)
Beispiel #5
0
def random_flip(m, k):
    return jnp.where(k == 0, m, jnp.fliplr(m))
Beispiel #6
0
def generate_triplets(key,
                      inputs,
                      n_inliers,
                      n_outliers,
                      n_random,
                      weight_temp=0.5,
                      distance='euclidean',
                      verbose=False):
    """Generate triplets.

  Args:
    key: Random key.
    inputs: Input points.
    n_inliers: Number of inliers.
    n_outliers: Number of outliers.
    n_random: Number of random triplets per point.
    weight_temp: Temperature of the log transformation on the weights.
    distance: Distance type.
    verbose: Whether to print progress.

  Returns:
    triplets and weights
  """
    n_points = inputs.shape[0]
    n_extra = min(n_inliers + 50, n_points)
    index = pynndescent.NNDescent(inputs, metric=distance)
    index.prepare()
    neighbors = index.query(inputs, n_extra)[0]
    neighbors = np.concatenate(
        (np.arange(n_points).reshape([-1, 1]), neighbors), 1)
    if verbose:
        logging.info('found nearest neighbors')
    distance_fn = get_distance_fn(distance)
    # conpute scaled neighbors and the scale parameter
    knn_distances, neighbors, sig = find_scaled_neighbors(
        inputs, neighbors, distance_fn)
    neighbors = neighbors[:, :n_inliers + 1]
    knn_distances = knn_distances[:, :n_inliers + 1]
    key, use_key = random.split(key)
    triplets = sample_knn_triplets(use_key, neighbors, n_inliers, n_outliers)
    weights = find_triplet_weights(inputs,
                                   triplets,
                                   neighbors[:, 1:n_inliers + 1],
                                   distance_fn,
                                   sig,
                                   distances=knn_distances[:, 1:n_inliers + 1])
    flip = weights < 0
    anchors, pairs = triplets[:, 0].reshape([-1, 1]), triplets[:, 1:]
    pairs = jnp.where(jnp.tile(flip.reshape([-1, 1]), [1, 2]),
                      jnp.fliplr(pairs), pairs)
    triplets = jnp.concatenate((anchors, pairs), 1)

    if n_random > 0:
        key, use_key = random.split(key)
        rand_triplets, rand_weights = sample_random_triplets(
            use_key, inputs, n_random, distance_fn, sig)

        triplets = jnp.concatenate((triplets, rand_triplets), 0)
        weights = jnp.concatenate((weights, 0.1 * rand_weights))

    weights -= jnp.min(weights)
    weights = tempered_log(1. + weights, weight_temp)
    return triplets, weights