コード例 #1
0
ファイル: train.py プロジェクト: apardyl/ml-audio-recognition
def evaluate_encoder(encoder: nn.Module,
                     test_loader: torch.utils.data.DataLoader,
                     loss_fn: nn.TripletMarginLoss = None,
                     writer: SummaryWriter = None,
                     epoch: int = -1):
    with torch.no_grad():
        encoder.eval()
        searcher = Searcher.get_simple_index(encoder.embedding_dim)
        embeddings_x = []
        embeddings_y = []
        epoch_losses = []
        for step, (x, y_pos, y_neg) in enumerate(test_loader):
            x, y_pos, y_neg = x.cuda(), y_pos.cuda(), y_neg.cuda()
            x_enc = encoder(x)
            y_pos_enc = encoder(y_pos)
            y_neg_enc = encoder(y_neg)
            if loss_fn:
                loss_val = loss_fn(x_enc, y_pos_enc, y_neg_enc)
                epoch_losses.append(loss_val.item())
            embeddings_x.append(x_enc)
            embeddings_y.append(y_pos_enc)
            print('    Test batch {} of {}'.format(step + 1, len(test_loader)),
                  file=sys.stderr)

        embeddings_x = torch.cat(embeddings_x, dim=0)
        embeddings_y = torch.cat(embeddings_y, dim=0)
        searcher.add(embeddings_x)
        lookup = searcher.search(embeddings_y, 100)
        correct_100 = sum(y in x
                          for y, x in enumerate(lookup[1])) / len(lookup[1])
        correct_50 = sum(y in x[:50]
                         for y, x in enumerate(lookup[1])) / len(lookup[1])
        correct_10 = sum(y in x[:10]
                         for y, x in enumerate(lookup[1])) / len(lookup[1])
        correct_1 = sum(y == x[0]
                        for y, x in enumerate(lookup[1])) / len(lookup[1])
        print(f'Test loss: {np.mean(epoch_losses):.4f}')
        print(
            'Test accuracy:\n    top1 {}\n    top10 {}\n    top50 {}\n    top100 {}'
            .format(correct_1, correct_10, correct_50, correct_100))
        if writer:
            writer.add_scalars('Accuracy', {
                'top1': correct_1,
                'top10': correct_10,
                'top50': correct_50,
                'top100': correct_100,
            },
                               global_step=epoch)
            writer.add_scalar('Loss/test',
                              np.mean(epoch_losses),
                              global_step=epoch)
            if epoch == -1 or epoch % 5 == 1:
                mat = torch.cat([embeddings_x[:1000], embeddings_y[:1000]],
                                dim=0)
                labels = list(range(1000)) + list(range(1000))
                writer.add_embedding(mat,
                                     labels,
                                     tag='Embeddings',
                                     global_step=epoch)
        return correct_1 * 100, correct_100 * 100, lookup, embeddings_x, embeddings_y
コード例 #2
0
ファイル: test.py プロジェクト: apardyl/ml-audio-recognition
def evaluate_all(small_encoder: SmallEncoder, large_encoder: LargeEncoder, test_loader: torch.utils.data.DataLoader,
                 lookup_samples: int):
    with torch.no_grad():
        small_encoder.eval()
        large_encoder.eval()
        searcher = Searcher.get_simple_index(small_encoder.embedding_dim)
        s_embeddings_x = []
        s_embeddings_y = []
        l_embeddings_x = []
        l_embeddings_y = []
        print('Calculating embeddings')
        for step, (x_s, y_s, x_l, y_l) in enumerate(test_loader):
            s_embeddings_x.append(small_encoder(x_s.cuda()))
            s_embeddings_y.append(small_encoder(y_s.cuda()))
            l_embeddings_x.append(large_encoder(x_l.cuda()))
            l_embeddings_y.append(large_encoder(y_l.cuda()))
            print('    Test batch {} of {}'.format(step + 1, len(test_loader)), file=sys.stderr)
        print('Merging results')
        s_embeddings_x = torch.cat(s_embeddings_x, dim=0).cpu()
        s_embeddings_y = torch.cat(s_embeddings_y, dim=0).cpu()
        l_embeddings_x = torch.cat(l_embeddings_x, dim=0).cpu()
        l_embeddings_y = torch.cat(l_embeddings_y, dim=0).cpu()
        print('Running kNN')
        searcher.add(s_embeddings_x)
        lookup = searcher.search(s_embeddings_y, lookup_samples)
        correct_100 = sum(y in x for y, x in enumerate(lookup[1]))
        correct_1 = sum(y == x[0] for y, x in enumerate(lookup[1]))
        print('Running verification')
        verified_1 = 0
        verified_1l = 0
        s_embeddings_x = s_embeddings_x.numpy()
        s_embeddings_y = s_embeddings_y.numpy()
        l_embeddings_x = l_embeddings_x.numpy()
        l_embeddings_y = l_embeddings_y.numpy()
        for idx, (knn, y_s, y_l) in enumerate(zip(lookup[1], s_embeddings_y, l_embeddings_y)):
            dists = [((((s_embeddings_x[v] - y_s) ** 2).mean() + ((l_embeddings_x[v] - y_l) ** 2).mean()), v)
                     for v in knn]
            best = min(dists, key=itemgetter(0))[1]
            dists_l = [(((l_embeddings_x[v] - y_l) ** 2).mean(), v) for v in knn]
            best_l = min(dists_l, key=itemgetter(0))[1]
            if best == idx:
                verified_1 += 1
            if best_l == idx:
                verified_1l += 1
        print('Lookup accuracy: {}, correct guess: {}'.format(correct_100 / len(lookup[1]), correct_1 / len(lookup[1])))
        print('Verification accuracy: (single encoder) {}, (dual encoder) {}'.format(verified_1l / correct_100,
                                                                                     verified_1 / correct_100))
        print('Final accuracy: (single encoder) {}, (dual encoder) {}'.format(verified_1l / len(lookup[1]),
                                                                              verified_1 / len(lookup[1])))
コード例 #3
0
    parser.add_argument('--index',
                        type=str,
                        help='Track lookup index save location',
                        required=True)
    args = parser.parse_args()

    with contextlib.suppress(FileNotFoundError):
        os.remove(args.database)
    with contextlib.suppress(FileNotFoundError):
        os.remove(args.index)

    database = sqlite3.connect(args.database)
    database.execute(
        """create table samples (id integer primary key, name text, offset integer, s_hash blob, l_hash blob);"""
    )
    searcher = Searcher.get_simple_index(SmallEncoder.embedding_dim)

    data = AudioIndexingDataset(args.data)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=None,
                                              num_workers=4,
                                              prefetch_factor=2)
    small_encoder = SmallEncoder().cuda()
    load_model_state(args.small_encoder, small_encoder)
    large_encoder = LargeEncoder().cuda()
    load_model_state(args.large_encoder, large_encoder)
    large_encoder.eval()
    small_encoder.eval()
    counter = 0

    embeddings = []