コード例 #1
0
ファイル: main.py プロジェクト: Rhcsky/KoSpeech
def train(config: DictConfig) -> nn.DataParallel:
    random.seed(config.train.seed)
    torch.manual_seed(config.train.seed)
    torch.cuda.manual_seed_all(config.train.seed)
    device = check_envirionment(config.train.use_cuda)
    if hasattr(config.train,
               "num_threads") and int(config.train.num_threads) > 0:
        torch.set_num_threads(config.train.num_threads)

    vocab = KsponSpeechVocabulary(
        f'/home/seungmin/dmount/KoSpeech/data/vocab/aihub_{config.train.output_unit}_vocabs.csv',
        output_unit=config.train.output_unit,
    )

    if not config.train.resume:
        epoch_time_step, trainset_list, validset = split_dataset(
            config, config.train.transcripts_path, vocab)
        model = build_model(config, vocab, device)

        optimizer = get_optimizer(model, config)
        lr_scheduler = get_lr_scheduler(config, optimizer, epoch_time_step)

        optimizer = Optimizer(optimizer, lr_scheduler,
                              config.train.total_steps,
                              config.train.max_grad_norm)
        criterion = get_criterion(config, vocab)

    else:
        trainset_list = None
        validset = None
        model = None
        optimizer = None
        epoch_time_step = None
        criterion = get_criterion(config, vocab)

    trainer = SupervisedTrainer(
        optimizer=optimizer,
        criterion=criterion,
        trainset_list=trainset_list,
        validset=validset,
        num_workers=config.train.num_workers,
        device=device,
        teacher_forcing_step=config.model.teacher_forcing_step,
        min_teacher_forcing_ratio=config.model.min_teacher_forcing_ratio,
        print_every=config.train.print_every,
        save_result_every=config.train.save_result_every,
        checkpoint_every=config.train.checkpoint_every,
        architecture=config.model.architecture,
        vocab=vocab,
        joint_ctc_attention=config.model.joint_ctc_attention,
    )
    model = trainer.train(
        model=model,
        batch_size=config.train.batch_size,
        epoch_time_step=epoch_time_step,
        num_epochs=config.train.num_epochs,
        teacher_forcing_ratio=config.model.teacher_forcing_ratio,
        resume=config.train.resume,
    )
    return model
コード例 #2
0
def train(config: DictConfig) -> nn.DataParallel:
    random.seed(config.train.seed)
    torch.manual_seed(config.train.seed)
    torch.cuda.manual_seed_all(config.train.seed)
    device = check_envirionment(config.train.use_cuda)

    if config.train.dataset == 'kspon':
        if config.train.output_unit == 'subword':
            vocab = KsponSpeechVocabulary(
                vocab_path=KSPONSPEECH_VOCAB_PATH,
                output_unit=config.train.output_unit,
                sp_model_path=KSPONSPEECH_SP_MODEL_PATH,
            )
        else:
            vocab = KsponSpeechVocabulary(
                f'../../../data/vocab/aihub_{config.train.output_unit}_vocabs.csv',
                output_unit=config.train.output_unit,
            )

    elif config.train.dataset == 'libri':
        vocab = LibriSpeechVocabulary(LIBRISPEECH_VOCAB_PATH, LIBRISPEECH_TOKENIZER_PATH)

    else:
        raise ValueError("Unsupported Dataset : {0}".format(config.train.dataset))

    if not config.train.resume:
        epoch_time_step, trainset_list, validset = split_dataset(config, config.train.transcripts_path, vocab)
        model = build_model(config, vocab, device)

        optimizer = get_optimizer(model, config)
        lr_scheduler = get_lr_scheduler(config, optimizer, epoch_time_step)

        optimizer = Optimizer(optimizer, lr_scheduler, config.train.warmup_steps, config.train.max_grad_norm)
        criterion = get_criterion(config, vocab)

    else:
        trainset_list = None
        validset = None
        model = None
        optimizer = None
        epoch_time_step = None
        criterion = get_criterion(config, vocab)

    trainer = SupervisedTrainer(
        optimizer=optimizer,
        criterion=criterion,
        trainset_list=trainset_list,
        validset=validset,
        num_workers=config.train.num_workers,
        device=device,
        teacher_forcing_step=config.model.teacher_forcing_step,
        min_teacher_forcing_ratio=config.model.min_teacher_forcing_ratio,
        print_every=config.train.print_every,
        save_result_every=config.train.save_result_every,
        checkpoint_every=config.train.checkpoint_every,
        architecture=config.model.architecture,
        vocab=vocab,
        joint_ctc_attention=config.model.joint_ctc_attention,
    )
    model = trainer.train(
        model=model,
        batch_size=config.train.batch_size,
        epoch_time_step=epoch_time_step,
        num_epochs=config.train.num_epochs,
        teacher_forcing_ratio=config.model.teacher_forcing_ratio,
        resume=config.train.resume,
    )
    return model
