Beispiel #1
0
def test(model, epoch, writer, xvector_dir):
    this_xvector_dir = "%s/test/epoch_%s" % (xvector_dir, epoch)

    extract_loader = torch.utils.data.DataLoader(extract_dir,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 **kwargs)
    verification_extract(extract_loader, model, this_xvector_dir, epoch)

    verify_dir = ScriptVerifyDataset(dir=args.test_dir,
                                     trials_file=args.trials,
                                     xvectors_dir=this_xvector_dir,
                                     loader=read_vec_flt)
    verify_loader = torch.utils.data.DataLoader(verify_dir,
                                                batch_size=128,
                                                shuffle=False,
                                                **kwargs)
    eer, eer_threshold, mindcf_01, mindcf_001 = verification_test(
        test_loader=verify_loader,
        dist_type=('cos' if args.cos_sim else 'l2'),
        log_interval=args.log_interval,
        xvector_dir=this_xvector_dir,
        epoch=epoch)
    print(
        '\33[91mTest  ERR: {:.4f}%, Threshold: {:.4f}, mindcf-0.01: {:.4f}, mindcf-0.001: {:.4f}.\33[0m\n'
        .format(100. * eer, eer_threshold, mindcf_01, mindcf_001))

    writer.add_scalar('Test/EER', 100. * eer, epoch)
    writer.add_scalar('Test/Threshold', eer_threshold, epoch)
    writer.add_scalar('Test/mindcf-0.01', mindcf_01, epoch)
    writer.add_scalar('Test/mindcf-0.001', mindcf_001, epoch)
Beispiel #2
0
def valid_test(train_extract_loader, model, epoch, xvector_dir):
    # switch to evaluate mode
    model.eval()

    this_xvector_dir = "%s/train/epoch_%s" % (xvector_dir, epoch)
    verification_extract(train_extract_loader, model, this_xvector_dir, epoch)

    verify_dir = ScriptVerifyDataset(dir=args.train_test_dir,
                                     trials_file=args.train_trials,
                                     xvectors_dir=this_xvector_dir,
                                     loader=read_vec_flt)
    verify_loader = torch.utils.data.DataLoader(verify_dir,
                                                batch_size=128,
                                                shuffle=False,
                                                **kwargs)
    eer, eer_threshold, mindcf_01, mindcf_001 = verification_test(
        test_loader=verify_loader,
        dist_type=('cos' if args.cos_sim else 'l2'),
        log_interval=args.log_interval,
        xvector_dir=this_xvector_dir,
        epoch=epoch)

    print('Test  Epoch {}:\n\33[91mTrain EER: {:.4f}%, Threshold: {:.4f}, ' \
          'mindcf-0.01: {:.4f}, mindcf-0.001: {:.4f}.'.format(epoch,
                                                              100. * eer,
                                                              eer_threshold,
                                                              mindcf_01,
                                                              mindcf_001))

    writer.add_scalar('Train/EER', 100. * eer, epoch)
    writer.add_scalar('Train/Threshold', eer_threshold, epoch)
    writer.add_scalar('Train/mindcf-0.01', mindcf_01, epoch)
    writer.add_scalar('Train/mindcf-0.001', mindcf_001, epoch)

    torch.cuda.empty_cache()
