def load_and_setup_big_WN(model_name, parser, checkpoint, device, logger, forward_is_infer=False, ema=True, jitable=False): model_parser = models.parse_model_args(model_name, parser, add_help=False) model_args, model_unk_args = model_parser.parse_known_args() model_config = models.get_model_config(model_name, model_args) model = models.get_model(model_name, model_config, device, logger, forward_is_infer=forward_is_infer, jitable=jitable) if checkpoint is not None: checkpoint_data = torch.load(checkpoint, map_location="cpu") status = '' if 'state_dict' in checkpoint_data: # torch.save(checkpoint_data['state_dict'], f'{"/".join(checkpoint.split("/")[:-1])}/minimal-{checkpoint.split("/")[-1].split("-")[-1]}') sd = checkpoint_data['state_dict'] if any(key.startswith('module.') for key in sd): sd = {k.replace('module.', ''): v for k,v in sd.items()} status += ' ' + str(model.load_state_dict(sd, strict=True)) else: model = checkpoint_data['model'] print(f'Loaded {model_name}{status}') if model_name == "WaveGlow": model = model.remove_weightnorm(model) model.device = device model.eval() return model.to(device)
def load_and_setup_model(model_name, parser, checkpoint, device, logger, forward_is_infer=False, ema=True, jitable=False): model_parser = models.parse_model_args(model_name, parser, add_help=False) model_args, model_unk_args = model_parser.parse_known_args() model_config = models.get_model_config(model_name, model_args) model = models.get_model(model_name, model_config, device, logger, forward_is_infer=forward_is_infer, jitable=jitable) model.eval() if checkpoint is not None: checkpoint_data = torch.load(checkpoint, map_location="cpu") if 'state_dict' in checkpoint_data: sd = checkpoint_data['state_dict'] if any(key.startswith('module.') for key in sd): sd = {k.replace('module.', ''): v for k,v in sd.items()} TEMP_NUM_SPEAKERS = 5 symbols_embedding_dim = 384 model.speaker_emb = nn.Embedding(TEMP_NUM_SPEAKERS, symbols_embedding_dim).to(device) if "speaker_emb.weight" not in sd: sd["speaker_emb.weight"] = torch.rand((TEMP_NUM_SPEAKERS, 384)) model.load_state_dict(sd, strict=False) if 'model' in checkpoint_data: model = checkpoint_data['model'] else: model = checkpoint_data print(f'Loaded {model_name}') if model_name == "WaveGlow": model = model.remove_weightnorm(model) model.eval() model.device = device return model.to(device)