def evaluate_encoder(encoder: nn.Module, test_loader: torch.utils.data.DataLoader, loss_fn: nn.TripletMarginLoss = None, writer: SummaryWriter = None, epoch: int = -1): with torch.no_grad(): encoder.eval() searcher = Searcher.get_simple_index(encoder.embedding_dim) embeddings_x = [] embeddings_y = [] epoch_losses = [] for step, (x, y_pos, y_neg) in enumerate(test_loader): x, y_pos, y_neg = x.cuda(), y_pos.cuda(), y_neg.cuda() x_enc = encoder(x) y_pos_enc = encoder(y_pos) y_neg_enc = encoder(y_neg) if loss_fn: loss_val = loss_fn(x_enc, y_pos_enc, y_neg_enc) epoch_losses.append(loss_val.item()) embeddings_x.append(x_enc) embeddings_y.append(y_pos_enc) print(' Test batch {} of {}'.format(step + 1, len(test_loader)), file=sys.stderr) embeddings_x = torch.cat(embeddings_x, dim=0) embeddings_y = torch.cat(embeddings_y, dim=0) searcher.add(embeddings_x) lookup = searcher.search(embeddings_y, 100) correct_100 = sum(y in x for y, x in enumerate(lookup[1])) / len(lookup[1]) correct_50 = sum(y in x[:50] for y, x in enumerate(lookup[1])) / len(lookup[1]) correct_10 = sum(y in x[:10] for y, x in enumerate(lookup[1])) / len(lookup[1]) correct_1 = sum(y == x[0] for y, x in enumerate(lookup[1])) / len(lookup[1]) print(f'Test loss: {np.mean(epoch_losses):.4f}') print( 'Test accuracy:\n top1 {}\n top10 {}\n top50 {}\n top100 {}' .format(correct_1, correct_10, correct_50, correct_100)) if writer: writer.add_scalars('Accuracy', { 'top1': correct_1, 'top10': correct_10, 'top50': correct_50, 'top100': correct_100, }, global_step=epoch) writer.add_scalar('Loss/test', np.mean(epoch_losses), global_step=epoch) if epoch == -1 or epoch % 5 == 1: mat = torch.cat([embeddings_x[:1000], embeddings_y[:1000]], dim=0) labels = list(range(1000)) + list(range(1000)) writer.add_embedding(mat, labels, tag='Embeddings', global_step=epoch) return correct_1 * 100, correct_100 * 100, lookup, embeddings_x, embeddings_y
def evaluate_all(small_encoder: SmallEncoder, large_encoder: LargeEncoder, test_loader: torch.utils.data.DataLoader, lookup_samples: int): with torch.no_grad(): small_encoder.eval() large_encoder.eval() searcher = Searcher.get_simple_index(small_encoder.embedding_dim) s_embeddings_x = [] s_embeddings_y = [] l_embeddings_x = [] l_embeddings_y = [] print('Calculating embeddings') for step, (x_s, y_s, x_l, y_l) in enumerate(test_loader): s_embeddings_x.append(small_encoder(x_s.cuda())) s_embeddings_y.append(small_encoder(y_s.cuda())) l_embeddings_x.append(large_encoder(x_l.cuda())) l_embeddings_y.append(large_encoder(y_l.cuda())) print(' Test batch {} of {}'.format(step + 1, len(test_loader)), file=sys.stderr) print('Merging results') s_embeddings_x = torch.cat(s_embeddings_x, dim=0).cpu() s_embeddings_y = torch.cat(s_embeddings_y, dim=0).cpu() l_embeddings_x = torch.cat(l_embeddings_x, dim=0).cpu() l_embeddings_y = torch.cat(l_embeddings_y, dim=0).cpu() print('Running kNN') searcher.add(s_embeddings_x) lookup = searcher.search(s_embeddings_y, lookup_samples) correct_100 = sum(y in x for y, x in enumerate(lookup[1])) correct_1 = sum(y == x[0] for y, x in enumerate(lookup[1])) print('Running verification') verified_1 = 0 verified_1l = 0 s_embeddings_x = s_embeddings_x.numpy() s_embeddings_y = s_embeddings_y.numpy() l_embeddings_x = l_embeddings_x.numpy() l_embeddings_y = l_embeddings_y.numpy() for idx, (knn, y_s, y_l) in enumerate(zip(lookup[1], s_embeddings_y, l_embeddings_y)): dists = [((((s_embeddings_x[v] - y_s) ** 2).mean() + ((l_embeddings_x[v] - y_l) ** 2).mean()), v) for v in knn] best = min(dists, key=itemgetter(0))[1] dists_l = [(((l_embeddings_x[v] - y_l) ** 2).mean(), v) for v in knn] best_l = min(dists_l, key=itemgetter(0))[1] if best == idx: verified_1 += 1 if best_l == idx: verified_1l += 1 print('Lookup accuracy: {}, correct guess: {}'.format(correct_100 / len(lookup[1]), correct_1 / len(lookup[1]))) print('Verification accuracy: (single encoder) {}, (dual encoder) {}'.format(verified_1l / correct_100, verified_1 / correct_100)) print('Final accuracy: (single encoder) {}, (dual encoder) {}'.format(verified_1l / len(lookup[1]), verified_1 / len(lookup[1])))
parser.add_argument('--index', type=str, help='Track lookup index save location', required=True) args = parser.parse_args() with contextlib.suppress(FileNotFoundError): os.remove(args.database) with contextlib.suppress(FileNotFoundError): os.remove(args.index) database = sqlite3.connect(args.database) database.execute( """create table samples (id integer primary key, name text, offset integer, s_hash blob, l_hash blob);""" ) searcher = Searcher.get_simple_index(SmallEncoder.embedding_dim) data = AudioIndexingDataset(args.data) data_loader = torch.utils.data.DataLoader(data, batch_size=None, num_workers=4, prefetch_factor=2) small_encoder = SmallEncoder().cuda() load_model_state(args.small_encoder, small_encoder) large_encoder = LargeEncoder().cuda() load_model_state(args.large_encoder, large_encoder) large_encoder.eval() small_encoder.eval() counter = 0 embeddings = []