コード例 #1
0
def load_and_setup_big_WN(model_name, parser, checkpoint, device, logger, forward_is_infer=False, ema=True, jitable=False):

    model_parser = models.parse_model_args(model_name, parser, add_help=False)
    model_args, model_unk_args = model_parser.parse_known_args()

    model_config = models.get_model_config(model_name, model_args)

    model = models.get_model(model_name, model_config, device, logger, forward_is_infer=forward_is_infer, jitable=jitable)

    if checkpoint is not None:
        checkpoint_data = torch.load(checkpoint, map_location="cpu")
        status = ''

        if 'state_dict' in checkpoint_data:
            # torch.save(checkpoint_data['state_dict'], f'{"/".join(checkpoint.split("/")[:-1])}/minimal-{checkpoint.split("/")[-1].split("-")[-1]}')
            sd = checkpoint_data['state_dict']

            if any(key.startswith('module.') for key in sd):
                sd = {k.replace('module.', ''): v for k,v in sd.items()}
            status += ' ' + str(model.load_state_dict(sd, strict=True))
        else:
            model = checkpoint_data['model']
        print(f'Loaded {model_name}{status}')

    if model_name == "WaveGlow":
        model = model.remove_weightnorm(model)

    model.device = device
    model.eval()
    return model.to(device)
コード例 #2
0
def load_and_setup_model(model_name, parser, checkpoint, device, logger, forward_is_infer=False, ema=True, jitable=False):
    model_parser = models.parse_model_args(model_name, parser, add_help=False)
    model_args, model_unk_args = model_parser.parse_known_args()

    model_config = models.get_model_config(model_name, model_args)
    model = models.get_model(model_name, model_config, device, logger, forward_is_infer=forward_is_infer, jitable=jitable)
    model.eval()

    if checkpoint is not None:
        checkpoint_data = torch.load(checkpoint, map_location="cpu")

        if 'state_dict' in checkpoint_data:
            sd = checkpoint_data['state_dict']
            if any(key.startswith('module.') for key in sd):
                sd = {k.replace('module.', ''): v for k,v in sd.items()}

            TEMP_NUM_SPEAKERS = 5
            symbols_embedding_dim = 384
            model.speaker_emb = nn.Embedding(TEMP_NUM_SPEAKERS, symbols_embedding_dim).to(device)
            if "speaker_emb.weight" not in sd:
                sd["speaker_emb.weight"] = torch.rand((TEMP_NUM_SPEAKERS, 384))

            model.load_state_dict(sd, strict=False)
        if 'model' in checkpoint_data:
            model = checkpoint_data['model']
        else:
            model = checkpoint_data
        print(f'Loaded {model_name}')

    if model_name == "WaveGlow":
        model = model.remove_weightnorm(model)
    model.eval()
    model.device = device
    return model.to(device)