def train(opt): random.seed(opt.seed) torch.manual_seed(opt.seed) torch.cuda.manual_seed_all(opt.seed) device = check_envirionment(opt.use_cuda) if not opt.resume: audio_paths, script_paths = load_data_list(opt.data_list_path, opt.dataset_path) epoch_time_step, trainset_list, validset = split_dataset(opt, audio_paths, script_paths) model = build_model(opt, device) optimizer = optim.Adam(model.module.parameters(), lr=opt.init_lr, weight_decay=opt.weight_decay) if opt.rampup_period > 0: scheduler = RampUpLR(optimizer, opt.init_lr, opt.high_plateau_lr, opt.rampup_period) optimizer = Optimizer(optimizer, scheduler, opt.rampup_period, opt.max_grad_norm) else: optimizer = Optimizer(optimizer, None, 0, opt.max_grad_norm) criterion = LabelSmoothedCrossEntropyLoss( num_classes=len(char2id), ignore_index=PAD_token, smoothing=opt.label_smoothing, dim=-1, reduction=opt.reduction, architecture=opt.architecture ).to(device) else: trainset_list = None validset = None model = None optimizer = None criterion = LabelSmoothedCrossEntropyLoss( num_classes=len(char2id), ignore_index=PAD_token, smoothing=opt.label_smoothing, dim=-1, reduction=opt.reduction, architecture=opt.architecture ).to(device) epoch_time_step = None trainer = SupervisedTrainer( optimizer=optimizer, criterion=criterion, trainset_list=trainset_list, validset=validset, num_workers=opt.num_workers, high_plateau_lr=opt.high_plateau_lr, low_plateau_lr=opt.low_plateau_lr, decay_threshold=opt.decay_threshold, exp_decay_period=opt.exp_decay_period, device=device, teacher_forcing_step=opt.teacher_forcing_step, min_teacher_forcing_ratio=opt.min_teacher_forcing_ratio, print_every=opt.print_every, save_result_every=opt.save_result_every, checkpoint_every=opt.checkpoint_every, architecture=opt.architecture ) model = trainer.train( model=model, batch_size=opt.batch_size, epoch_time_step=epoch_time_step, num_epochs=opt.num_epochs, teacher_forcing_ratio=opt.teacher_forcing_ratio, resume=opt.resume ) return model
def inference(opt): device = check_envirionment(opt.use_cuda) model = load_test_model(opt, device) audio_paths, script_paths = load_data_list(opt.data_list_path, opt.dataset_path) target_dict = load_targets(script_paths) testset = SpectrogramDataset(audio_paths=audio_paths, script_paths=script_paths, sos_id=SOS_token, eos_id=EOS_token, target_dict=target_dict, opt=opt, spec_augment=False, noise_augment=False) evaluator = Evaluator(testset, opt.batch_size, device, opt.num_workers, opt.print_every, opt.decode, opt.k) evaluator.evaluate(model)
def train(opt): random.seed(opt.seed) torch.manual_seed(opt.seed) torch.cuda.manual_seed_all(opt.seed) device = check_envirionment(opt.use_cuda) audio_paths, script_paths = load_data_list(opt.data_list_path, opt.dataset_path) epoch_time_step, trainset_list, validset = split_dataset( opt, audio_paths, script_paths) model = build_ensemble(['model_path1', 'model_path2', 'model_path3'], opt.ensemble_method, device) optimizer = optim.Adam(model.module.parameters(), lr=opt.init_lr) optimizer = Optimizer(optimizer, None, 0, opt.max_grad_norm) criterion = nn.NLLLoss(reduction='sum', ignore_index=PAD_token).to(device) trainer = SupervisedTrainer( optimizer=optimizer, criterion=criterion, trainset_list=trainset_list, validset=validset, num_workers=opt.num_workers, high_plateau_lr=opt.high_plateau_lr, low_plateau_lr=opt.low_plateau_lr, decay_threshold=opt.decay_threshold, exp_decay_period=opt.exp_decay_period, device=device, teacher_forcing_step=opt.teacher_forcing_step, min_teacher_forcing_ratio=opt.min_teacher_forcing_ratio, print_every=opt.print_every, save_result_every=opt.save_result_every, checkpoint_every=opt.checkpoint_every) model = trainer.train(model=model, batch_size=opt.batch_size, epoch_time_step=epoch_time_step, num_epochs=opt.num_epochs, teacher_forcing_ratio=opt.teacher_forcing_ratio, resume=opt.resume) Checkpoint(model, model.optimizer, model.criterion, model.trainset_list, model.validset, opt.num_epochs).save()