Exemplo n.º 1
0
def train(config: DictConfig) -> nn.DataParallel:
    random.seed(config.train.seed)
    torch.manual_seed(config.train.seed)
    torch.cuda.manual_seed_all(config.train.seed)
    device = check_envirionment(config.train.use_cuda)
    if hasattr(config.train,
               "num_threads") and int(config.train.num_threads) > 0:
        torch.set_num_threads(config.train.num_threads)

    vocab = KsponSpeechVocabulary(
        f'/home/seungmin/dmount/KoSpeech/data/vocab/aihub_{config.train.output_unit}_vocabs.csv',
        output_unit=config.train.output_unit,
    )

    if not config.train.resume:
        epoch_time_step, trainset_list, validset = split_dataset(
            config, config.train.transcripts_path, vocab)
        model = build_model(config, vocab, device)

        optimizer = get_optimizer(model, config)
        lr_scheduler = get_lr_scheduler(config, optimizer, epoch_time_step)

        optimizer = Optimizer(optimizer, lr_scheduler,
                              config.train.total_steps,
                              config.train.max_grad_norm)
        criterion = get_criterion(config, vocab)

    else:
        trainset_list = None
        validset = None
        model = None
        optimizer = None
        epoch_time_step = None
        criterion = get_criterion(config, vocab)

    trainer = SupervisedTrainer(
        optimizer=optimizer,
        criterion=criterion,
        trainset_list=trainset_list,
        validset=validset,
        num_workers=config.train.num_workers,
        device=device,
        teacher_forcing_step=config.model.teacher_forcing_step,
        min_teacher_forcing_ratio=config.model.min_teacher_forcing_ratio,
        print_every=config.train.print_every,
        save_result_every=config.train.save_result_every,
        checkpoint_every=config.train.checkpoint_every,
        architecture=config.model.architecture,
        vocab=vocab,
        joint_ctc_attention=config.model.joint_ctc_attention,
    )
    model = trainer.train(
        model=model,
        batch_size=config.train.batch_size,
        epoch_time_step=epoch_time_step,
        num_epochs=config.train.num_epochs,
        teacher_forcing_ratio=config.model.teacher_forcing_ratio,
        resume=config.train.resume,
    )
    return model
Exemplo n.º 2
0
def train(config: DictConfig) -> nn.DataParallel:
    random.seed(config.train.seed)
    torch.manual_seed(config.train.seed)
    torch.cuda.manual_seed_all(config.train.seed)
    device = check_envirionment(config.train.use_cuda)

    if config.train.dataset == 'kspon':
        if config.train.output_unit == 'subword':
            vocab = KsponSpeechVocabulary(
                vocab_path=KSPONSPEECH_VOCAB_PATH,
                output_unit=config.train.output_unit,
                sp_model_path=KSPONSPEECH_SP_MODEL_PATH,
            )
        else:
            vocab = KsponSpeechVocabulary(
                f'../../../data/vocab/aihub_{config.train.output_unit}_vocabs.csv',
                output_unit=config.train.output_unit,
            )

    elif config.train.dataset == 'libri':
        vocab = LibriSpeechVocabulary(LIBRISPEECH_VOCAB_PATH, LIBRISPEECH_TOKENIZER_PATH)

    else:
        raise ValueError("Unsupported Dataset : {0}".format(config.train.dataset))

    if not config.train.resume:
        epoch_time_step, trainset_list, validset = split_dataset(config, config.train.transcripts_path, vocab)
        model = build_model(config, vocab, device)

        optimizer = get_optimizer(model, config)
        lr_scheduler = get_lr_scheduler(config, optimizer, epoch_time_step)

        optimizer = Optimizer(optimizer, lr_scheduler, config.train.warmup_steps, config.train.max_grad_norm)
        criterion = get_criterion(config, vocab)

    else:
        trainset_list = None
        validset = None
        model = None
        optimizer = None
        epoch_time_step = None
        criterion = get_criterion(config, vocab)

    trainer = SupervisedTrainer(
        optimizer=optimizer,
        criterion=criterion,
        trainset_list=trainset_list,
        validset=validset,
        num_workers=config.train.num_workers,
        device=device,
        teacher_forcing_step=config.model.teacher_forcing_step,
        min_teacher_forcing_ratio=config.model.min_teacher_forcing_ratio,
        print_every=config.train.print_every,
        save_result_every=config.train.save_result_every,
        checkpoint_every=config.train.checkpoint_every,
        architecture=config.model.architecture,
        vocab=vocab,
        joint_ctc_attention=config.model.joint_ctc_attention,
    )
    model = trainer.train(
        model=model,
        batch_size=config.train.batch_size,
        epoch_time_step=epoch_time_step,
        num_epochs=config.train.num_epochs,
        teacher_forcing_ratio=config.model.teacher_forcing_ratio,
        resume=config.train.resume,
    )
    return model
