def train(lr, l2, max_gnorm, momentum, margin, lambda_, swap, latent_size, n_frames, model, ncoef, epochs, batch_size, valid_batch_size, n_workers, cuda, train_hdf_file, valid_hdf_file, cp_path, softmax, delta, logdir): if cuda: device=get_freer_gpu() if args.model == 'resnet_qrnn': import cupy cupy.cuda.Device(int(str(device).split(':')[-1])).use() cp_name = get_file_name(cp_path) if args.logdir: from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(log_dir=logdir+cp_name, comment=args.model, purge_step=True) else: writer = None train_dataset = Loader(hdf5_name = train_hdf_file, max_nb_frames = int(n_frames), delta = delta) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n.workers, worker_init_fn=set_np_randomseed) valid_dataset = Loader_valid(hdf5_name = valid_hdf_file, max_nb_frames = int(n_frames), delta = delta) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=valid_batch_size, shuffle=True, num_workers=n_workers, worker_init_fn=set_np_randomseed) if model == 'resnet_mfcc': model=model_.ResNet_mfcc(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=ncoef, sm_type=softmax, delta=delta) elif model == 'resnet_34': model=model_.ResNet_34(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=ncoef, sm_type=softmax, delta=delta) elif model == 'resnet_lstm': model=model_.ResNet_lstm(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=ncoef, sm_type=softmax, delta=delta) elif model == 'resnet_qrnn': model=model_.ResNet_qrnn(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=ncoef, sm_type=softmax, delta=delta) elif model == 'resnet_stats': model=model_.ResNet_stats(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=ncoef, sm_type=softmax, delta=delta) elif args.model == 'resnet_large': model = model_.ResNet_large(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta) elif args.model == 'resnet_small': model = model_.ResNet_small(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta) elif args.model == 'resnet_2d': model = model_.ResNet_2d(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta) elif args.model == 'TDNN': model = model_.TDNN(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta) elif args.model == 'TDNN_att': model = model_.TDNN_att(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta) elif args.model == 'TDNN_multihead': model = model_.TDNN_multihead(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta) elif args.model == 'TDNN_lstm': model = model_.TDNN_lstm(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta) elif args.model == 'TDNN_aspp': model = model_.TDNN_aspp(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta) elif args.model == 'TDNN_mod': model = model_.TDNN_mod(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta) elif args.model == 'TDNN_multipool': model = model_.TDNN_multipool(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta) elif args.model == 'transformer': model = model_.transformer_enc(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta) if cuda: model=model.to(device) else: device=None optimizer=optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=l2) trainer=TrainLoop(model, optimizer, train_loader, valid_loader, max_gnorm=max_gnorm, margin=margin, lambda_=lambda_, verbose=-1, device=device, cp_name=cp_name, save_cp=True, checkpoint_path=cp_path, swap=swap, softmax=True, pretrain=False, mining=True, cuda=cuda, logger=writer) return trainer.train(n_epochs=epochs)
purge_step=0) writer.add_hparams(hparam_dict=args_dict, metric_dict={'best_eer': 0.0}) else: writer = None train_dataset = Loader(hdf5_name=args.train_hdf_file, max_nb_frames=args.n_frames, delta=args.delta) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, worker_init_fn=set_np_randomseed) valid_dataset = Loader_valid(hdf5_name=args.valid_hdf_file, max_nb_frames=args.n_frames, delta=args.delta) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.valid_batch_size, shuffle=True, num_workers=args.workers, worker_init_fn=set_np_randomseed) if args.model == 'resnet_mfcc': model = model_.ResNet_mfcc(n_z=args.latent_size, proj_size=train_dataset.n_speakers if args.softmax != 'none' or args.pretrain else 0, ncoef=args.ncoef, sm_type=args.softmax, delta=args.delta) elif args.model == 'resnet_34':
torch.cuda.manual_seed(args.seed) train_dataset = Loader(hdf5_name=args.train_hdf_file, max_len=args.max_len, vad=args.vad, ncoef=args.ncoef if args.mfcc else None) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers, worker_init_fn=set_np_randomseed) if args.valid_hdf_file is not None: valid_dataset = Loader_valid(hdf5_name=args.valid_hdf_file, max_len=args.max_len, vad=args.vad, ncoef=args.ncoef if args.mfcc else None) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=args.valid_batch_size, shuffle=True, drop_last=True, num_workers=args.workers, worker_init_fn=set_np_randomseed) else: valid_loader = None if args.cuda: device = get_freer_gpu() import cupy cupy.cuda.Device(int(str(device).split(':')[-1])).use()