def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive # print the experiment configuration print('\nCurrent time is\33[91m {}\33[0m.'.format(str(time.asctime()))) opts = vars(args) keys = list(opts.keys()) keys.sort() options = [] for k in keys: options.append("\'%s\': \'%s\'" % (str(k), str(opts[k]))) print('Parsed options: \n{ %s }' % (', '.join(options))) print('Number of Speakers in training set: {}\n'.format( train_config_dir.num_spks)) # instantiate model and initialize weights kernel_size = args.kernel_size.split(',') kernel_size = [int(x) for x in kernel_size] context = args.context.split(',') context = [int(x) for x in context] if args.padding == '': padding = [int((x - 1) / 2) for x in kernel_size] else: padding = args.padding.split(',') padding = [int(x) for x in padding] kernel_size = tuple(kernel_size) padding = tuple(padding) stride = args.stride.split(',') stride = [int(x) for x in stride] channels = args.channels.split(',') channels = [int(x) for x in channels] model_kwargs = { 'input_dim': args.input_dim, 'feat_dim': args.feat_dim, 'kernel_size': kernel_size, 'context': context, 'filter_fix': args.filter_fix, 'mask': args.mask_layer, 'mask_len': args.mask_len, 'block_type': args.block_type, 'filter': args.filter, 'exp': args.exp, 'inst_norm': args.inst_norm, 'input_norm': args.input_norm, 'stride': stride, 'fast': args.fast, 'avg_size': args.avg_size, 'time_dim': args.time_dim, 'padding': padding, 'encoder_type': args.encoder_type, 'vad': args.vad, 'transform': args.transform, 'embedding_size': args.embedding_size, 'ince': args.inception, 'resnet_size': args.resnet_size, 'num_classes': train_config_dir.num_spks, 'channels': channels, 'alpha': args.alpha, 'dropout_p': args.dropout_p, 'loss_type': args.loss_type, 'm': args.m, 'margin': args.margin, 's': args.s, 'all_iteraion': args.all_iteraion } print('Model options: {}'.format(model_kwargs)) model = create_model(args.model, **model_kwargs) # optionally resume from a checkpoint # resume = args.ckp_dir + '/checkpoint_{}.pth'.format(args.epoch) assert os.path.isfile(args.resume), print( '=> no checkpoint found at {}'.format(args.resume)) print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) epoch = checkpoint['epoch'] checkpoint_state_dict = checkpoint['state_dict'] if isinstance(checkpoint_state_dict, tuple): checkpoint_state_dict = checkpoint_state_dict[0] filtered = { k: v for k, v in checkpoint_state_dict.items() if 'num_batches_tracked' not in k } # filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k} if list(filtered.keys())[0].startswith('module'): new_state_dict = OrderedDict() for k, v in filtered.items(): name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module. new_state_dict[name] = v # 新字典的key值对应的value为一一对应的值。 model.load_state_dict(new_state_dict) else: model_dict = model.state_dict() model_dict.update(filtered) model.load_state_dict(model_dict) # model.dropout.p = args.dropout_p if args.cuda: model.cuda() extracted_set = [] vec_type = 'xvectors_a' if args.xvector else 'xvectors_b' if args.train_dir != '': train_dir = KaldiExtractDataset(dir=args.train_dir, filer_loader=file_loader, transform=transform_V, extract_trials=False) train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=False, **kwargs) # Extract Train set vectors # extract(train_loader, model, dataset='train', extract_path=args.extract_path + '/x_vector') train_xvector_dir = args.xvector_dir + '/%s/epoch_%d/train' % ( vec_type, epoch) verification_extract(train_loader, model, train_xvector_dir, epoch=epoch, test_input=args.test_input, verbose=True, xvector=args.xvector) # copy wav.scp and utt2spk ... extracted_set.append('train') assert args.test_dir != '' test_dir = KaldiExtractDataset(dir=args.test_dir, filer_loader=file_loader, transform=transform_V, extract_trials=False) test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.batch_size, shuffle=False, **kwargs) # Extract test set vectors test_xvector_dir = args.xvector_dir + '/%s/epoch_%d/test' % (vec_type, epoch) # extract(test_loader, model, set_id='test', extract_path=args.extract_path + '/x_vector') verification_extract(test_loader, model, test_xvector_dir, epoch=epoch, test_input=args.test_input, verbose=True, xvector=args.xvector) # copy wav.scp and utt2spk ... extracted_set.append('test') if len(extracted_set) > 0: print('Extract x-vector completed for %s in %s!\n' % (','.join(extracted_set), args.xvector_dir + '/%s' % vec_type))
transform_T.transforms.append(mvnormal()) # pdb.set_trace() if args.feat_format == 'kaldi': file_loader = read_mat torch.multiprocessing.set_sharing_strategy('file_system') elif args.feat_format == 'npy': file_loader = np.load if not args.valid: args.num_valid = 0 train_dir = ScriptTrainDataset(dir=args.train_dir, samples_per_speaker=args.input_per_spks, loader=file_loader, transform=transform, num_valid=args.num_valid) verfify_dir = KaldiExtractDataset(dir=args.test_dir, transform=transform_T, filer_loader=file_loader) if args.valid: <<<<<<< HEAD:TrainAndTest/test_vox1.py valid_dir = ScriptValidDataset(valid_set=train_dir.valid_set, loader=file_loader, spk_to_idx=train_dir.spk_to_idx, valid_uid2feat=train_dir.valid_uid2feat, valid_utt2spk_dict=train_dir.valid_utt2spk_dict, transform=transform) def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive # test_display_triplet_distance = False # print the experiment configuration print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime()))) print('Parsed options: {}'.format(vars(args))) # print('Number of Speakers: {}.\n'.format(train_dir.num_spks))
file_loader = read_mat elif args.feat_format == 'npy': file_loader = np.load torch.multiprocessing.set_sharing_strategy('file_system') train_dir_a = EgsDataset(dir=args.train_dir_a, feat_dim=args.feat_dim, loader=file_loader, transform=transform) train_dir_b = EgsDataset(dir=args.train_dir_b, feat_dim=args.feat_dim, loader=file_loader, transform=transform) train_extract_dir = KaldiExtractDataset(dir=args.train_test_dir, transform=transform_V, filer_loader=file_loader, trials_file=args.train_trials) extract_dir = KaldiExtractDataset(dir=args.test_dir, transform=transform_V, filer_loader=file_loader) # test_dir = ScriptTestDataset(dir=args.test_dir, loader=file_loader, transform=transform_T) # if len(test_dir) < args.veri_pairs: # args.veri_pairs = len(test_dir) # print('There are %d verification pairs.' % len(test_dir)) # else: # test_dir.partition(args.veri_pairs) valid_dir_a = EgsDataset(dir=args.valid_dir_a, feat_dim=args.feat_dim, loader=file_loader,