def valid_test(train_extract_loader, valid_loader, model, epoch, xvector_dir):
    # switch to evaluate mode
    model.eval()

    valid_loader_a, valid_loader_b = valid_loader
    valid_pbar = tqdm(enumerate(zip(valid_loader_a, valid_loader_b)))
    correct_a = 0.
    correct_b = 0.

    total_datasize_a = 0.
    total_datasize_b = 0.

    softmax = nn.Softmax(dim=1)
    with torch.no_grad():
        for batch_idx, ((data_a, label_a), (data_b, label_b)) in valid_pbar:

            label_a = label_a.cuda()
            label_b = label_b.cuda()

            # compute output
            data = torch.cat((data_a, data_b), dim=0)
            data = data.cuda()

            _, feats = model(data)
            classfier_a, classfier_b = model.cls_forward(feats[:len(data_a)], feats[len(data_a):])

            # pdb.set_trace()
            predicted_labels = softmax(classfier_a)
            predicted_one_labels = torch.max(predicted_labels, dim=1)[1]
            minibatch_correct = float((predicted_one_labels.cuda() == label_a).sum().item())
            minibatch_a = minibatch_correct / len(predicted_one_labels)
            correct_a += minibatch_correct
            total_datasize_a += len(predicted_one_labels)

            predicted_labels = softmax(classfier_b)
            predicted_one_labels = torch.max(predicted_labels, dim=1)[1]
            minibatch_correct = float((predicted_one_labels.cuda() == label_b).sum().item())
            minibatch_b = minibatch_correct / len(predicted_one_labels)
            correct_b += minibatch_correct
            total_datasize_b += len(predicted_one_labels)

            if batch_idx % args.log_interval == 0:
                valid_pbar.set_description(
                    'Valid Epoch: {:2d} for {:4d} Batch Accuracy: A set: {:.4f}%, B set: {:.4f}%'.format(
                        epoch,
                    len(valid_loader_a.dataset),
                        100. * minibatch_a,
                        100. * minibatch_b
                    ))
                # break

    valid_accuracy_a = 100. * correct_a / total_datasize_a
    valid_accuracy_b = 100. * correct_b / total_datasize_b
    writer.add_scalar('Train/Valid_Accuracy_A', valid_accuracy_a, epoch)
    writer.add_scalar('Train/Valid_Accuracy_B', valid_accuracy_b, epoch)

    torch.cuda.empty_cache()

    this_xvector_dir = "%s/train/epoch_%s" % (xvector_dir, epoch)
    verification_extract(train_extract_loader, model, this_xvector_dir, epoch)

    verify_dir = ScriptVerifyDataset(dir=args.train_test_dir, trials_file=args.train_trials,
                                     xvectors_dir=this_xvector_dir,
                                     loader=read_vec_flt)
    verify_loader = torch.utils.data.DataLoader(verify_dir, batch_size=128, shuffle=False, **kwargs)
    eer, eer_threshold, mindcf_01, mindcf_001 = verification_test(test_loader=verify_loader,
                                                                  dist_type=('cos' if args.cos_sim else 'l2'),
                                                                  log_interval=args.log_interval,
                                                                  xvector_dir=this_xvector_dir,
                                                                  epoch=epoch)

    print('Test  Epoch {}:\n\33[91mTrain EER: {:.4f}%, Threshold: {:.4f}, ' \
          'mindcf-0.01: {:.4f}, mindcf-0.001: {:.4f}.'.format(epoch,
                                                              100. * eer,
                                                              eer_threshold,
                                                              mindcf_01,
                                                              mindcf_001))

    print('Valid on A Accuracy: %.4f %%. Valid on B Accuracy: %.4f %%.\33[0m' % (
        valid_accuracy_a, valid_accuracy_b))

    writer.add_scalar('Train/EER', 100. * eer, epoch)
    writer.add_scalar('Train/Threshold', eer_threshold, epoch)
    writer.add_scalar('Train/mindcf-0.01', mindcf_01, epoch)
    writer.add_scalar('Train/mindcf-0.001', mindcf_001, epoch)

    torch.cuda.empty_cache()
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Speakers: {}.\n'.format(train_dir.num_spks))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]
    padding = [int((x - 1) / 2) for x in kernel_size]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)
    stride = args.stride.split(',')
    stride = [int(x) for x in stride]

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {
        'input_dim': args.input_dim,
        'feat_dim': args.feat_dim,
        'kernel_size': kernel_size,
        'filter': args.filter,
        'inst_norm': args.inst_norm,
        'stride': stride,
        'fast': args.fast,
        'avg_size': args.avg_size,
        'time_dim': args.time_dim,
        'padding': padding,
        'encoder_type': args.encoder_type,
        'vad': args.vad,
        'embedding_size': args.embedding_size,
        'ince': args.inception,
        'resnet_size': args.resnet_size,
        'num_classes': train_dir.num_spks,
        'channels': channels,
        'alpha': args.alpha,
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))
    model = create_model(args.model, **model_kwargs)

    start_epoch = 0
    if args.save_init and not args.finetune:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path,
                                                   start_epoch)
        torch.save(model, check_path)

    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            filtered = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'num_batches_tracked' not in k
            }
            model_dict = model.state_dict()
            model_dict.update(filtered)
            model.load_state_dict(model_dict)
            #
            # model.dropout.p = args.dropout_p
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier = AngleLinear(in_features=args.embedding_size,
                                       out_features=train_dir.num_spks,
                                       m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min,
                                        lambda_max=args.lambda_max)
    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=train_dir.num_spks,
                                  feat_dim=args.embedding_size)
    elif args.loss_type == 'amsoft':
        ce_criterion = None
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size,
                                                n_classes=train_dir.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)

    optimizer = create_optimizer(model.parameters(), args.optimizer,
                                 **opt_kwargs)
    if args.loss_type == 'center':
        optimizer = torch.optim.SGD([{
            'params': xe_criterion.parameters(),
            'lr': args.lr * 5
        }, {
            'params': model.parameters()
        }],
                                    lr=args.lr,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)
    if args.finetune:
        if args.loss_type == 'asoft' or args.loss_type == 'amsoft':
            classifier_params = list(map(id, model.classifier.parameters()))
            rest_params = filter(lambda p: id(p) not in classifier_params,
                                 model.parameters())
            optimizer = torch.optim.SGD(
                [{
                    'params': model.classifier.parameters(),
                    'lr': args.lr * 10
                }, {
                    'params': rest_params
                }],
                lr=args.lr,
                weight_decay=args.weight_decay,
                momentum=args.momentum)
    if args.filter:
        filter_params = list(map(id, model.filter_layer.parameters()))
        rest_params = filter(lambda p: id(p) not in filter_params,
                             model.parameters())
        optimizer = torch.optim.SGD([{
            'params': model.filter_layer.parameters(),
            'lr': args.lr * 0.05
        }, {
            'params': rest_params
        }],
                                    lr=args.lr,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)

    if args.scheduler == 'exp':
        scheduler = ExponentialLR(optimizer, gamma=args.gamma)
    else:
        milestones = args.milestones.split(',')
        milestones = [int(x) for x in milestones]
        milestones.sort()
        scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    ce = [ce_criterion, xe_criterion]

    start = args.start_epoch + start_epoch
    print('Start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    train_loader = torch.utils.data.DataLoader(train_dir,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir,
                                               batch_size=int(args.batch_size /
                                                              2),
                                               shuffle=False,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dir,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)
    # sitw_test_loader = torch.utils.data.DataLoader(sitw_test_dir, batch_size=args.test_batch_size,
    #                                                shuffle=False, **kwargs)
    # sitw_dev_loader = torch.utils.data.DataLoader(sitw_dev_part, batch_size=args.test_batch_size, shuffle=False,
    #                                               **kwargs)

    if args.cuda:
        model = model.cuda()
        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()
        try:
            print('Dropout is {}.'.format(model.dropout_p))
        except:
            pass

    # for epoch in range(start, end):
    #     # pdb.set_trace()
    #     print('\n\33[1;34m Current \'{}\' learning rate is '.format(args.optimizer), end='')
    #     for param_group in optimizer.param_groups:
    #         print('{:.5f} '.format(param_group['lr']), end='')
    #     print(' \33[0m')
    #
    #     train(train_loader, model, ce, optimizer, epoch)
    #     if epoch % 4 == 1 or epoch == (end - 1):
    #         check_path = '{}/checkpoint_{}.pth'.format(args.check_path, epoch)
    #         torch.save({'epoch': epoch,
    #                     'state_dict': model.state_dict(),
    #                     'criterion': ce},
    #                    check_path)
    #
    #     if epoch % 2 == 1 and epoch != (end - 1):
    #         test(test_loader, valid_loader, model, epoch)
    #     # sitw_test(sitw_test_loader, model, epoch)
    #     # sitw_test(sitw_dev_loader, model, epoch)
    #     scheduler.step()

    # exit(1)
    xvector_dir = args.check_path
    xvector_dir = xvector_dir.replace('checkpoint', 'xvector')

    if args.extract:
        extract_dir = KaldiExtractDataset(dir=args.test_dir,
                                          transform=transform_V,
                                          filer_loader=file_loader)
        extract_loader = torch.utils.data.DataLoader(extract_dir,
                                                     batch_size=1,
                                                     shuffle=False,
                                                     **kwargs)
        verification_extract(extract_loader, model, xvector_dir)

    verify_dir = ScriptVerifyDataset(dir=args.test_dir,
                                     trials_file=args.trials,
                                     xvectors_dir=xvector_dir,
                                     loader=read_vec_flt)
    verify_loader = torch.utils.data.DataLoader(verify_dir,
                                                batch_size=64,
                                                shuffle=False,
                                                **kwargs)
    verification_test(test_loader=verify_loader,
                      dist_type=('cos' if args.cos_sim else 'l2'),
                      log_interval=args.log_interval,
                      save=args.save_score)

    writer.close()
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive

    # print the experiment configuration
    print('\nCurrent time is\33[91m {}\33[0m.'.format(str(time.asctime())))
    opts = vars(args)
    keys = list(opts.keys())
    keys.sort()

    options = []
    for k in keys:
        options.append("\'%s\': \'%s\'" % (str(k), str(opts[k])))

    print('Parsed options: \n{ %s }' % (', '.join(options)))
    print('Number of Speakers in training set: {}\n'.format(
        train_config_dir.num_spks))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]

    context = args.context.split(',')
    context = [int(x) for x in context]
    if args.padding == '':
        padding = [int((x - 1) / 2) for x in kernel_size]
    else:
        padding = args.padding.split(',')
        padding = [int(x) for x in padding]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)
    stride = args.stride.split(',')
    stride = [int(x) for x in stride]

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {
        'input_dim': args.input_dim,
        'feat_dim': args.feat_dim,
        'kernel_size': kernel_size,
        'context': context,
        'filter_fix': args.filter_fix,
        'mask': args.mask_layer,
        'mask_len': args.mask_len,
        'block_type': args.block_type,
        'filter': args.filter,
        'exp': args.exp,
        'inst_norm': args.inst_norm,
        'input_norm': args.input_norm,
        'stride': stride,
        'fast': args.fast,
        'avg_size': args.avg_size,
        'time_dim': args.time_dim,
        'padding': padding,
        'encoder_type': args.encoder_type,
        'vad': args.vad,
        'transform': args.transform,
        'embedding_size': args.embedding_size,
        'ince': args.inception,
        'resnet_size': args.resnet_size,
        'num_classes': train_config_dir.num_spks,
        'channels': channels,
        'alpha': args.alpha,
        'dropout_p': args.dropout_p,
        'loss_type': args.loss_type,
        'm': args.m,
        'margin': args.margin,
        's': args.s,
        'all_iteraion': args.all_iteraion
    }

    print('Model options: {}'.format(model_kwargs))
    model = create_model(args.model, **model_kwargs)

    # optionally resume from a checkpoint
    # resume = args.ckp_dir + '/checkpoint_{}.pth'.format(args.epoch)
    assert os.path.isfile(args.resume), print(
        '=> no checkpoint found at {}'.format(args.resume))

    print('=> loading checkpoint {}'.format(args.resume))
    checkpoint = torch.load(args.resume)
    epoch = checkpoint['epoch']

    checkpoint_state_dict = checkpoint['state_dict']
    if isinstance(checkpoint_state_dict, tuple):
        checkpoint_state_dict = checkpoint_state_dict[0]
    filtered = {
        k: v
        for k, v in checkpoint_state_dict.items()
        if 'num_batches_tracked' not in k
    }

    # filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
    if list(filtered.keys())[0].startswith('module'):
        new_state_dict = OrderedDict()
        for k, v in filtered.items():
            name = k[7:]  # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module.
            new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。

        model.load_state_dict(new_state_dict)
    else:
        model_dict = model.state_dict()
        model_dict.update(filtered)
        model.load_state_dict(model_dict)
    # model.dropout.p = args.dropout_p

    if args.cuda:
        model.cuda()

    extracted_set = []

    vec_type = 'xvectors_a' if args.xvector else 'xvectors_b'
    if args.train_dir != '':
        train_dir = KaldiExtractDataset(dir=args.train_dir,
                                        filer_loader=file_loader,
                                        transform=transform_V,
                                        extract_trials=False)
        train_loader = torch.utils.data.DataLoader(train_dir,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   **kwargs)
        # Extract Train set vectors
        # extract(train_loader, model, dataset='train', extract_path=args.extract_path + '/x_vector')
        train_xvector_dir = args.xvector_dir + '/%s/epoch_%d/train' % (
            vec_type, epoch)
        verification_extract(train_loader,
                             model,
                             train_xvector_dir,
                             epoch=epoch,
                             test_input=args.test_input,
                             verbose=True,
                             xvector=args.xvector)
        # copy wav.scp and utt2spk ...
        extracted_set.append('train')

    assert args.test_dir != ''
    test_dir = KaldiExtractDataset(dir=args.test_dir,
                                   filer_loader=file_loader,
                                   transform=transform_V,
                                   extract_trials=False)
    test_loader = torch.utils.data.DataLoader(test_dir,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              **kwargs)

    # Extract test set vectors
    test_xvector_dir = args.xvector_dir + '/%s/epoch_%d/test' % (vec_type,
                                                                 epoch)
    # extract(test_loader, model, set_id='test', extract_path=args.extract_path + '/x_vector')
    verification_extract(test_loader,
                         model,
                         test_xvector_dir,
                         epoch=epoch,
                         test_input=args.test_input,
                         verbose=True,
                         xvector=args.xvector)
    # copy wav.scp and utt2spk ...
    extracted_set.append('test')

    if len(extracted_set) > 0:
        print('Extract x-vector completed for %s in %s!\n' %
              (','.join(extracted_set), args.xvector_dir + '/%s' % vec_type))
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Classes: {}\n'.format(train_dir.num_spks))

    # instantiate
    # model and initialize weights
    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]
    padding = [int((x - 1) / 2) for x in kernel_size]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)

    model_kwargs = {'input_dim': args.feat_dim,
                    'kernel_size': kernel_size,
                    'stride': args.stride,
                    'avg_size': args.avg_size,
                    'time_dim': args.time_dim,
                    'padding': padding,
                    'resnet_size': args.resnet_size,
                    'embedding_size': args.embedding_size,
                    'num_classes': len(train_dir.speakers),
                    'dropout_p': args.dropout_p}

    print('Model options: {}'.format(model_kwargs))

    model = create_model(args.model, **model_kwargs)

    start = 1
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start = checkpoint['epoch']
            checkpoint = torch.load(args.resume)
            filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            model.load_state_dict(filtered)
            # optimizer.load_state_dict(checkpoint['optimizer'])
            # scheduler.load_state_dict(checkpoint['scheduler'])
            # criterion.load_state_dict(checkpoint['criterion'])
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier = AngleLinear(in_features=args.embedding_size, out_features=train_dir.num_spks, m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min, lambda_max=args.lambda_max)
    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size)
    elif args.loss_type == 'amsoft':
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=train_dir.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)

    if args.cuda:
        model.cuda()

    optimizer = create_optimizer(model.parameters(), args.optimizer, **opt_kwargs)
    if args.loss_type == 'center':
        optimizer = torch.optim.SGD([{'params': xe_criterion.parameters(), 'lr': args.lr * 5},
                                     {'params': model.parameters()}],
                                    lr=args.lr, weight_decay=args.weight_decay,
                                    momentum=args.momentum)

    milestones = args.milestones.split(',')
    milestones = [int(x) for x in milestones]
    milestones.sort()
    # print('Scheduler options: {}'.format(milestones))
    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    if args.save_init and not args.finetune:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start)
        torch.save({'epoch': start, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict()}, check_path)

    start += args.start_epoch
    print('Start epoch is : ' + str(start))
    end = args.epochs + 1

    # pdb.set_trace()
    train_loader = torch.utils.data.DataLoader(train_dir,
                                               batch_size=args.batch_size,
                                               shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir,
                                               batch_size=int(args.batch_size / 2),
                                               shuffle=False, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dir,
                                              batch_size=args.test_batch_size,
                                              shuffle=False, **kwargs)

    ce = [ce_criterion, xe_criterion]
    if args.cuda:
        model = model.cuda()
        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()

    for epoch in range(start, end):
        # pdb.set_trace()
        print('\n\33[1;34m Current \'{}\' learning rate is '.format(args.optimizer), end='')
        for param_group in optimizer.param_groups:
            print('{:.5f} '.format(param_group['lr']), end='')
        print(' \33[0m')

        train(train_loader, model, optimizer, ce, scheduler, epoch)
        test(test_loader, valid_loader, model, epoch)

        scheduler.step()
        # break
    verfify_dir = KaldiExtractDataset(dir=args.test_dir, transform=transform_T, filer_loader=file_loader)
    verify_loader = torch.utils.data.DataLoader(verfify_dir, batch_size=args.test_batch_size, shuffle=False,
                                                **kwargs)
    verification_extract(verify_loader, model, args.xvector_dir)
    file_loader = read_vec_flt
    test_dir = ScriptVerifyDataset(dir=args.test_dir, trials_file=args.trials,
                                   xvectors_dir=args.xvector_dir, loader=file_loader)
    test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size * 64, shuffle=False, **kwargs)
    verification_test(test_loader=test_loader, dist_type='cos' if args.cos_sim else 'l2',
                      log_interval=args.log_interval)

    writer.close()