コード例 #3
0
def train(config: DictConfig):
    random.seed(config.train.seed)
    torch.manual_seed(config.train.seed)
    torch.cuda.manual_seed_all(config.train.seed)
    device = check_envirionment(config.train.use_cuda)

    if config.train.dataset == 'kspon':
        if config.train.output_unit == 'subword':
            vocab = KsponSpeechVocabulary(
                vocab_path='../../../data/vocab/kspon_sentencepiece.vocab',
                output_unit=config.train.output_unit,
                sp_model_path='../../../data/vocab/kspon_sentencepiece.model',
            )
        else:
            vocab = KsponSpeechVocabulary(
                f'../../../data/vocab/aihub_{config.train.output_unit}_vocabs.csv',
                output_unit=config.train.output_unit)

    elif config.train.dataset == 'libri':
        vocab = LibriSpeechVocabulary('../../../data/vocab/tokenizer.vocab',
                                      '../../../data/vocab/tokenizer.model')

    else:
        raise ValueError("Unsupported Dataset : {0}".format(
            config.train.dataset))

    if not config.train.resume:
        epoch_time_step, trainset_list, validset = split_dataset(
            config, config.train.transcripts_path, vocab)
        model = build_model(config, vocab, device)

        optimizer = get_optimizer(model, config)

        lr_scheduler = TriStageLRScheduler(
            optimizer=optimizer,
            init_lr=config.train.init_lr,
            peak_lr=config.train.peak_lr,
            final_lr=config.train.final_lr,
            init_lr_scale=config.train.init_lr_scale,
            final_lr_scale=config.train.final_lr_scale,
            warmup_steps=config.train.warmup_steps,
            total_steps=int(config.train.num_epochs * epoch_time_step))
        optimizer = Optimizer(optimizer, lr_scheduler,
                              config.train.warmup_steps,
                              config.train.max_grad_norm)
        criterion = get_criterion(config, vocab)

    else:
        trainset_list = None
        validset = None
        model = None
        optimizer = None
        epoch_time_step = None
        criterion = get_criterion(config, vocab)

    trainer = SupervisedTrainer(
        optimizer=optimizer,
        criterion=criterion,
        trainset_list=trainset_list,
        validset=validset,
        num_workers=config.train.num_workers,
        device=device,
        teacher_forcing_step=config.model.teacher_forcing_step,
        min_teacher_forcing_ratio=config.model.min_teacher_forcing_ratio,
        print_every=config.train.print_every,
        save_result_every=config.train.save_result_every,
        checkpoint_every=config.train.checkpoint_every,
        architecture=config.model.architecture,
        vocab=vocab,
        joint_ctc_attention=config.model.joint_ctc_attention,
    )
    model = trainer.train(
        model=model,
        batch_size=config.train.batch_size,
        epoch_time_step=epoch_time_step,
        num_epochs=config.train.num_epochs,
        teacher_forcing_ratio=config.model.teacher_forcing_ratio,
        resume=config.train.resume,
    )
    return model