Exemplo n.º 3
0
def train(config: DictConfig):
    random.seed(config.train.seed)
    torch.manual_seed(config.train.seed)
    torch.cuda.manual_seed_all(config.train.seed)
    device = check_envirionment(config.train.use_cuda)

    if config.train.dataset == 'kspon':
        if config.train.output_unit == 'subword':
            vocab = KsponSpeechVocabulary(
                vocab_path='../../../data/vocab/kspon_sentencepiece.vocab',
                output_unit=config.train.output_unit,
                sp_model_path='../../../data/vocab/kspon_sentencepiece.model',
            )
        else:
            vocab = KsponSpeechVocabulary(
                f'../../../data/vocab/aihub_{config.train.output_unit}_vocabs.csv',
                output_unit=config.train.output_unit)

    elif config.train.dataset == 'libri':
        vocab = LibriSpeechVocabulary('../../../data/vocab/tokenizer.vocab',
                                      '../../../data/vocab/tokenizer.model')

    else:
        raise ValueError("Unsupported Dataset : {0}".format(
            config.train.dataset))

    if not config.train.resume:
        epoch_time_step, trainset_list, validset = split_dataset(
            config, config.train.transcripts_path, vocab)
        model = build_model(config, vocab, device)

        optimizer = get_optimizer(model, config)

        lr_scheduler = TriStageLRScheduler(
            optimizer=optimizer,
            init_lr=config.train.init_lr,
            peak_lr=config.train.peak_lr,
            final_lr=config.train.final_lr,
            init_lr_scale=config.train.init_lr_scale,
            final_lr_scale=config.train.final_lr_scale,
            warmup_steps=config.train.warmup_steps,
            total_steps=int(config.train.num_epochs * epoch_time_step))
        optimizer = Optimizer(optimizer, lr_scheduler,
                              config.train.warmup_steps,
                              config.train.max_grad_norm)
        criterion = get_criterion(config, vocab)

    else:
        trainset_list = None
        validset = None
        model = None
        optimizer = None
        epoch_time_step = None
        criterion = get_criterion(config, vocab)

    trainer = SupervisedTrainer(
        optimizer=optimizer,
        criterion=criterion,
        trainset_list=trainset_list,
        validset=validset,
        num_workers=config.train.num_workers,
        device=device,
        teacher_forcing_step=config.model.teacher_forcing_step,
        min_teacher_forcing_ratio=config.model.min_teacher_forcing_ratio,
        print_every=config.train.print_every,
        save_result_every=config.train.save_result_every,
        checkpoint_every=config.train.checkpoint_every,
        architecture=config.model.architecture,
        vocab=vocab,
        joint_ctc_attention=config.model.joint_ctc_attention,
    )
    model = trainer.train(
        model=model,
        batch_size=config.train.batch_size,
        epoch_time_step=epoch_time_step,
        num_epochs=config.train.num_epochs,
        teacher_forcing_ratio=config.model.teacher_forcing_ratio,
        resume=config.train.resume,
    )
    return model
