def test(model, epoch, writer, xvector_dir): this_xvector_dir = "%s/test/epoch_%s" % (xvector_dir, epoch) extract_loader = torch.utils.data.DataLoader(extract_dir, batch_size=1, shuffle=False, **kwargs) verification_extract(extract_loader, model, this_xvector_dir, epoch) verify_dir = ScriptVerifyDataset(dir=args.test_dir, trials_file=args.trials, xvectors_dir=this_xvector_dir, loader=read_vec_flt) verify_loader = torch.utils.data.DataLoader(verify_dir, batch_size=128, shuffle=False, **kwargs) eer, eer_threshold, mindcf_01, mindcf_001 = verification_test( test_loader=verify_loader, dist_type=('cos' if args.cos_sim else 'l2'), log_interval=args.log_interval, xvector_dir=this_xvector_dir, epoch=epoch) print( '\33[91mTest ERR: {:.4f}%, Threshold: {:.4f}, mindcf-0.01: {:.4f}, mindcf-0.001: {:.4f}.\33[0m\n' .format(100. * eer, eer_threshold, mindcf_01, mindcf_001)) writer.add_scalar('Test/EER', 100. * eer, epoch) writer.add_scalar('Test/Threshold', eer_threshold, epoch) writer.add_scalar('Test/mindcf-0.01', mindcf_01, epoch) writer.add_scalar('Test/mindcf-0.001', mindcf_001, epoch)
def valid_test(train_extract_loader, model, epoch, xvector_dir): # switch to evaluate mode model.eval() this_xvector_dir = "%s/train/epoch_%s" % (xvector_dir, epoch) verification_extract(train_extract_loader, model, this_xvector_dir, epoch) verify_dir = ScriptVerifyDataset(dir=args.train_test_dir, trials_file=args.train_trials, xvectors_dir=this_xvector_dir, loader=read_vec_flt) verify_loader = torch.utils.data.DataLoader(verify_dir, batch_size=128, shuffle=False, **kwargs) eer, eer_threshold, mindcf_01, mindcf_001 = verification_test( test_loader=verify_loader, dist_type=('cos' if args.cos_sim else 'l2'), log_interval=args.log_interval, xvector_dir=this_xvector_dir, epoch=epoch) print('Test Epoch {}:\n\33[91mTrain EER: {:.4f}%, Threshold: {:.4f}, ' \ 'mindcf-0.01: {:.4f}, mindcf-0.001: {:.4f}.'.format(epoch, 100. * eer, eer_threshold, mindcf_01, mindcf_001)) writer.add_scalar('Train/EER', 100. * eer, epoch) writer.add_scalar('Train/Threshold', eer_threshold, epoch) writer.add_scalar('Train/mindcf-0.01', mindcf_01, epoch) writer.add_scalar('Train/mindcf-0.001', mindcf_001, epoch) torch.cuda.empty_cache()
def valid_test(train_extract_loader, valid_loader, model, epoch, xvector_dir): # switch to evaluate mode model.eval() valid_loader_a, valid_loader_b = valid_loader valid_pbar = tqdm(enumerate(zip(valid_loader_a, valid_loader_b))) correct_a = 0. correct_b = 0. total_datasize_a = 0. total_datasize_b = 0. softmax = nn.Softmax(dim=1) with torch.no_grad(): for batch_idx, ((data_a, label_a), (data_b, label_b)) in valid_pbar: label_a = label_a.cuda() label_b = label_b.cuda() # compute output data = torch.cat((data_a, data_b), dim=0) data = data.cuda() _, feats = model(data) classfier_a, classfier_b = model.cls_forward(feats[:len(data_a)], feats[len(data_a):]) # pdb.set_trace() predicted_labels = softmax(classfier_a) predicted_one_labels = torch.max(predicted_labels, dim=1)[1] minibatch_correct = float((predicted_one_labels.cuda() == label_a).sum().item()) minibatch_a = minibatch_correct / len(predicted_one_labels) correct_a += minibatch_correct total_datasize_a += len(predicted_one_labels) predicted_labels = softmax(classfier_b) predicted_one_labels = torch.max(predicted_labels, dim=1)[1] minibatch_correct = float((predicted_one_labels.cuda() == label_b).sum().item()) minibatch_b = minibatch_correct / len(predicted_one_labels) correct_b += minibatch_correct total_datasize_b += len(predicted_one_labels) if batch_idx % args.log_interval == 0: valid_pbar.set_description( 'Valid Epoch: {:2d} for {:4d} Batch Accuracy: A set: {:.4f}%, B set: {:.4f}%'.format( epoch, len(valid_loader_a.dataset), 100. * minibatch_a, 100. * minibatch_b )) # break valid_accuracy_a = 100. * correct_a / total_datasize_a valid_accuracy_b = 100. * correct_b / total_datasize_b writer.add_scalar('Train/Valid_Accuracy_A', valid_accuracy_a, epoch) writer.add_scalar('Train/Valid_Accuracy_B', valid_accuracy_b, epoch) torch.cuda.empty_cache() this_xvector_dir = "%s/train/epoch_%s" % (xvector_dir, epoch) verification_extract(train_extract_loader, model, this_xvector_dir, epoch) verify_dir = ScriptVerifyDataset(dir=args.train_test_dir, trials_file=args.train_trials, xvectors_dir=this_xvector_dir, loader=read_vec_flt) verify_loader = torch.utils.data.DataLoader(verify_dir, batch_size=128, shuffle=False, **kwargs) eer, eer_threshold, mindcf_01, mindcf_001 = verification_test(test_loader=verify_loader, dist_type=('cos' if args.cos_sim else 'l2'), log_interval=args.log_interval, xvector_dir=this_xvector_dir, epoch=epoch) print('Test Epoch {}:\n\33[91mTrain EER: {:.4f}%, Threshold: {:.4f}, ' \ 'mindcf-0.01: {:.4f}, mindcf-0.001: {:.4f}.'.format(epoch, 100. * eer, eer_threshold, mindcf_01, mindcf_001)) print('Valid on A Accuracy: %.4f %%. Valid on B Accuracy: %.4f %%.\33[0m' % ( valid_accuracy_a, valid_accuracy_b)) writer.add_scalar('Train/EER', 100. * eer, epoch) writer.add_scalar('Train/Threshold', eer_threshold, epoch) writer.add_scalar('Train/mindcf-0.01', mindcf_01, epoch) writer.add_scalar('Train/mindcf-0.001', mindcf_001, epoch) torch.cuda.empty_cache()
def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive # test_display_triplet_distance = False # print the experiment configuration print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime()))) print('Parsed options: {}'.format(vars(args))) print('Number of Speakers: {}.\n'.format(train_dir.num_spks)) # instantiate model and initialize weights kernel_size = args.kernel_size.split(',') kernel_size = [int(x) for x in kernel_size] padding = [int((x - 1) / 2) for x in kernel_size] kernel_size = tuple(kernel_size) padding = tuple(padding) stride = args.stride.split(',') stride = [int(x) for x in stride] channels = args.channels.split(',') channels = [int(x) for x in channels] model_kwargs = { 'input_dim': args.input_dim, 'feat_dim': args.feat_dim, 'kernel_size': kernel_size, 'filter': args.filter, 'inst_norm': args.inst_norm, 'stride': stride, 'fast': args.fast, 'avg_size': args.avg_size, 'time_dim': args.time_dim, 'padding': padding, 'encoder_type': args.encoder_type, 'vad': args.vad, 'embedding_size': args.embedding_size, 'ince': args.inception, 'resnet_size': args.resnet_size, 'num_classes': train_dir.num_spks, 'channels': channels, 'alpha': args.alpha, 'dropout_p': args.dropout_p } print('Model options: {}'.format(model_kwargs)) model = create_model(args.model, **model_kwargs) start_epoch = 0 if args.save_init and not args.finetune: check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start_epoch) torch.save(model, check_path) if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] filtered = { k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k } model_dict = model.state_dict() model_dict.update(filtered) model.load_state_dict(model_dict) # # model.dropout.p = args.dropout_p else: print('=> no checkpoint found at {}'.format(args.resume)) ce_criterion = nn.CrossEntropyLoss() if args.loss_type == 'soft': xe_criterion = None elif args.loss_type == 'asoft': ce_criterion = None model.classifier = AngleLinear(in_features=args.embedding_size, out_features=train_dir.num_spks, m=args.m) xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min, lambda_max=args.lambda_max) elif args.loss_type == 'center': xe_criterion = CenterLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size) elif args.loss_type == 'amsoft': ce_criterion = None model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=train_dir.num_spks) xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s) optimizer = create_optimizer(model.parameters(), args.optimizer, **opt_kwargs) if args.loss_type == 'center': optimizer = torch.optim.SGD([{ 'params': xe_criterion.parameters(), 'lr': args.lr * 5 }, { 'params': model.parameters() }], lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) if args.finetune: if args.loss_type == 'asoft' or args.loss_type == 'amsoft': classifier_params = list(map(id, model.classifier.parameters())) rest_params = filter(lambda p: id(p) not in classifier_params, model.parameters()) optimizer = torch.optim.SGD( [{ 'params': model.classifier.parameters(), 'lr': args.lr * 10 }, { 'params': rest_params }], lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) if args.filter: filter_params = list(map(id, model.filter_layer.parameters())) rest_params = filter(lambda p: id(p) not in filter_params, model.parameters()) optimizer = torch.optim.SGD([{ 'params': model.filter_layer.parameters(), 'lr': args.lr * 0.05 }, { 'params': rest_params }], lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) if args.scheduler == 'exp': scheduler = ExponentialLR(optimizer, gamma=args.gamma) else: milestones = args.milestones.split(',') milestones = [int(x) for x in milestones] milestones.sort() scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1) ce = [ce_criterion, xe_criterion] start = args.start_epoch + start_epoch print('Start epoch is : ' + str(start)) # start = 0 end = start + args.epochs train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=False, **kwargs) valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=int(args.batch_size / 2), shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size, shuffle=False, **kwargs) # sitw_test_loader = torch.utils.data.DataLoader(sitw_test_dir, batch_size=args.test_batch_size, # shuffle=False, **kwargs) # sitw_dev_loader = torch.utils.data.DataLoader(sitw_dev_part, batch_size=args.test_batch_size, shuffle=False, # **kwargs) if args.cuda: model = model.cuda() for i in range(len(ce)): if ce[i] != None: ce[i] = ce[i].cuda() try: print('Dropout is {}.'.format(model.dropout_p)) except: pass # for epoch in range(start, end): # # pdb.set_trace() # print('\n\33[1;34m Current \'{}\' learning rate is '.format(args.optimizer), end='') # for param_group in optimizer.param_groups: # print('{:.5f} '.format(param_group['lr']), end='') # print(' \33[0m') # # train(train_loader, model, ce, optimizer, epoch) # if epoch % 4 == 1 or epoch == (end - 1): # check_path = '{}/checkpoint_{}.pth'.format(args.check_path, epoch) # torch.save({'epoch': epoch, # 'state_dict': model.state_dict(), # 'criterion': ce}, # check_path) # # if epoch % 2 == 1 and epoch != (end - 1): # test(test_loader, valid_loader, model, epoch) # # sitw_test(sitw_test_loader, model, epoch) # # sitw_test(sitw_dev_loader, model, epoch) # scheduler.step() # exit(1) xvector_dir = args.check_path xvector_dir = xvector_dir.replace('checkpoint', 'xvector') if args.extract: extract_dir = KaldiExtractDataset(dir=args.test_dir, transform=transform_V, filer_loader=file_loader) extract_loader = torch.utils.data.DataLoader(extract_dir, batch_size=1, shuffle=False, **kwargs) verification_extract(extract_loader, model, xvector_dir) verify_dir = ScriptVerifyDataset(dir=args.test_dir, trials_file=args.trials, xvectors_dir=xvector_dir, loader=read_vec_flt) verify_loader = torch.utils.data.DataLoader(verify_dir, batch_size=64, shuffle=False, **kwargs) verification_test(test_loader=verify_loader, dist_type=('cos' if args.cos_sim else 'l2'), log_interval=args.log_interval, save=args.save_score) writer.close()
def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive # print the experiment configuration print('\nCurrent time is\33[91m {}\33[0m.'.format(str(time.asctime()))) opts = vars(args) keys = list(opts.keys()) keys.sort() options = [] for k in keys: options.append("\'%s\': \'%s\'" % (str(k), str(opts[k]))) print('Parsed options: \n{ %s }' % (', '.join(options))) print('Number of Speakers in training set: {}\n'.format( train_config_dir.num_spks)) # instantiate model and initialize weights kernel_size = args.kernel_size.split(',') kernel_size = [int(x) for x in kernel_size] context = args.context.split(',') context = [int(x) for x in context] if args.padding == '': padding = [int((x - 1) / 2) for x in kernel_size] else: padding = args.padding.split(',') padding = [int(x) for x in padding] kernel_size = tuple(kernel_size) padding = tuple(padding) stride = args.stride.split(',') stride = [int(x) for x in stride] channels = args.channels.split(',') channels = [int(x) for x in channels] model_kwargs = { 'input_dim': args.input_dim, 'feat_dim': args.feat_dim, 'kernel_size': kernel_size, 'context': context, 'filter_fix': args.filter_fix, 'mask': args.mask_layer, 'mask_len': args.mask_len, 'block_type': args.block_type, 'filter': args.filter, 'exp': args.exp, 'inst_norm': args.inst_norm, 'input_norm': args.input_norm, 'stride': stride, 'fast': args.fast, 'avg_size': args.avg_size, 'time_dim': args.time_dim, 'padding': padding, 'encoder_type': args.encoder_type, 'vad': args.vad, 'transform': args.transform, 'embedding_size': args.embedding_size, 'ince': args.inception, 'resnet_size': args.resnet_size, 'num_classes': train_config_dir.num_spks, 'channels': channels, 'alpha': args.alpha, 'dropout_p': args.dropout_p, 'loss_type': args.loss_type, 'm': args.m, 'margin': args.margin, 's': args.s, 'all_iteraion': args.all_iteraion } print('Model options: {}'.format(model_kwargs)) model = create_model(args.model, **model_kwargs) # optionally resume from a checkpoint # resume = args.ckp_dir + '/checkpoint_{}.pth'.format(args.epoch) assert os.path.isfile(args.resume), print( '=> no checkpoint found at {}'.format(args.resume)) print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) epoch = checkpoint['epoch'] checkpoint_state_dict = checkpoint['state_dict'] if isinstance(checkpoint_state_dict, tuple): checkpoint_state_dict = checkpoint_state_dict[0] filtered = { k: v for k, v in checkpoint_state_dict.items() if 'num_batches_tracked' not in k } # filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k} if list(filtered.keys())[0].startswith('module'): new_state_dict = OrderedDict() for k, v in filtered.items(): name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module. new_state_dict[name] = v # 新字典的key值对应的value为一一对应的值。 model.load_state_dict(new_state_dict) else: model_dict = model.state_dict() model_dict.update(filtered) model.load_state_dict(model_dict) # model.dropout.p = args.dropout_p if args.cuda: model.cuda() extracted_set = [] vec_type = 'xvectors_a' if args.xvector else 'xvectors_b' if args.train_dir != '': train_dir = KaldiExtractDataset(dir=args.train_dir, filer_loader=file_loader, transform=transform_V, extract_trials=False) train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=False, **kwargs) # Extract Train set vectors # extract(train_loader, model, dataset='train', extract_path=args.extract_path + '/x_vector') train_xvector_dir = args.xvector_dir + '/%s/epoch_%d/train' % ( vec_type, epoch) verification_extract(train_loader, model, train_xvector_dir, epoch=epoch, test_input=args.test_input, verbose=True, xvector=args.xvector) # copy wav.scp and utt2spk ... extracted_set.append('train') assert args.test_dir != '' test_dir = KaldiExtractDataset(dir=args.test_dir, filer_loader=file_loader, transform=transform_V, extract_trials=False) test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.batch_size, shuffle=False, **kwargs) # Extract test set vectors test_xvector_dir = args.xvector_dir + '/%s/epoch_%d/test' % (vec_type, epoch) # extract(test_loader, model, set_id='test', extract_path=args.extract_path + '/x_vector') verification_extract(test_loader, model, test_xvector_dir, epoch=epoch, test_input=args.test_input, verbose=True, xvector=args.xvector) # copy wav.scp and utt2spk ... extracted_set.append('test') if len(extracted_set) > 0: print('Extract x-vector completed for %s in %s!\n' % (','.join(extracted_set), args.xvector_dir + '/%s' % vec_type))
def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive # print the experiment configuration print('\nCurrent time is \33[91m{}\33[0m'.format(str(time.asctime()))) print('Parsed options: {}'.format(vars(args))) print('Number of Classes: {}\n'.format(train_dir.num_spks)) # instantiate # model and initialize weights # instantiate model and initialize weights kernel_size = args.kernel_size.split(',') kernel_size = [int(x) for x in kernel_size] padding = [int((x - 1) / 2) for x in kernel_size] kernel_size = tuple(kernel_size) padding = tuple(padding) model_kwargs = {'input_dim': args.feat_dim, 'kernel_size': kernel_size, 'stride': args.stride, 'avg_size': args.avg_size, 'time_dim': args.time_dim, 'padding': padding, 'resnet_size': args.resnet_size, 'embedding_size': args.embedding_size, 'num_classes': len(train_dir.speakers), 'dropout_p': args.dropout_p} print('Model options: {}'.format(model_kwargs)) model = create_model(args.model, **model_kwargs) start = 1 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) start = checkpoint['epoch'] checkpoint = torch.load(args.resume) filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k} model.load_state_dict(filtered) # optimizer.load_state_dict(checkpoint['optimizer']) # scheduler.load_state_dict(checkpoint['scheduler']) # criterion.load_state_dict(checkpoint['criterion']) else: print('=> no checkpoint found at {}'.format(args.resume)) ce_criterion = nn.CrossEntropyLoss() if args.loss_type == 'soft': xe_criterion = None elif args.loss_type == 'asoft': ce_criterion = None model.classifier = AngleLinear(in_features=args.embedding_size, out_features=train_dir.num_spks, m=args.m) xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min, lambda_max=args.lambda_max) elif args.loss_type == 'center': xe_criterion = CenterLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size) elif args.loss_type == 'amsoft': model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=train_dir.num_spks) xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s) if args.cuda: model.cuda() optimizer = create_optimizer(model.parameters(), args.optimizer, **opt_kwargs) if args.loss_type == 'center': optimizer = torch.optim.SGD([{'params': xe_criterion.parameters(), 'lr': args.lr * 5}, {'params': model.parameters()}], lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) milestones = args.milestones.split(',') milestones = [int(x) for x in milestones] milestones.sort() # print('Scheduler options: {}'.format(milestones)) scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1) if args.save_init and not args.finetune: check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start) torch.save({'epoch': start, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()}, check_path) start += args.start_epoch print('Start epoch is : ' + str(start)) end = args.epochs + 1 # pdb.set_trace() train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=True, **kwargs) valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=int(args.batch_size / 2), shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size, shuffle=False, **kwargs) ce = [ce_criterion, xe_criterion] if args.cuda: model = model.cuda() for i in range(len(ce)): if ce[i] != None: ce[i] = ce[i].cuda() for epoch in range(start, end): # pdb.set_trace() print('\n\33[1;34m Current \'{}\' learning rate is '.format(args.optimizer), end='') for param_group in optimizer.param_groups: print('{:.5f} '.format(param_group['lr']), end='') print(' \33[0m') train(train_loader, model, optimizer, ce, scheduler, epoch) test(test_loader, valid_loader, model, epoch) scheduler.step() # break verfify_dir = KaldiExtractDataset(dir=args.test_dir, transform=transform_T, filer_loader=file_loader) verify_loader = torch.utils.data.DataLoader(verfify_dir, batch_size=args.test_batch_size, shuffle=False, **kwargs) verification_extract(verify_loader, model, args.xvector_dir) file_loader = read_vec_flt test_dir = ScriptVerifyDataset(dir=args.test_dir, trials_file=args.trials, xvectors_dir=args.xvector_dir, loader=file_loader) test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size * 64, shuffle=False, **kwargs) verification_test(test_loader=test_loader, dist_type='cos' if args.cos_sim else 'l2', log_interval=args.log_interval) writer.close()