Ejemplo n.º 1
0
def torch_joint_resume(snapshot_path, trainer):
    """Resume from snapshot for pytorch.

    Args:
        snapshot_path (str): Snapshot file path.
        trainer (chainer.training.Trainer): Chainer's trainer instance.

    """
    from chainer.serializers import NpzDeserializer

    # load snapshot
    snapshot_dict = torch.load(snapshot_path,
                               map_location=lambda storage, loc: storage)

    # restore trainer states
    d = NpzDeserializer(snapshot_dict["trainer"])
    d.load(trainer)

    # restore asr model states
    if hasattr(trainer.updater.model, "model"):
        # (for TTS model)
        if hasattr(trainer.updater.model.model, "module"):
            trainer.updater.model.model.module.load_state_dict(
                snapshot_dict["asr_model"])
        else:
            trainer.updater.model.model.load_state_dict(
                snapshot_dict["asr_model"])
    else:
        # (for ASR model)
        if hasattr(trainer.updater.model, "module"):
            trainer.updater.model.module.load_state_dict(
                snapshot_dict["asr_model"])
        else:
            trainer.updater.model.load_state_dict(snapshot_dict["asr_model"])
    # restore tts model states
    if hasattr(trainer.updater.tts_model, "model"):
        # (for TTS model)
        if hasattr(trainer.updater.tts_model, "module"):
            trainer.updater.tts_model.module.load_state_dict(
                snapshot_dict["tts_model"])
        else:
            trainer.updater.tts_model.load_state_dict(
                snapshot_dict["tts_model"])
    else:
        # (for ASR model)
        if hasattr(trainer.updater.tts_model, "module"):
            trainer.updater.tts_model.module.load_state_dict(
                snapshot_dict["tts_model"])
        else:
            trainer.updater.tts_model.load_state_dict(
                snapshot_dict["tts_model"])

    # retore optimizer states
    trainer.updater.get_optimizer("main").load_state_dict(
        snapshot_dict["asr_optimizer"])
    trainer.updater.get_optimizer("tts").load_state_dict(
        snapshot_dict['tts_optimizer'])
    # delete opened snapshot
    del snapshot_dict
Ejemplo n.º 2
0
def load_model_from_trainer_npz(path, model):
    with np.load(path) as f:
        d = NpzDeserializer(f, path="updater/model:main/")
        d.load(model)