def find_nn_gpu(F0, F1, nn_max_n=-1, return_distance=False, dist_type='SquareL2'): # Too much memory if F0 or F1 large. Divide the F0 if nn_max_n > 1: N = len(F0) C = int(np.ceil(N / nn_max_n)) stride = nn_max_n dists, inds = [], [] for i in range(C): dist = pdist(F0[i * stride:(i + 1) * stride], F1, dist_type=dist_type) min_dist, ind = dist.min(dim=1) dists.append(min_dist.detach().unsqueeze(1).cpu()) inds.append(ind.cpu()) if C * stride < N: dist = pdist(F0[C * stride:], F1, dist_type=dist_type) min_dist, ind = dist.min(dim=1) dists.append(min_dist.detach().unsqueeze(1).cpu()) inds.append(ind.cpu()) dists = torch.cat(dists) inds = torch.cat(inds) assert len(inds) == N else: dist = pdist(F0, F1, dist_type=dist_type) min_dist, inds = dist.min(dim=1) dists = min_dist.detach().unsqueeze(1).cpu() inds = inds.cpu() if return_distance: return inds, dists else: return inds
def contrastive_hardest_negative_loss(self, F0, F1, positive_pairs, num_pos=5192, num_hn_samples=2048, thresh=None): """ Generate negative pairs """ N0, N1 = len(F0), len(F1) N_pos_pairs = len(positive_pairs) hash_seed = max(N0, N1) sel0 = np.random.choice(N0, min(N0, num_hn_samples), replace=False) sel1 = np.random.choice(N1, min(N1, num_hn_samples), replace=False) if N_pos_pairs > num_pos: pos_sel = np.random.choice(N_pos_pairs, num_pos, replace=False) sample_pos_pairs = positive_pairs[pos_sel] else: sample_pos_pairs = positive_pairs # Find negatives for all F1[positive_pairs[:, 1]] subF0, subF1 = F0[sel0], F1[sel1] pos_ind0 = sample_pos_pairs[:, 0].long() pos_ind1 = sample_pos_pairs[:, 1].long() posF0, posF1 = F0[pos_ind0], F1[pos_ind1] D01 = pdist(posF0, subF1, dist_type='L2') D10 = pdist(posF1, subF0, dist_type='L2') D01min, D01ind = D01.min(1) D10min, D10ind = D10.min(1) if not isinstance(positive_pairs, np.ndarray): positive_pairs = np.array(positive_pairs, dtype=np.int64) pos_keys = _hash(positive_pairs, hash_seed) D01ind = sel1[D01ind.cpu().numpy()] D10ind = sel0[D10ind.cpu().numpy()] neg_keys0 = _hash([pos_ind0.numpy(), D01ind], hash_seed) neg_keys1 = _hash([D10ind, pos_ind1.numpy()], hash_seed) mask0 = torch.from_numpy( np.logical_not(np.isin(neg_keys0, pos_keys, assume_unique=False))) mask1 = torch.from_numpy( np.logical_not(np.isin(neg_keys1, pos_keys, assume_unique=False))) pos_loss = F.relu((posF0 - posF1).pow(2).sum(1) - self.pos_thresh) neg_loss0 = F.relu(self.neg_thresh - D01min[mask0]).pow(2) neg_loss1 = F.relu(self.neg_thresh - D10min[mask1]).pow(2) return pos_loss.mean(), (neg_loss0.mean() + neg_loss1.mean()) / 2
def contrastive_hardest_negative_loss(self, xyz0_rot, xyz1, F0, F1, positive_pairs, num_pos=5192, num_hn_samples=2048, matching_search_voxel_size=1.2, thresh=None): """ Generate negative pairs """ N0, N1 = len(F0), len(F1) N_pos_pairs = len(positive_pairs) sel0 = np.random.choice(N0, min(N0, num_hn_samples), replace=False) sel1 = np.random.choice(N1, min(N1, num_hn_samples), replace=False) if N_pos_pairs > num_pos: pos_sel = np.random.choice(N_pos_pairs, num_pos, replace=False) sample_pos_pairs = positive_pairs[pos_sel] else: sample_pos_pairs = positive_pairs # Find negatives for all F1[positive_pairs[:, 1]] subF0, subF1 = F0[sel0], F1[sel1] pos_ind0 = sample_pos_pairs[:, 0].long() pos_ind1 = sample_pos_pairs[:, 1].long() posF0, posF1 = F0[pos_ind0], F1[pos_ind1] D01 = pdist(posF0, subF1, dist_type='L2') D10 = pdist(posF1, subF0, dist_type='L2') D01min, D01ind = D01.min(1) D10min, D10ind = D10.min(1) D01ind = sel1[D01ind.cpu().numpy()] D10ind = sel0[D10ind.cpu().numpy()] def d(x, y): return torch.sqrt(torch.sum((x - y)**2, dim=1)) mask0 = d(xyz0_rot[pos_ind0, :], xyz1[D01ind, :]) > matching_search_voxel_size mask1 = d(xyz0_rot[D10ind, :], xyz1[pos_ind1, :]) > matching_search_voxel_size pos_loss = F.relu((posF0 - posF1).pow(2).sum(1) - self.pos_thresh) neg_loss0 = F.relu(self.neg_thresh - D01min[mask0]).pow(2) neg_loss1 = F.relu(self.neg_thresh - D10min[mask1]).pow(2) return pos_loss.mean(), (neg_loss0.mean() + neg_loss1.mean()) / 2
def triplet_loss(self, F0, F1, positive_pairs, num_pos=1024, num_hn_samples=512, num_rand_triplet=1024): """ Generate negative pairs """ N0, N1 = len(F0), len(F1) num_pos_pairs = len(positive_pairs) hash_seed = max(N0, N1) sel0 = np.random.choice(N0, min(N0, num_hn_samples), replace=False) sel1 = np.random.choice(N1, min(N1, num_hn_samples), replace=False) if num_pos_pairs > num_pos: pos_sel = np.random.choice(num_pos_pairs, num_pos, replace=False) sample_pos_pairs = positive_pairs[pos_sel] else: sample_pos_pairs = positive_pairs # Find negatives for all F1[positive_pairs[:, 1]] subF0, subF1 = F0[sel0], F1[sel1] pos_ind0 = sample_pos_pairs[:, 0].long() pos_ind1 = sample_pos_pairs[:, 1].long() posF0, posF1 = F0[pos_ind0], F1[pos_ind1] D01 = pdist(posF0, subF1, dist_type='L2') D10 = pdist(posF1, subF0, dist_type='L2') D01min, D01ind = D01.min(1) D10min, D10ind = D10.min(1) if not isinstance(positive_pairs, np.ndarray): positive_pairs = np.array(positive_pairs, dtype=np.int64) pos_keys = _hash(positive_pairs, hash_seed) D01ind = sel1[D01ind.cpu().numpy()] D10ind = sel0[D10ind.cpu().numpy()] neg_keys0 = _hash([pos_ind0.numpy(), D01ind], hash_seed) neg_keys1 = _hash([D10ind, pos_ind1.numpy()], hash_seed) mask0 = torch.from_numpy( np.logical_not(np.isin(neg_keys0, pos_keys, assume_unique=False))) mask1 = torch.from_numpy( np.logical_not(np.isin(neg_keys1, pos_keys, assume_unique=False))) pos_dist = torch.sqrt((posF0 - posF1).pow(2).sum(1) + 1e-7) # Random triplets rand_inds = np.random.choice( num_pos_pairs, min(num_pos_pairs, num_rand_triplet), replace=False) rand_pairs = positive_pairs[rand_inds] negatives = np.random.choice(N1, min(N1, num_rand_triplet), replace=False) # Remove positives from negatives rand_neg_keys = _hash([rand_pairs[:, 0], negatives], hash_seed) rand_mask = np.logical_not(np.isin(rand_neg_keys, pos_keys, assume_unique=False)) anchors, positives = rand_pairs[torch.from_numpy(rand_mask)].T negatives = negatives[rand_mask] rand_pos_dist = torch.sqrt((F0[anchors] - F1[positives]).pow(2).sum(1) + 1e-7) rand_neg_dist = torch.sqrt((F0[anchors] - F1[negatives]).pow(2).sum(1) + 1e-7) loss = F.relu( torch.cat([ rand_pos_dist + self.neg_thresh - rand_neg_dist, pos_dist[mask0] + self.neg_thresh - D01min[mask0], pos_dist[mask1] + self.neg_thresh - D10min[mask1] ])).mean() return loss, pos_dist.mean(), (D01min.mean() + D10min.mean()).item() / 2