示例#1
0
文件: infer.py 项目: battyone/sonosco
def main(config_path, audio_path, plot):
    config = parse_yaml(config_path)["infer"]
    device = torch.device("cuda" if CUDA_ENABLED else "cpu")

    loader = Deserializer()
    model = loader.deserialize(TDSSeq2Seq, config["model_checkpoint_path"])
    model.to(device)
    model.eval()

    decoder = GreedyDecoder(model.decoder.labels)

    processor = AudioDataProcessor(**config)
    spect, lens = processor.parse_audio_for_inference(audio_path)
    spect = spect.to(device)

    # Watch out lens is modified after this call!
    # It is now equal to the number of encoded states
    with torch.no_grad():
        out, output_lens, attention = model(spect, lens)
        decoded_output, decoded_offsets = decoder.decode(out, output_lens)
        LOGGER.info(decoded_output)
        if plot:
            import matplotlib.pyplot as plt
            plt.matshow(attention[0].numpy())
            plt.show()
示例#2
0
def main(config_path):
    config = parse_yaml(config_path)["train"]
    experiment = Experiment.create(config, LOGGER)

    train_loader, val_loader, test_loader = create_data_loaders(**config)
    # Create mode
    if config.get('checkpoint_path'):
        LOGGER.info(
            f"Loading model from checkpoint: {config['checkpoint_path']}")
        loader = Deserializer()
        trainer: ModelTrainer = loader.deserialize(
            ModelTrainer, config["checkpoint_path"], {
                'train_data_loader': train_loader,
                'val_data_loader': val_loader,
                'test_data_loader': test_loader,
            })
    else:
        device = torch.device("cuda" if CUDA_ENABLED else "cpu")

        char_list = config["labels"] + EOS + SOS

        config["decoder"]["vocab_size"] = len(char_list)
        config["decoder"]["sos_id"] = char_list.index(SOS)
        config["decoder"]["eos_id"] = char_list.index(EOS)
        model = Seq2Seq(config["encoder"], config["decoder"])
        model.to(device)

        # Create data loaders

        # Create model trainer
        trainer = ModelTrainer(model,
                               loss=cross_entropy_loss,
                               epochs=config["max_epochs"],
                               train_data_loader=train_loader,
                               val_data_loader=val_loader,
                               test_data_loader=test_loader,
                               lr=config["learning_rate"],
                               weight_decay=config['weight_decay'],
                               metrics=[word_error_rate, character_error_rate],
                               decoder=GreedyDecoder(config['labels']),
                               device=device,
                               test_step=config["test_step"],
                               custom_model_eval=True)

        trainer.add_callback(
            LasTextComparisonCallback(labels=char_list,
                                      log_dir=experiment.plots_path,
                                      args=config['recognizer']))
        # trainer.add_callback(TbTeacherForcingTextComparisonCallback(log_dir=experiment.plots_path))

    # Setup experiment with a model trainer

    experiment.setup_model_trainer(trainer, checkpoints=True, tensorboard=True)
    try:
        experiment.start()
    except KeyboardInterrupt:
        experiment.stop()
示例#3
0
 def __init__(self, model_path: str) -> None:
     """
     Sonosco interface for Automatic speech recognition
     Args:
         model_path: path to model used in recognition
     """
     super().__init__()
     self.model_path = model_path
     self.loader = Deserializer()
示例#4
0
def main(config_path):
    config = parse_yaml(config_path)["train"]

    device = torch.device("cuda" if CUDA_ENABLED else "cpu")

    char_list = config["labels"] + EOS + SOS

    config["decoder"]["vocab_size"] = len(char_list)
    config["decoder"]["sos_id"] = char_list.index(SOS)
    config["decoder"]["eos_id"] = char_list.index(EOS)

    # Create mode
    if not config.get('checkpoint_path'):
        LOGGER.info("No checkpoint path specified")
        sys.exit(1)

    LOGGER.info(f"Loading model from checkpoint: {config['checkpoint_path']}")
    loader = Deserializer()
    model = loader.deserialize(Seq2Seq, config["checkpoint_path"])
    model.to(device)

    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(**config)

    # Create model trainer
    trainer = ModelTrainer(model,
                           loss=cross_entropy_loss,
                           epochs=config["max_epochs"],
                           train_data_loader=train_loader,
                           val_data_loader=val_loader,
                           test_data_loader=test_loader,
                           lr=config["learning_rate"],
                           weight_decay=config['weight_decay'],
                           metrics=[word_error_rate, character_error_rate],
                           decoder=GreedyDecoder(config['labels']),
                           device=device,
                           test_step=config["test_step"],
                           custom_model_eval=True)

    metrics = defaultdict()
    trainer._compute_validation_error(metrics)
    LOGGER.info(metrics)
