Exemplo n.º 1
0
def main(cfg):
    if hasattr(cfg.model.optim, 'sched'):
        logging.warning(
            "You are using an optimizer scheduler while finetuning. Are you sure this is intended?"
        )
    if cfg.model.optim.lr > 1e-3 or cfg.model.optim.lr < 1e-5:
        logging.warning("The recommended learning rate for finetuning is 2e-4")
    trainer = pl.Trainer(**cfg.trainer)
    exp_manager(trainer, cfg.get("exp_manager", None))
    model = FastPitchModel(cfg=cfg.model, trainer=trainer)
    model.maybe_init_from_pretrained_checkpoint(cfg=cfg)
    lr_logger = pl.callbacks.LearningRateMonitor()
    epoch_time_logger = LogEpochTimeCallback()
    trainer.callbacks.extend([lr_logger, epoch_time_logger])
    trainer.fit(model)
Exemplo n.º 2
0
def fastpitch_model():
    test_root = os.path.dirname(os.path.abspath(__file__))
    conf = OmegaConf.load(os.path.join(test_root, '../../../examples/tts/conf/fastpitch.yaml'))
    conf.train_dataset = conf.validation_datasets = '.'
    conf.model.train_ds = conf.model.test_ds = conf.model.validation_ds = None
    model = FastPitchModel(cfg=conf.model)
    return model
Exemplo n.º 3
0
def main(cfg):
    trainer = pl.Trainer(**cfg.trainer)
    exp_manager(trainer, cfg.get("exp_manager", None))
    model = FastPitchModel(cfg=cfg.model, trainer=trainer)
    lr_logger = pl.callbacks.LearningRateMonitor()
    epoch_time_logger = LogEpochTimeCallback()
    trainer.callbacks.extend([lr_logger, epoch_time_logger])
    trainer.fit(model)
Exemplo n.º 4
0
def fastpitch_model():
    model = FastPitchModel.from_pretrained(model_name="tts_en_fastpitch")
    return model