Esempio n. 1
0
    def _create_random_model(self):
        # pylint: disable=global-statement
        global symbols, phonemes
        config = load_config(
            os.path.join(get_tests_output_path(), 'dummy_model_config.json'))
        if 'characters' in config.keys():
            symbols, phonemes = make_symbols(**config.characters)

        num_chars = len(phonemes) if config.use_phonemes else len(symbols)
        model = setup_model(num_chars, 0, config)
        output_path = os.path.join(get_tests_output_path())
        save_checkpoint(model, None, 10, 10, 1, output_path)
Esempio n. 2
0
    def load_tts(self, tts_checkpoint, tts_config, use_cuda):
        # pylint: disable=global-statement
        global symbols, phonemes

        print(" > Loading TTS model ...")
        print(" | > model config: ", tts_config)
        print(" | > checkpoint file: ", tts_checkpoint)

        self.tts_config = load_config(tts_config)
        self.use_phonemes = self.tts_config.use_phonemes
        self.ap = AudioProcessor(**self.tts_config.audio)

        if 'characters' in self.tts_config.keys():
            symbols, phonemes = make_symbols(**self.tts_config.characters)

        if self.use_phonemes:
            self.input_size = len(phonemes)
        else:
            self.input_size = len(symbols)
        # TODO: fix this for multi-speaker model - load speakers
        if self.config.tts_speakers is not None:
            self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
            num_speakers = len(self.tts_speakers)
        else:
            num_speakers = 0
        self.tts_model = setup_model(self.input_size,
                                     num_speakers=num_speakers,
                                     c=self.tts_config)
        # load model state
        cp = torch.load(tts_checkpoint, map_location=torch.device('cpu'))
        # load the model
        self.tts_model.load_state_dict(cp['model'])
        if use_cuda:
            self.tts_model.cuda()
        self.tts_model.eval()
        self.tts_model.decoder.max_decoder_steps = 3000
        if 'r' in cp:
            self.tts_model.decoder.set_r(cp['r'])
            print(f" > model reduction factor: {cp['r']}")
Esempio n. 3
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global meta_data_train, meta_data_eval, symbols, phonemes
    # Audio processor
    ap = AudioProcessor(**c.audio)
    if 'characters' in c.keys():
        symbols, phonemes = make_symbols(**c.characters)

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])
    num_chars = len(phonemes) if c.use_phonemes else len(symbols)

    # load data instances
    meta_data_train, meta_data_eval = load_meta_data(c.datasets)

    # set the portion of the data used for training
    if 'train_portion' in c.keys():
        meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
    if 'eval_portion' in c.keys():
        meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]

    # parse speakers
    if c.use_speaker_embedding:
        speakers = get_speakers(meta_data_train)
        if args.restore_path:
            if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
                prev_out_path = os.path.dirname(args.restore_path)
                speaker_mapping = load_speaker_mapping(prev_out_path)
                if not speaker_mapping:
                    print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
                    speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
                    if not speaker_mapping:
                        raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
                speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
            elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
                prev_out_path = os.path.dirname(args.restore_path)
                speaker_mapping = load_speaker_mapping(prev_out_path)
                speaker_embedding_dim = None
                assert all([speaker in speaker_mapping
                            for speaker in speakers]), "As of now you, you cannot " \
                                                    "introduce new speakers to " \
                                                    "a previously trained model."
        elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
            speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
            speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
        elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
            raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
        else: # if start new train and don't use External Embedding file
            speaker_mapping = {name: i for i, name in enumerate(speakers)}
            speaker_embedding_dim = None
        save_speaker_mapping(OUT_PATH, speaker_mapping)
        num_speakers = len(speaker_mapping)
        print("Training with {} speakers: {}".format(num_speakers,
                                                     ", ".join(speakers)))
    else:
        num_speakers = 0
        speaker_embedding_dim = None
        speaker_mapping = None

    model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)

    params = set_weight_decay(model, c.wd)
    optimizer = RAdam(params, lr=c.lr, weight_decay=0)
    if c.stopnet and c.separate_stopnet:
        optimizer_st = RAdam(model.decoder.stopnet.parameters(),
                             lr=c.lr,
                             weight_decay=0)
    else:
        optimizer_st = None

    if c.apex_amp_level == "O1":
        # pylint: disable=import-outside-toplevel
        from apex import amp
        model.cuda()
        model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
    else:
        amp = None

    # setup criterion
    criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)

    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            # TODO: fix optimizer init, model.cuda() needs to be called before
            # optimizer restore
            # optimizer.load_state_dict(checkpoint['optimizer'])
            if c.reinit_layers:
                raise RuntimeError
            model.load_state_dict(checkpoint['model'])
        except KeyError:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model'], c)
            # torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
            # print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
            model.load_state_dict(model_dict)
            del model_dict

        if amp and 'amp' in checkpoint:
            amp.load_state_dict(checkpoint['amp'])

        for group in optimizer.param_groups:
            group['lr'] = c.lr
        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model.cuda()
        criterion.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)

    if c.noam_schedule:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.warmup_steps,
                           last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        # set gradual training
        if c.gradual_training is not None:
            r, c.batch_size = gradual_training_scheduler(global_step, c)
            c.r = r
            model.decoder.set_r(r)
            if c.bidirectional_decoder:
                model.decoder_backward.set_r(r)
            print("\n > Number of output frames:", model.decoder.r)
        train_avg_loss_dict, global_step = train(model, criterion, optimizer,
                                                 optimizer_st, scheduler, ap,
                                                 global_step, epoch, amp, speaker_mapping)
        eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = train_avg_loss_dict['avg_postnet_loss']
        if c.run_eval:
            target_loss = eval_avg_loss_dict['avg_postnet_loss']
        best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
                                    OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
