Ejemplo n.º 1
0
def test(model, epoch, writer, xvector_dir):
    this_xvector_dir = "%s/test/epoch_%s" % (xvector_dir, epoch)

    extract_loader = torch.utils.data.DataLoader(extract_dir,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 **kwargs)
    verification_extract(extract_loader, model, this_xvector_dir, epoch)

    verify_dir = ScriptVerifyDataset(dir=args.test_dir,
                                     trials_file=args.trials,
                                     xvectors_dir=this_xvector_dir,
                                     loader=read_vec_flt)
    verify_loader = torch.utils.data.DataLoader(verify_dir,
                                                batch_size=128,
                                                shuffle=False,
                                                **kwargs)
    eer, eer_threshold, mindcf_01, mindcf_001 = verification_test(
        test_loader=verify_loader,
        dist_type=('cos' if args.cos_sim else 'l2'),
        log_interval=args.log_interval,
        xvector_dir=this_xvector_dir,
        epoch=epoch)
    print(
        '\33[91mTest  ERR: {:.4f}%, Threshold: {:.4f}, mindcf-0.01: {:.4f}, mindcf-0.001: {:.4f}.\33[0m\n'
        .format(100. * eer, eer_threshold, mindcf_01, mindcf_001))

    writer.add_scalar('Test/EER', 100. * eer, epoch)
    writer.add_scalar('Test/Threshold', eer_threshold, epoch)
    writer.add_scalar('Test/mindcf-0.01', mindcf_01, epoch)
    writer.add_scalar('Test/mindcf-0.001', mindcf_001, epoch)
Ejemplo n.º 2
0
def valid_test(train_extract_loader, model, epoch, xvector_dir):
    # switch to evaluate mode
    model.eval()

    this_xvector_dir = "%s/train/epoch_%s" % (xvector_dir, epoch)
    verification_extract(train_extract_loader, model, this_xvector_dir, epoch)

    verify_dir = ScriptVerifyDataset(dir=args.train_test_dir,
                                     trials_file=args.train_trials,
                                     xvectors_dir=this_xvector_dir,
                                     loader=read_vec_flt)
    verify_loader = torch.utils.data.DataLoader(verify_dir,
                                                batch_size=128,
                                                shuffle=False,
                                                **kwargs)
    eer, eer_threshold, mindcf_01, mindcf_001 = verification_test(
        test_loader=verify_loader,
        dist_type=('cos' if args.cos_sim else 'l2'),
        log_interval=args.log_interval,
        xvector_dir=this_xvector_dir,
        epoch=epoch)

    print('Test  Epoch {}:\n\33[91mTrain EER: {:.4f}%, Threshold: {:.4f}, ' \
          'mindcf-0.01: {:.4f}, mindcf-0.001: {:.4f}.'.format(epoch,
                                                              100. * eer,
                                                              eer_threshold,
                                                              mindcf_01,
                                                              mindcf_001))

    writer.add_scalar('Train/EER', 100. * eer, epoch)
    writer.add_scalar('Train/Threshold', eer_threshold, epoch)
    writer.add_scalar('Train/mindcf-0.01', mindcf_01, epoch)
    writer.add_scalar('Train/mindcf-0.001', mindcf_001, epoch)

    torch.cuda.empty_cache()
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    # print('Number of Speakers: {}.\n'.format(train_dir.num_spks))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]
    padding = [int((x - 1) / 2) for x in kernel_size]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {'embedding_size': args.embedding_size,
                    'resnet_size': args.resnet_size,
                    'inst_norm': args.inst_norm,
                    'input_dim': args.feat_dim,
                    'fast': args.fast,
                    'num_classes': train_dir.num_spks,
                    'alpha': args.alpha,
                    'channels': channels,
                    'stride': args.stride,
                    'avg_size': args.avg_size,
                    'time_dim': args.time_dim,
                    'encoder_type': args.encoder_type,
                    'kernel_size': kernel_size,
                    'padding': padding,
                    'dropout_p': args.dropout_p}

    print('Model options: {}'.format(model_kwargs))
    if args.valid or args.extract:
        model = create_model(args.model, **model_kwargs)
        if args.loss_type == 'asoft':
            model.classifier = AngleLinear(in_features=args.embedding_size, out_features=train_dir.num_spks, m=args.m)
        elif args.loss_type == 'amsoft':
            model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=train_dir.num_spks)

        assert os.path.isfile(args.resume)
        print('=> loading checkpoint {}'.format(args.resume))
        checkpoint = torch.load(args.resume)
        # start_epoch = checkpoint['epoch']

        filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
        # model_dict = model.state_dict()
        # model_dict.update(filtered)
        model.load_state_dict(filtered)
        #
        try:
            model.dropout.p = args.dropout_p
        except:
            pass
        start = args.start_epoch
        print('Epoch is : ' + str(start))

        if args.cuda:
            model.cuda()
        # train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=True, **kwargs)
        if args.valid:
            valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.test_batch_size, shuffle=False,
                                                       **kwargs)
            valid(valid_loader, model)

        if args.extract:
            verify_loader = torch.utils.data.DataLoader(verfify_dir, batch_size=args.test_batch_size, shuffle=False,
                                                        **kwargs)
            extract(verify_loader, model, args.xvector_dir)

    file_loader = read_vec_flt
    test_dir = ScriptVerifyDataset(dir=args.test_dir, trials_file=args.trials,
                                   xvectors_dir=args.xvector_dir, loader=file_loader)
    test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size * 64, shuffle=False, **kwargs)
    test(test_loader)
        # train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=True, **kwargs)

        if args.valid:
            valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.test_batch_size, shuffle=False,
                                                       **kwargs)
            valid(valid_loader, model)

        del train_dir  # , valid_dir
        print('Memery Usage: %.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))

        if args.extract:
            verify_loader = torch.utils.data.DataLoader(verfify_dir, batch_size=args.test_batch_size, shuffle=False,
                                                        **kwargs)
            extract(verify_loader, model, args.xvector_dir)

    file_loader = read_vec_flt
    test_dir = ScriptVerifyDataset(dir=args.test_dir, trials_file=args.trials, xvectors_dir=args.xvector_dir,
                                   loader=file_loader)
    test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size * 64, shuffle=False, **kwargs)
    test(test_loader)

# python TrainAndTest/Spectrogram/train_surescnn10_kaldi.py > Log/SuResCNN10/spect_161/

# test easy spectrogram soft 161 vox1
#   Test ERR is 1.6076%, Threshold is 0.31004807353019714
#   mindcf-0.01 0.2094, mindcf-0.001 0.3767.

# test hard spectrogram soft 161 vox1
#   Test ERR is 2.9182%, Threshold is 0.35036733746528625
#   mindcf-0.01 0.3369, mindcf-0.001 0.5494.