Esempio n. 1
0
def train(opt):
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)
    device = check_envirionment(opt.use_cuda)

    if not opt.resume:
        audio_paths, script_paths = load_data_list(opt.data_list_path, opt.dataset_path)

        epoch_time_step, trainset_list, validset = split_dataset(opt, audio_paths, script_paths)
        model = build_model(opt, device)

        optimizer = optim.Adam(model.module.parameters(), lr=opt.init_lr, weight_decay=opt.weight_decay)

        if opt.rampup_period > 0:
            scheduler = RampUpLR(optimizer, opt.init_lr, opt.high_plateau_lr, opt.rampup_period)
            optimizer = Optimizer(optimizer, scheduler, opt.rampup_period, opt.max_grad_norm)
        else:
            optimizer = Optimizer(optimizer, None, 0, opt.max_grad_norm)

        criterion = LabelSmoothedCrossEntropyLoss(
            num_classes=len(char2id), ignore_index=PAD_token,
            smoothing=opt.label_smoothing, dim=-1,
            reduction=opt.reduction, architecture=opt.architecture
        ).to(device)

    else:
        trainset_list = None
        validset = None
        model = None
        optimizer = None
        criterion = LabelSmoothedCrossEntropyLoss(
            num_classes=len(char2id), ignore_index=PAD_token,
            smoothing=opt.label_smoothing, dim=-1,
            reduction=opt.reduction, architecture=opt.architecture
        ).to(device)
        epoch_time_step = None

    trainer = SupervisedTrainer(
        optimizer=optimizer, criterion=criterion, trainset_list=trainset_list,
        validset=validset, num_workers=opt.num_workers,
        high_plateau_lr=opt.high_plateau_lr, low_plateau_lr=opt.low_plateau_lr,
        decay_threshold=opt.decay_threshold, exp_decay_period=opt.exp_decay_period,
        device=device, teacher_forcing_step=opt.teacher_forcing_step,
        min_teacher_forcing_ratio=opt.min_teacher_forcing_ratio, print_every=opt.print_every,
        save_result_every=opt.save_result_every, checkpoint_every=opt.checkpoint_every,
        architecture=opt.architecture
    )
    model = trainer.train(
        model=model,
        batch_size=opt.batch_size,
        epoch_time_step=epoch_time_step,
        num_epochs=opt.num_epochs,
        teacher_forcing_ratio=opt.teacher_forcing_ratio,
        resume=opt.resume
    )
    return model
Esempio n. 2
0
def train(opt):
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)
    device = check_envirionment(opt.use_cuda)

    audio_paths, script_paths = load_data_list(opt.data_list_path,
                                               opt.dataset_path)

    epoch_time_step, trainset_list, validset = split_dataset(
        opt, audio_paths, script_paths)
    model = build_ensemble(['model_path1', 'model_path2', 'model_path3'],
                           opt.ensemble_method, device)

    optimizer = optim.Adam(model.module.parameters(), lr=opt.init_lr)
    optimizer = Optimizer(optimizer, None, 0, opt.max_grad_norm)
    criterion = nn.NLLLoss(reduction='sum', ignore_index=PAD_token).to(device)

    trainer = SupervisedTrainer(
        optimizer=optimizer,
        criterion=criterion,
        trainset_list=trainset_list,
        validset=validset,
        num_workers=opt.num_workers,
        high_plateau_lr=opt.high_plateau_lr,
        low_plateau_lr=opt.low_plateau_lr,
        decay_threshold=opt.decay_threshold,
        exp_decay_period=opt.exp_decay_period,
        device=device,
        teacher_forcing_step=opt.teacher_forcing_step,
        min_teacher_forcing_ratio=opt.min_teacher_forcing_ratio,
        print_every=opt.print_every,
        save_result_every=opt.save_result_every,
        checkpoint_every=opt.checkpoint_every)
    model = trainer.train(model=model,
                          batch_size=opt.batch_size,
                          epoch_time_step=epoch_time_step,
                          num_epochs=opt.num_epochs,
                          teacher_forcing_ratio=opt.teacher_forcing_ratio,
                          resume=opt.resume)
    Checkpoint(model, model.optimizer, model.criterion, model.trainset_list,
               model.validset, opt.num_epochs).save()