示例#5
0
def main(config_path, audio_path):
    config = parse_yaml(config_path)["infer"]

    loader = Deserializer()
    model: Seq2Seq = loader.deserialize(Seq2Seq,
                                        config["model_checkpoint_path"])
    model.to(DEVICE)
    model.eval()

    decoder = GreedyDecoder(config["labels"])

    processor = AudioDataProcessor(**config)
    spect, lens = processor.parse_audio_for_inference(audio_path)
    spect = spect.to(DEVICE)

    with torch.no_grad():
        output = model.recognize(spect[0], lens, config["labels"],
                                 config["recognizer"])[0]
        transcription = decoder.convert_to_strings(
            torch.tensor([output['yseq']]))
        LOGGER.info(transcription)
示例#6
0
文件: train.py 项目: battyone/sonosco
def main(config_path):
    config = parse_yaml(config_path)["train"]
    experiment = Experiment.create(config, LOGGER)

    device = torch.device("cuda" if CUDA_ENABLED else "cpu")

    # Create model

    loader = Deserializer()
    if config.get('checkpoint_path'):
        LOGGER.info("Starting from checkpoint")
        model = loader.deserialize(TDSSeq2Seq, config["checkpoint_path"])
    else:
        model = TDSSeq2Seq(config["encoder"], config["decoder"])

    model.to(device)

    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(**config)

    # Create model trainer
    trainer = ModelTrainer(model, loss=cross_entropy_loss, epochs=config["max_epochs"],
                           train_data_loader=train_loader, val_data_loader=val_loader,
                           test_data_loader=test_loader,
                           lr=config["learning_rate"], custom_model_eval=True,
                           metrics=[word_error_rate, character_error_rate],
                           decoder=GreedyDecoder(config["decoder"]['labels']),
                           device=device, test_step=config["test_step"])

    trainer.add_callback(TbTextComparisonCallback(log_dir=experiment.plots_path))
    trainer.add_callback(TbTeacherForcingTextComparisonCallback(log_dir=experiment.plots_path))
    trainer.add_callback(DisableSoftWindowAttention())
    # Setup experiment with a model trainer
    experiment.setup_model_trainer(trainer, checkpoints=True, tensorboard=True)

    try:
        experiment.start()
    except KeyboardInterrupt:
        experiment.stop()
示例#7
0
def test_mode_trainer_serialization():
    config_path = "model_trainer_config_test.yaml"
    config = parse_yaml(config_path)["train"]

    device = torch.device("cuda" if CUDA_ENABLED else "cpu")

    char_list = config["labels"] + EOS + SOS

    config["decoder"]["vocab_size"] = len(char_list)
    config["decoder"]["sos_id"] = char_list.index(SOS)
    config["decoder"]["eos_id"] = char_list.index(EOS)

    # Create mode
    if config.get('checkpoint_path'):
        LOGGER.info(
            f"Loading model from checkpoint: {config['checkpoint_path']}")
        loader = Deserializer()
        model = loader.deserialize(Seq2Seq, config["checkpoint_path"])
    else:
        model = Seq2Seq(config["encoder"], config["decoder"])
    model.to(device)

    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(**config)

    # Create model trainer
    trainer = ModelTrainer(model,
                           loss=cross_entropy_loss,
                           epochs=config["max_epochs"],
                           train_data_loader=train_loader,
                           val_data_loader=val_loader,
                           test_data_loader=test_loader,
                           lr=config["learning_rate"],
                           weight_decay=config['weight_decay'],
                           metrics=[word_error_rate, character_error_rate],
                           decoder=GreedyDecoder(config['labels']),
                           device=device,
                           test_step=config["test_step"],
                           custom_model_eval=True)
    loader = Deserializer()
    s = Serializer()
    s.serialize(trainer, '/Users/w.jurasz/ser', config=config)
    trainer_deserialized, deserialized_config = loader.deserialize(
        ModelTrainer,
        '/Users/w.jurasz/ser', {
            'train_data_loader': train_loader,
            'val_data_loader': val_loader,
            'test_data_loader': test_loader,
        },
        with_config=True)
    assert trainer_deserialized is not None
    assert deserialized_config == config