Exemplo n.º 4
0
def train(opt):
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)
    device = check_envirionment(opt.use_cuda)

    if opt.dataset == 'kspon':
        if opt.output_unit == 'subword':
            vocab = KsponSpeechVocabulary(
                vocab_path='../data/vocab/kspon_sentencepiece.vocab',
                output_unit=opt.output_unit,
                sp_model_path='../data/vocab/kspon_sentencepiece.model')
        else:
            vocab = KsponSpeechVocabulary(
                f'../data/vocab/aihub_{opt.output_unit}_vocabs.csv',
                output_unit=opt.output_unit)

    elif opt.dataset == 'libri':
        vocab = LibriSpeechVocabulary('../data/vocab/tokenizer.vocab',
                                      '../data/vocab/tokenizer.model')

    else:
        raise ValueError("Unsupported Dataset : {0}".format(opt.dataset))

    if not opt.resume:
        epoch_time_step, trainset_list, validset = split_dataset(
            opt, opt.transcripts_path, vocab)
        model = build_model(opt, vocab, device)

        if opt.optimizer.lower() == 'adam':
            optimizer = optim.Adam(model.module.parameters(),
                                   lr=opt.init_lr,
                                   weight_decay=opt.weight_decay)
        elif opt.optimizer.lower() == 'radam':
            optimizer = RAdam(model.module.parameters(),
                              lr=opt.init_lr,
                              weight_decay=opt.weight_decay)
        elif opt.optimizer.lower() == 'adamp':
            optimizer = AdamP(model.module.parameters(),
                              lr=opt.init_lr,
                              weight_decay=opt.weight_decay)
        elif opt.optimizer.lower() == 'adadelta':
            optimizer = optim.Adadelta(model.module.parameters(),
                                       lr=opt.init_lr,
                                       weight_decay=opt.weight_decay)
        elif opt.optimizer.lower() == 'adagrad':
            optimizer = optim.Adagrad(model.module.parameters(),
                                      lr=opt.init_lr,
                                      weight_decay=opt.weight_decay)
        else:
            raise ValueError(
                f"Unsupported Optimizer, Supported Optimizer : Adam, RAdam, Adadelta, Adagrad"
            )

        lr_scheduler = TriStageLRScheduler(optimizer=optimizer,
                                           init_lr=opt.init_lr,
                                           peak_lr=opt.peak_lr,
                                           final_lr=opt.final_lr,
                                           init_lr_scale=opt.init_lr_scale,
                                           final_lr_scale=opt.final_lr_scale,
                                           warmup_steps=opt.warmup_steps,
                                           total_steps=int(opt.num_epochs *
                                                           epoch_time_step))
        optimizer = Optimizer(optimizer, lr_scheduler, opt.warmup_steps,
                              opt.max_grad_norm)

        if opt.architecture == 'deepspeech2':
            criterion = nn.CTCLoss(blank=vocab.blank_id,
                                   reduction=opt.reduction).to(device)
        else:
            criterion = LabelSmoothedCrossEntropyLoss(
                num_classes=len(vocab),
                ignore_index=vocab.pad_id,
                smoothing=opt.label_smoothing,
                dim=-1,
                reduction=opt.reduction,
                architecture=opt.architecture).to(device)

    else:
        trainset_list = None
        validset = None
        model = None
        optimizer = None
        criterion = LabelSmoothedCrossEntropyLoss(
            num_classes=len(vocab),
            ignore_index=vocab.pad_id,
            smoothing=opt.label_smoothing,
            dim=-1,
            reduction=opt.reduction,
            architecture=opt.architecture).to(device)
        epoch_time_step = None

    trainer = SupervisedTrainer(
        optimizer=optimizer,
        criterion=criterion,
        trainset_list=trainset_list,
        validset=validset,
        num_workers=opt.num_workers,
        device=device,
        teacher_forcing_step=opt.teacher_forcing_step,
        min_teacher_forcing_ratio=opt.min_teacher_forcing_ratio,
        print_every=opt.print_every,
        save_result_every=opt.save_result_every,
        checkpoint_every=opt.checkpoint_every,
        architecture=opt.architecture,
        vocab=vocab)
    model = trainer.train(model=model,
                          batch_size=opt.batch_size,
                          epoch_time_step=epoch_time_step,
                          num_epochs=opt.num_epochs,
                          teacher_forcing_ratio=opt.teacher_forcing_ratio,
                          resume=opt.resume)
    return model