def main(): # print the experiment configuration print('\33[91m \nCurrent time is {}\33[0m'.format(str(time.asctime()))) # print('Parsed options:\n{}\n'.format(vars(args))) print('Number of Speakers: {}\n'.format(len(train_dir.classes))) context = [[-2, 2], [-2, 0, 2], [-3, 0, 3], [0], [0]] # the same configure as x-vector node_num = [64, 128, 256, 512, 1024, 1024, 512, 512] full_context = [True, False, False, True, True] # train_set = trainset.TrainSet('../all_feature/') # todo: # train_set = [] train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, collate_fn=PadCollate(dim=2), shuffle=True, **kwargs) valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.test_batch_size, ollate_fn=PadCollate(dim=2), shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size, shuffle=False, **kwargs) model = Time_Delay(context, 64, len(train_dir.classes), node_num, full_context) if args.cuda: model.cuda() optimizer = create_optimizer(model, args.lr) # torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # torch.set_num_threads(16) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] checkpoint = torch.load(args.resume) filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k} model.load_state_dict(filtered) optimizer.load_state_dict(checkpoint['optimizer']) # criterion.load_state_dict(checkpoint['criterion']) else: print('=> no checkpoint found at {}'.format(args.resume)) start = args.start_epoch print('start epoch is : ' + str(start)) # start = 0 end = start + args.epochs for epoch in range(start, end): # pdb.set_trace() train(train_loader, model, optimizer, epoch) test(test_loader, valid_loader, model, epoch)
def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive # print the experiment configuration print('\33[91mCurrent time is {}\33[0m'.format(str(time.asctime()))) print('Parsed options: {}'.format(vars(args))) print('Number of Classes: {}\n'.format(len(train_dir.speakers))) # instantiate # model and initialize weights model = XVectorTDNN(len(train_dir.speakers), dropout_p=0.0) if args.cuda: model.cuda() # pdb.set_trace() valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=int(args.batch_size / 2), collate_fn=PadCollate(dim=1), shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(test_part, batch_size=args.test_batch_size, shuffle=False, **kwargs) # optionally resume from a checkpoint epochs = np.arange(1, 15) for epoch in epochs: if os.path.isfile(args.resume.format(epoch)): print('=> loading checkpoint {}'.format(args.resume.format(epoch))) checkpoint = torch.load(args.resume.format(epoch)) start = checkpoint['epoch'] filtered = { k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k } model.load_state_dict(filtered) # criterion.load_state_dict(checkpoint['criterion']) else: print('=> no checkpoint found at {}'.format(args.resume)) model.eval() test(test_loader, valid_loader, model, start)
def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive test_display_triplet_distance = False # print the experiment configuration print('\33[91m\nCurrent time is {}\33[0m'.format(str(time.asctime()))) # print('Parsed options:\n{}\n'.format(vars(args))) print('Number of Speakers: {}\n'.format(len(train_dir.classes))) # instantiate model and initialize weights model = SuperficialResCNN(layers=[1, 1, 1, 1], embedding_size=args.embedding_size, n_classes=len(train_dir.classes), m=args.margin) # model = ResCNNSpeaker(embedding_size=args.embedding_size, resnet_size=10, num_classes=len(train_dir.classes)) if args.cuda: model.cuda() optimizer = create_optimizer(model, args.lr) # criterion = AngularSoftmax(in_feats=args.embedding_size, # num_classes=len(train_dir.classes)) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] checkpoint = torch.load(args.resume) filtered = { k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k } model.load_state_dict(filtered) # optimizer.load_state_dict(checkpoint['optimizer']) # criterion.load_state_dict(checkpoint['criterion']) else: print('=> no checkpoint found at {}'.format(args.resume)) start = args.start_epoch print('start epoch is : ' + str(start)) # start = 0 end = start + args.epochs scheduler = StepLR(optimizer, step_size=15, gamma=0.1) train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=True, collate_fn=PadCollate(dim=2), **kwargs) valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.test_batch_size, shuffle=False, collate_fn=PadCollate(dim=2), **kwargs) test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size, shuffle=False, **kwargs) for epoch in range(start, end): # pdb.set_trace() train(train_loader, model, optimizer, epoch) test(test_loader, valid_loader, model, epoch) scheduler.step() # exit(1) writer.close()
def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive # test_display_triplet_distance = False # print the experiment configuration print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime()))) print('Parsed options: {}'.format(vars(args))) print('Number of Speakers: {}.\n'.format(train_dir.num_spks)) # instantiate model and initialize weights kernel_size = args.kernel_size.split(',') kernel_size = [int(x) for x in kernel_size] padding = [int((x - 1) / 2) for x in kernel_size] kernel_size = tuple(kernel_size) padding = tuple(padding) channels = args.channels.split(',') channels = [int(x) for x in channels] model_kwargs = {'embedding_size': args.embedding_size, 'inst_norm': args.inst_norm, 'resnet_size': args.resnet_size, 'num_classes': train_dir.num_spks, 'channels': channels, 'avg_size': args.avg_size, 'alpha': args.alpha, 'kernel_size': kernel_size, 'padding': padding, 'dropout_p': args.dropout_p} print('Model options: {}'.format(model_kwargs)) model = create_model(args.model, **model_kwargs) start_epoch = 0 if args.save_init: check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start_epoch) torch.save(model, check_path) if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k} model_dict = model.state_dict() model_dict.update(filtered) model.load_state_dict(model_dict) # model.dropout.p = args.dropout_p else: print('=> no checkpoint found at {}'.format(args.resume)) ce_criterion = nn.CrossEntropyLoss() if args.loss_type == 'soft': xe_criterion = None elif args.loss_type == 'asoft': ce_criterion = None model.classifier = AngleLinear(in_features=args.embedding_size, out_features=train_dir.num_spks, m=args.m) xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min, lambda_max=args.lambda_max) elif args.loss_type == 'center': xe_criterion = CenterLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size) elif args.loss_type == 'amsoft': ce_criterion = None model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=train_dir.num_spks) xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s) optimizer = create_optimizer(model.parameters(), args.optimizer, **opt_kwargs) if args.loss_type == 'center': optimizer = torch.optim.SGD([{'params': xe_criterion.parameters(), 'lr': args.lr * 5}, {'params': model.parameters()}], lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) if args.finetune: if args.loss_type == 'asoft' or args.loss_type == 'amsoft': classifier_params = list(map(id, model.classifier.parameters())) rest_params = filter(lambda p: id(p) not in classifier_params, model.parameters()) optimizer = torch.optim.SGD([{'params': model.classifier.parameters(), 'lr': args.lr * 5}, {'params': rest_params}], lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) if args.scheduler == 'exp': scheduler = ExponentialLR(optimizer, gamma=args.gamma) else: milestones = args.milestones.split(',') milestones = [int(x) for x in milestones] milestones.sort() scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1) ce = [ce_criterion, xe_criterion] start = args.start_epoch + start_epoch print('Start epoch is : ' + str(start)) # start = 0 end = start + args.epochs train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, collate_fn=PadCollate(dim=2, fix_len=False, min_chunk_size=250, max_chunk_size=450), shuffle=True, **kwargs) valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=int(args.batch_size / 2), collate_fn=PadCollate(dim=2, fix_len=False, min_chunk_size=250, max_chunk_size=450), shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size, shuffle=False, **kwargs) # sitw_test_loader = torch.utils.data.DataLoader(sitw_test_dir, batch_size=args.test_batch_size, # shuffle=False, **kwargs) # sitw_dev_loader = torch.utils.data.DataLoader(sitw_dev_part, batch_size=args.test_batch_size, shuffle=False, # **kwargs) if args.cuda: model = model.cuda() for i in range(len(ce)): if ce[i] != None: ce[i] = ce[i].cuda() print('Dropout is {}.'.format(model.dropout_p)) for epoch in range(start, end): # pdb.set_trace() print('\n\33[1;34m Current \'{}\' learning rate is '.format(args.optimizer), end='') for param_group in optimizer.param_groups: print('{:.5f} '.format(param_group['lr']), end='') print(' \33[0m') train(train_loader, model, ce, optimizer, epoch) if epoch % 4 == 1 or epoch == (end - 1): check_path = '{}/checkpoint_{}.pth'.format(args.check_path, epoch) torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'criterion': ce}, check_path) if epoch % 2 == 1 and epoch != (end - 1): test(test_loader, valid_loader, model, epoch) # sitw_test(sitw_test_loader, model, epoch) # sitw_test(sitw_dev_loader, model, epoch) scheduler.step() # exit(1) extract_dir = KaldiExtractDataset(dir=args.test_dir, transform=transform_V, filer_loader=file_loader) extract_loader = torch.utils.data.DataLoader(extract_dir, batch_size=1, shuffle=False, **kwargs) xvector_dir = args.check_path xvector_dir = xvector_dir.replace('checkpoint', 'xvector') verification_extract(extract_loader, model, xvector_dir) verify_dir = ScriptVerifyDataset(dir=args.test_dir, trials_file=args.trials, xvectors_dir=xvector_dir, loader=read_vec_flt) verify_loader = torch.utils.data.DataLoader(verify_dir, batch_size=64, shuffle=False, **kwargs) verification_test(test_loader=verify_loader, dist_type=('cos' if args.cos_sim else 'l2'), log_interval=args.log_interval) writer.close()
def main(): # print the experiment configuration print('\33[91m \nCurrent time is {}\33[0m'.format(str(time.asctime()))) print('Parsed options: {}'.format(vars(args))) print('Number of Speakers: {}\n'.format(len(train_dir.classes))) # device = torch.device('cuda:2') if torch.cuda.is_available() else torch.device('cpu') context = [[-2, 2], [-2, 0, 2], [-3, 0, 3], [0], [0]] # the same configure as x-vector node_num = [512, 512, 512, 512, 1500, 3000, 512, 512] full_context = [True, False, False, True, True] # train_set = trainset.TrainSet('../all_feature/') # todo: # train_set = [] train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=True, collate_fn=PadCollate(dim=2), **kwargs) valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.batch_size, shuffle=False, collate_fn=PadCollate(dim=2), **kwargs) test_loader = torch.utils.data.DataLoader(test_part, batch_size=args.test_batch_size, shuffle=False, **kwargs) # train_loader = DataLoader(train_dir, batch_size=args.batch_size, shuffle=True) # valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.batch_size, shuffle=False) # test_loader = DataLoader(test_dir, batch_size=args.test_batch_size, shuffle=False) model = Time_Delay(context, 24, len(train_dir.classes), node_num, full_context) if args.cuda: # model.to(device) model = model.cuda() optimizer = create_optimizer(model.parameters(), args.optimizer, **opt_kwargs) scheduler = MultiStepLR(optimizer, milestones=[16, 24], gamma=0.1) ce_loss = nn.CrossEntropyLoss().cuda() # torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # torch.set_num_threads(16) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] checkpoint = torch.load(args.resume) filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k} model.load_state_dict(filtered) optimizer.load_state_dict(checkpoint['optimizer']) # criterion.load_state_dict(checkpoint['criterion']) else: print('=> no checkpoint found at {}'.format(args.resume)) start = args.start_epoch print('start epoch is : ' + str(start)) # start = 0 end = start + args.epochs for epoch in range(start, end): # pdb.set_trace() train(train_loader, model, ce_loss, optimizer, epoch) test(test_loader, valid_loader, model, epoch) scheduler.step()
def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive # test_display_triplet_distance = False # print the experiment configuration print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime()))) print('Parsed options: {}'.format(vars(args))) print('Number of Speakers: {}.\n'.format(train_dir.num_spks)) model_kwargs = { 'embedding_size': args.embedding_size, 'num_classes': train_dir.num_spks, 'input_dim': args.feat_dim, 'dropout_p': args.dropout_p } print('Model options: {}'.format(model_kwargs)) model = create_model(args.model, **model_kwargs) # model = ASTDNN(num_classes=train_dir.num_spks, input_dim=args.feat_dim, # embedding_size=args.embedding_size, # dropout_p=args.dropout_p) start_epoch = 0 if args.save_init: check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start_epoch) torch.save(model, check_path) if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] filtered = { k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k } model_dict = model.state_dict() model_dict.update(filtered) model.load_state_dict(model_dict) # try: model.dropout.p = args.dropout_p except: pass else: print('=> no checkpoint found at {}'.format(args.resume)) ce_criterion = nn.CrossEntropyLoss() if args.loss_type == 'soft': xe_criterion = None elif args.loss_type == 'asoft': ce_criterion = None model.classifier = AngleLinear(in_features=args.embedding_size, out_features=train_dir.num_spks, m=args.m) xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min, lambda_max=args.lambda_max) elif args.loss_type == 'center': xe_criterion = CenterLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size) elif args.loss_type == 'amsoft': ce_criterion = None model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=train_dir.num_spks) xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s) optimizer = create_optimizer(model.parameters(), args.optimizer, **opt_kwargs) if args.loss_type == 'center': optimizer = torch.optim.SGD([{ 'params': xe_criterion.parameters(), 'lr': args.lr * 5 }, { 'params': model.parameters() }], lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) if args.finetune: if args.loss_type == 'asoft' or args.loss_type == 'amsoft': classifier_params = list(map(id, model.classifier.parameters())) rest_params = filter(lambda p: id(p) not in classifier_params, model.parameters()) optimizer = torch.optim.SGD( [{ 'params': model.classifier.parameters(), 'lr': args.lr * 5 }, { 'params': rest_params }], lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) if args.scheduler == 'exp': scheduler = ExponentialLR(optimizer, gamma=args.gamma) else: milestones = args.milestones.split(',') milestones = [int(x) for x in milestones] milestones.sort() scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1) ce = [ce_criterion, xe_criterion] start = args.start_epoch + start_epoch print('Start epoch is : ' + str(start)) # start = 0 end = start + args.epochs train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, collate_fn=PadCollate( dim=2, fix_len=False, min_chunk_size=250, max_chunk_size=450), shuffle=True, **kwargs) valid_loader = torch.utils.data.DataLoader( valid_dir, batch_size=int(args.batch_size / 2), collate_fn=PadCollate(dim=2, fix_len=False, min_chunk_size=250, max_chunk_size=450), shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size, shuffle=False, **kwargs) # sitw_test_loader = torch.utils.data.DataLoader(sitw_test_dir, batch_size=args.test_batch_size, # shuffle=False, **kwargs) # sitw_dev_loader = torch.utils.data.DataLoader(sitw_dev_part, batch_size=args.test_batch_size, shuffle=False, # **kwargs) if args.cuda: model = model.cuda() for i in range(len(ce)): if ce[i] != None: ce[i] = ce[i].cuda() for epoch in range(start, end): # pdb.set_trace() print('\n\33[1;34m Current \'{}\' learning rate is '.format( args.optimizer), end='') for param_group in optimizer.param_groups: print('{:.5f} '.format(param_group['lr']), end='') print(' \33[0m') train(train_loader, model, ce, optimizer, epoch) test(test_loader, valid_loader, model, epoch) # sitw_test(sitw_test_loader, model, epoch) # sitw_test(sitw_dev_loader, model, epoch) scheduler.step() # exit(1) writer.close()
def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive # print the experiment configuration print('\33[91m\nCurrent time is {}.\33[0m'.format(str(time.asctime()))) print('Parsed options: {}'.format(vars(args))) print('Number of Speakers: {}\n'.format(len(train_dir.classes))) # instantiate model and initialize weights model = DeepSpeakerModel(resnet_size=10, embedding_size=args.embedding_size, num_classes=len(train_dir.classes)) if args.cuda: model.cuda() if args.data_parallel: model = torch.nn.DataParallel(model, device_ids=[2, 3]) optimizer = create_optimizer(model, args.lr) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] # Filter that remove uncessary component in checkpoint file filtered = { k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k } model.load_state_dict(filtered) optimizer.load_state_dict(checkpoint['optimizer']) else: print('=> no checkpoint found at {}'.format(args.resume)) start = args.start_epoch print('start epoch is : ' + str(start)) #start = 0 end = start + args.epochs train_loader = torch.utils.data.DataLoader( train_dir, batch_size=args.batch_size, shuffle=True, collate_fn=TripletPadCollate(dim=2), **kwargs) valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.test_batch_size, shuffle=False, collate_fn=PadCollate(dim=2), **kwargs) test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size, shuffle=False, **kwargs) for epoch in range(start, end): train(train_loader, model, optimizer, epoch) test(test_loader, valid_loader, model, epoch) # break writer.close()
def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive # test_display_triplet_distance = False # print the experiment configuration print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime()))) opts = vars(args) keys = list(opts.keys()) keys.sort() options = [] for k in keys: options.append("\'%s\': \'%s\'" % (str(k), str(opts[k]))) print('Parsed options: \n{ %s }' % (', '.join(options))) print('Number of Speakers: {}.\n'.format(train_dir.num_spks)) # instantiate model and initialize weights kernel_size = args.kernel_size.split(',') kernel_size = [int(x) for x in kernel_size] context = args.context.split(',') context = [int(x) for x in context] if args.padding == '': padding = [int((x - 1) / 2) for x in kernel_size] else: padding = args.padding.split(',') padding = [int(x) for x in padding] kernel_size = tuple(kernel_size) padding = tuple(padding) stride = args.stride.split(',') stride = [int(x) for x in stride] channels = args.channels.split(',') channels = [int(x) for x in channels] model_kwargs = { 'input_dim': args.input_dim, 'feat_dim': args.feat_dim, 'kernel_size': kernel_size, 'context': context, 'filter_fix': args.filter_fix, 'mask': args.mask_layer, 'mask_len': args.mask_len, 'block_type': args.block_type, 'filter': args.filter, 'exp': args.exp, 'inst_norm': args.inst_norm, 'input_norm': args.input_norm, 'stride': stride, 'fast': args.fast, 'avg_size': args.avg_size, 'time_dim': args.time_dim, 'padding': padding, 'encoder_type': args.encoder_type, 'vad': args.vad, 'transform': args.transform, 'embedding_size': args.embedding_size, 'ince': args.inception, 'resnet_size': args.resnet_size, 'num_classes': train_dir.num_spks, 'channels': channels, 'alpha': args.alpha, 'dropout_p': args.dropout_p } print('Model options: {}'.format(model_kwargs)) dist_type = 'cos' if args.cos_sim else 'l2' print('Testing with %s distance, ' % dist_type) model = create_model(args.model, **model_kwargs) start_epoch = 0 if args.save_init and not args.finetune: check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start_epoch) torch.save(model, check_path) iteration = 0 # if args.resume else 0 if args.finetune and args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] checkpoint_state_dict = checkpoint['state_dict'] if isinstance(checkpoint_state_dict, tuple): checkpoint_state_dict = checkpoint_state_dict[0] filtered = { k: v for k, v in checkpoint_state_dict.items() if 'num_batches_tracked' not in k } if list(filtered.keys())[0].startswith('module'): new_state_dict = OrderedDict() for k, v in filtered.items(): name = k[ 7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module. new_state_dict[name] = v # 新字典的key值对应的value为一一对应的值。 model.load_state_dict(new_state_dict) else: model_dict = model.state_dict() model_dict.update(filtered) model.load_state_dict(model_dict) # model.dropout.p = args.dropout_p else: print('=> no checkpoint found at {}'.format(args.resume)) ce_criterion = nn.CrossEntropyLoss() if args.loss_type == 'soft': xe_criterion = None elif args.loss_type == 'asoft': ce_criterion = None model.classifier = AngleLinear(in_features=args.embedding_size, out_features=train_dir.num_spks, m=args.m) xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min, lambda_max=args.lambda_max) elif args.loss_type == 'center': xe_criterion = CenterLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size) elif args.loss_type == 'gaussian': xe_criterion = GaussianLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size) elif args.loss_type == 'coscenter': xe_criterion = CenterCosLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size) elif args.loss_type == 'mulcenter': xe_criterion = MultiCenterLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size, num_center=args.num_center) elif args.loss_type == 'amsoft': ce_criterion = None model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=train_dir.num_spks) xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s) elif args.loss_type == 'arcsoft': ce_criterion = None model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=train_dir.num_spks) xe_criterion = ArcSoftmaxLoss(margin=args.margin, s=args.s, iteraion=iteration, all_iteraion=args.all_iteraion) elif args.loss_type == 'wasse': xe_criterion = Wasserstein_Loss(source_cls=args.source_cls) elif args.loss_type == 'ring': xe_criterion = RingLoss(ring=args.ring) args.alpha = 0.0 model_para = model.parameters() if args.loss_type in [ 'center', 'mulcenter', 'gaussian', 'coscenter', 'ring' ]: assert args.lr_ratio > 0 model_para = [{ 'params': xe_criterion.parameters(), 'lr': args.lr * args.lr_ratio }, { 'params': model.parameters() }] if args.finetune: if args.loss_type == 'asoft' or args.loss_type == 'amsoft': classifier_params = list(map(id, model.classifier.parameters())) rest_params = filter(lambda p: id(p) not in classifier_params, model.parameters()) assert args.lr_ratio > 0 model_para = [{ 'params': model.classifier.parameters(), 'lr': args.lr * args.lr_ratio }, { 'params': rest_params }] if args.filter in ['fDLR', 'fBLayer', 'fLLayer', 'fBPLayer']: filter_params = list(map(id, model.filter_layer.parameters())) rest_params = filter(lambda p: id(p) not in filter_params, model.parameters()) model_para = [{ 'params': model.filter_layer.parameters(), 'lr': args.lr * args.lr_ratio }, { 'params': rest_params }] optimizer = create_optimizer(model_para, args.optimizer, **opt_kwargs) if not args.finetune and args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] checkpoint_state_dict = checkpoint['state_dict'] if isinstance(checkpoint_state_dict, tuple): checkpoint_state_dict = checkpoint_state_dict[0] filtered = { k: v for k, v in checkpoint_state_dict.items() if 'num_batches_tracked' not in k } # filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k} if list(filtered.keys())[0].startswith('module'): new_state_dict = OrderedDict() for k, v in filtered.items(): name = k[ 7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module. new_state_dict[name] = v # 新字典的key值对应的value为一一对应的值。 model.load_state_dict(new_state_dict) else: model_dict = model.state_dict() model_dict.update(filtered) model.load_state_dict(model_dict) # model.dropout.p = args.dropout_p else: print('=> no checkpoint found at {}'.format(args.resume)) # Save model config txt with open( osp.join( args.check_path, 'model.%s.conf' % time.strftime("%Y.%m.%d", time.localtime())), 'w') as f: f.write('model: ' + str(model) + '\n') f.write('CrossEntropy: ' + str(ce_criterion) + '\n') f.write('Other Loss: ' + str(xe_criterion) + '\n') f.write('Optimizer: ' + str(optimizer) + '\n') milestones = args.milestones.split(',') milestones = [int(x) for x in milestones] milestones.sort() if args.scheduler == 'exp': scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=args.gamma) elif args.scheduler == 'rop': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, patience=args.patience, min_lr=1e-5) else: scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) ce = [ce_criterion, xe_criterion] start = args.start_epoch + start_epoch print('Start epoch is : ' + str(start)) # start = 0 end = start + args.epochs train_loader = torch.utils.data.DataLoader( train_dir, batch_size=args.batch_size, collate_fn=PadCollate(dim=2, num_batch=int( np.ceil(len(train_dir) / args.batch_size)), min_chunk_size=args.min_chunk_size, max_chunk_size=args.max_chunk_size), shuffle=args.shuffle, **kwargs) valid_loader = torch.utils.data.DataLoader( valid_dir, batch_size=int(args.batch_size / 2), collate_fn=PadCollate(dim=2, fix_len=True, min_chunk_size=args.chunk_size, max_chunk_size=args.chunk_size + 1), shuffle=False, **kwargs) train_extract_loader = torch.utils.data.DataLoader(train_extract_dir, batch_size=1, shuffle=False, **extract_kwargs) if args.cuda: if len(args.gpu_id) > 1: print("Continue with gpu: %s ..." % str(args.gpu_id)) torch.distributed.init_process_group( backend="nccl", # init_method='tcp://localhost:23456', init_method= 'file:///home/ssd2020/yangwenhao/lstm_speaker_verification/data/sharedfile', rank=0, world_size=1) model = DistributedDataParallel(model.cuda(), find_unused_parameters=True) else: model = model.cuda() for i in range(len(ce)): if ce[i] != None: ce[i] = ce[i].cuda() try: print('Dropout is {}.'.format(model.dropout_p)) except: pass xvector_dir = args.check_path xvector_dir = xvector_dir.replace('checkpoint', 'xvector') start_time = time.time() try: for epoch in range(start, end): # pdb.set_trace() lr_string = '\n\33[1;34m Current \'{}\' learning rate is '.format( args.optimizer) for param_group in optimizer.param_groups: lr_string += '{:.6f} '.format(param_group['lr']) print('%s \33[0m' % lr_string) train(train_loader, model, ce, optimizer, epoch) valid_loss = valid_class(valid_loader, model, ce, epoch) if (epoch == 1 or epoch != (end - 2)) and (epoch % 4 == 1 or epoch in milestones or epoch == (end - 1)): model.eval() check_path = '{}/checkpoint_{}.pth'.format( args.check_path, epoch) model_state_dict = model.module.state_dict() \ if isinstance(model, DistributedDataParallel) else model.state_dict(), torch.save( { 'epoch': epoch, 'state_dict': model_state_dict, 'criterion': ce }, check_path) valid_test(train_extract_loader, model, epoch, xvector_dir) test(model, epoch, writer, xvector_dir) if epoch != (end - 1): try: shutil.rmtree("%s/train/epoch_%s" % (xvector_dir, epoch)) shutil.rmtree("%s/test/epoch_%s" % (xvector_dir, epoch)) except Exception as e: print('rm dir xvectors error:', e) if args.scheduler == 'rop': scheduler.step(valid_loss) else: scheduler.step() except KeyboardInterrupt: end = epoch writer.close() stop_time = time.time() t = float(stop_time - start_time) print("Running %.4f minutes for each epoch.\n" % (t / 60 / (max(end - start, 1)))) exit(0)