def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive

    # print the experiment configuration
    print('\nCurrent time is\33[91m {}\33[0m.'.format(str(time.asctime())))
    opts = vars(args)
    keys = list(opts.keys())
    keys.sort()

    options = []
    for k in keys:
        options.append("\'%s\': \'%s\'" % (str(k), str(opts[k])))

    print('Parsed options: \n{ %s }' % (', '.join(options)))
    print('Number of Speakers in training set: {}\n'.format(
        train_config_dir.num_spks))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]

    context = args.context.split(',')
    context = [int(x) for x in context]
    if args.padding == '':
        padding = [int((x - 1) / 2) for x in kernel_size]
    else:
        padding = args.padding.split(',')
        padding = [int(x) for x in padding]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)
    stride = args.stride.split(',')
    stride = [int(x) for x in stride]

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {
        'input_dim': args.input_dim,
        'feat_dim': args.feat_dim,
        'kernel_size': kernel_size,
        'context': context,
        'filter_fix': args.filter_fix,
        'mask': args.mask_layer,
        'mask_len': args.mask_len,
        'block_type': args.block_type,
        'filter': args.filter,
        'exp': args.exp,
        'inst_norm': args.inst_norm,
        'input_norm': args.input_norm,
        'stride': stride,
        'fast': args.fast,
        'avg_size': args.avg_size,
        'time_dim': args.time_dim,
        'padding': padding,
        'encoder_type': args.encoder_type,
        'vad': args.vad,
        'transform': args.transform,
        'embedding_size': args.embedding_size,
        'ince': args.inception,
        'resnet_size': args.resnet_size,
        'num_classes': train_config_dir.num_spks,
        'channels': channels,
        'alpha': args.alpha,
        'dropout_p': args.dropout_p,
        'loss_type': args.loss_type,
        'm': args.m,
        'margin': args.margin,
        's': args.s,
        'all_iteraion': args.all_iteraion
    }

    print('Model options: {}'.format(model_kwargs))
    model = create_model(args.model, **model_kwargs)

    # optionally resume from a checkpoint
    # resume = args.ckp_dir + '/checkpoint_{}.pth'.format(args.epoch)
    assert os.path.isfile(args.resume), print(
        '=> no checkpoint found at {}'.format(args.resume))

    print('=> loading checkpoint {}'.format(args.resume))
    checkpoint = torch.load(args.resume)
    epoch = checkpoint['epoch']

    checkpoint_state_dict = checkpoint['state_dict']
    if isinstance(checkpoint_state_dict, tuple):
        checkpoint_state_dict = checkpoint_state_dict[0]
    filtered = {
        k: v
        for k, v in checkpoint_state_dict.items()
        if 'num_batches_tracked' not in k
    }

    # filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
    if list(filtered.keys())[0].startswith('module'):
        new_state_dict = OrderedDict()
        for k, v in filtered.items():
            name = k[7:]  # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module.
            new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。

        model.load_state_dict(new_state_dict)
    else:
        model_dict = model.state_dict()
        model_dict.update(filtered)
        model.load_state_dict(model_dict)
    # model.dropout.p = args.dropout_p

    if args.cuda:
        model.cuda()

    extracted_set = []

    vec_type = 'xvectors_a' if args.xvector else 'xvectors_b'
    if args.train_dir != '':
        train_dir = KaldiExtractDataset(dir=args.train_dir,
                                        filer_loader=file_loader,
                                        transform=transform_V,
                                        extract_trials=False)
        train_loader = torch.utils.data.DataLoader(train_dir,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   **kwargs)
        # Extract Train set vectors
        # extract(train_loader, model, dataset='train', extract_path=args.extract_path + '/x_vector')
        train_xvector_dir = args.xvector_dir + '/%s/epoch_%d/train' % (
            vec_type, epoch)
        verification_extract(train_loader,
                             model,
                             train_xvector_dir,
                             epoch=epoch,
                             test_input=args.test_input,
                             verbose=True,
                             xvector=args.xvector)
        # copy wav.scp and utt2spk ...
        extracted_set.append('train')

    assert args.test_dir != ''
    test_dir = KaldiExtractDataset(dir=args.test_dir,
                                   filer_loader=file_loader,
                                   transform=transform_V,
                                   extract_trials=False)
    test_loader = torch.utils.data.DataLoader(test_dir,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              **kwargs)

    # Extract test set vectors
    test_xvector_dir = args.xvector_dir + '/%s/epoch_%d/test' % (vec_type,
                                                                 epoch)
    # extract(test_loader, model, set_id='test', extract_path=args.extract_path + '/x_vector')
    verification_extract(test_loader,
                         model,
                         test_xvector_dir,
                         epoch=epoch,
                         test_input=args.test_input,
                         verbose=True,
                         xvector=args.xvector)
    # copy wav.scp and utt2spk ...
    extracted_set.append('test')

    if len(extracted_set) > 0:
        print('Extract x-vector completed for %s in %s!\n' %
              (','.join(extracted_set), args.xvector_dir + '/%s' % vec_type))
    transform_T.transforms.append(mvnormal())

# pdb.set_trace()
if args.feat_format == 'kaldi':
    file_loader = read_mat
    torch.multiprocessing.set_sharing_strategy('file_system')
elif args.feat_format == 'npy':
    file_loader = np.load

if not args.valid:
    args.num_valid = 0

train_dir = ScriptTrainDataset(dir=args.train_dir, samples_per_speaker=args.input_per_spks, loader=file_loader,
                               transform=transform, num_valid=args.num_valid)

verfify_dir = KaldiExtractDataset(dir=args.test_dir, transform=transform_T, filer_loader=file_loader)

if args.valid:
<<<<<<< HEAD:TrainAndTest/test_vox1.py
    valid_dir = ScriptValidDataset(valid_set=train_dir.valid_set, loader=file_loader, spk_to_idx=train_dir.spk_to_idx,
                               valid_uid2feat=train_dir.valid_uid2feat, valid_utt2spk_dict=train_dir.valid_utt2spk_dict,
                               transform=transform)


def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    # print('Number of Speakers: {}.\n'.format(train_dir.num_spks))
示例#3
0
    file_loader = read_mat
elif args.feat_format == 'npy':
    file_loader = np.load
torch.multiprocessing.set_sharing_strategy('file_system')

train_dir_a = EgsDataset(dir=args.train_dir_a,
                         feat_dim=args.feat_dim,
                         loader=file_loader,
                         transform=transform)
train_dir_b = EgsDataset(dir=args.train_dir_b,
                         feat_dim=args.feat_dim,
                         loader=file_loader,
                         transform=transform)

train_extract_dir = KaldiExtractDataset(dir=args.train_test_dir,
                                        transform=transform_V,
                                        filer_loader=file_loader,
                                        trials_file=args.train_trials)
extract_dir = KaldiExtractDataset(dir=args.test_dir,
                                  transform=transform_V,
                                  filer_loader=file_loader)
# test_dir = ScriptTestDataset(dir=args.test_dir, loader=file_loader, transform=transform_T)

# if len(test_dir) < args.veri_pairs:
#     args.veri_pairs = len(test_dir)
#     print('There are %d verification pairs.' % len(test_dir))
# else:
#     test_dir.partition(args.veri_pairs)

valid_dir_a = EgsDataset(dir=args.valid_dir_a,
                         feat_dim=args.feat_dim,
                         loader=file_loader,