def sample_random_triplets(key, inputs, n_random, distance_fn, sig): """Sample uniformly random triplets. Args: key: Random key. inputs: Input points. n_random: Number of random triplets per point. distance_fn: Distance function. sig: Scaling factor for the distances Returns: triplets: Sampled triplets. """ n_points = inputs.shape[0] anchors = jnp.tile(jnp.arange(n_points).reshape([-1, 1]), [1, n_random]).reshape([-1, 1]) pairs = rejection_sample(key, (n_points * n_random, 2), n_points, anchors) triplets = jnp.concatenate((anchors, pairs), 1) anc = triplets[:, 0] sim = triplets[:, 1] out = triplets[:, 2] p_sim = -(sliced_distances(anc, sim, inputs, distance_fn)**2) / (sig[anc] * sig[sim]) p_out = -(sliced_distances(anc, out, inputs, distance_fn)**2) / (sig[anc] * sig[out]) flip = p_sim < p_out weights = p_sim - p_out pairs = jnp.where(jnp.tile(flip.reshape([-1, 1]), [1, 2]), jnp.fliplr(pairs), pairs) triplets = jnp.concatenate((anchors, pairs), 1) return triplets, weights
def transpose(m: Array) -> Array: n = m.shape[0] fmp = jnp.fliplr(jnp.pad(m, ((_W // 2, _W // 2), (0, 0)))) def onerow(i): return jnp.diag(lax.dynamic_slice_in_dim(fmp, i, _W, axis=0)) return vmap(onerow)(jnp.arange(n))
def lr_flip(o): image, boxes = o image = np.fliplr(image) # Flip the box x1, y1, x2, y2 = np.split(boxes, 4, axis=-1) bb_w = x2 - x1 delta_W = np.expand_dims(boxes[..., 0], axis=-1) x1 = 1. - delta_W - bb_w x2 = 1. - delta_W boxes = np.stack([x1, y1, x2, y2], axis=-1) boxes = boxes.reshape(-1, 4) return image, boxes
def flip_multiply(x, y): return x * np.fliplr(y)
def random_flip(m, k): return jnp.where(k == 0, m, jnp.fliplr(m))
def generate_triplets(key, inputs, n_inliers, n_outliers, n_random, weight_temp=0.5, distance='euclidean', verbose=False): """Generate triplets. Args: key: Random key. inputs: Input points. n_inliers: Number of inliers. n_outliers: Number of outliers. n_random: Number of random triplets per point. weight_temp: Temperature of the log transformation on the weights. distance: Distance type. verbose: Whether to print progress. Returns: triplets and weights """ n_points = inputs.shape[0] n_extra = min(n_inliers + 50, n_points) index = pynndescent.NNDescent(inputs, metric=distance) index.prepare() neighbors = index.query(inputs, n_extra)[0] neighbors = np.concatenate( (np.arange(n_points).reshape([-1, 1]), neighbors), 1) if verbose: logging.info('found nearest neighbors') distance_fn = get_distance_fn(distance) # conpute scaled neighbors and the scale parameter knn_distances, neighbors, sig = find_scaled_neighbors( inputs, neighbors, distance_fn) neighbors = neighbors[:, :n_inliers + 1] knn_distances = knn_distances[:, :n_inliers + 1] key, use_key = random.split(key) triplets = sample_knn_triplets(use_key, neighbors, n_inliers, n_outliers) weights = find_triplet_weights(inputs, triplets, neighbors[:, 1:n_inliers + 1], distance_fn, sig, distances=knn_distances[:, 1:n_inliers + 1]) flip = weights < 0 anchors, pairs = triplets[:, 0].reshape([-1, 1]), triplets[:, 1:] pairs = jnp.where(jnp.tile(flip.reshape([-1, 1]), [1, 2]), jnp.fliplr(pairs), pairs) triplets = jnp.concatenate((anchors, pairs), 1) if n_random > 0: key, use_key = random.split(key) rand_triplets, rand_weights = sample_random_triplets( use_key, inputs, n_random, distance_fn, sig) triplets = jnp.concatenate((triplets, rand_triplets), 0) weights = jnp.concatenate((weights, 0.1 * rand_weights)) weights -= jnp.min(weights) weights = tempered_log(1. + weights, weight_temp) return triplets, weights