Esempio n. 4
0
    num_speakers = 0

    # load speakers
    if args.speakers_json != '':
        speaker_mapping = json.load(open(args.speakers_json, 'r'))
        num_speakers = len(speaker_mapping)
        if C.use_external_speaker_embedding_file:
            if args.speaker_fileid is not None:
                speaker_embedding = speaker_mapping[args.speaker_fileid]['embedding']
            else: # if speaker_fileid is not specificated use the first sample in speakers.json
                speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']
            speaker_embedding_dim = len(speaker_embedding)

    # load the model
    num_chars = len(phonemes) if C.use_phonemes else len(symbols)
    model = setup_model(num_chars, num_speakers, C, speaker_embedding_dim)
    cp = torch.load(args.model_path, map_location=torch.device('cpu'))
    model.load_state_dict(cp['model'])
    model.eval()
    if args.use_cuda:
        model.cuda()
    model.decoder.set_r(cp['r'])

    # load vocoder model
    if args.vocoder_path != "":
        VC = load_config(args.vocoder_config_path)
        vocoder_model = setup_generator(VC)
        vocoder_model.load_state_dict(torch.load(args.vocoder_path, map_location="cpu")["model"])
        vocoder_model.remove_weight_norm()
        if args.use_cuda:
            vocoder_model.cuda()
                    type=str,
                    help='Path to config file of torch model.')
parser.add_argument(
    '--output_path',
    type=str,
    help='path to output file including file name to save TF model.')
args = parser.parse_args()

# load model config
config_path = args.config_path
c = load_config(config_path)
num_speakers = 0

# init torch model
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, c)
checkpoint = torch.load(args.torch_model_path,
                        map_location=torch.device('cpu'))
state_dict = checkpoint['model']
model.load_state_dict(state_dict)

# init tf model
model_tf = Tacotron2(num_chars=num_chars,
                     num_speakers=num_speakers,
                     r=model.decoder.r,
                     postnet_output_dim=c.audio['num_mels'],
                     decoder_output_dim=c.audio['num_mels'],
                     attn_type=c.attention_type,
                     attn_win=c.windowing,
                     attn_norm=c.attention_norm,
                     prenet_type=c.prenet_type,