Exemple #1
0
 def __init__(self, path, INPUT_SR, TARGET_SR, WINDOW_LENGTH):
     self.INPUT_SR = INPUT_SR
     self.TARGET_SR = TARGET_SR
     self.WINDOW_LENGTH = WINDOW_LENGTH
     self.CONFIG = load_config('config_fr.json')
     self.CONFIG['audio']['sample_rate'] = self.INPUT_SR
     self.AP_INPUT = AudioProcessor(**self.CONFIG['audio'])
     self.CONFIG['audio']['sample_rate'] = self.TARGET_SR
     self.AP_TARGET = AudioProcessor(**self.CONFIG['audio'])
     self.files = glob.glob(path + '/**/*.wav', recursive=True)
     #If you change your dataset, delete cache.json
     if os.path.isfile('./cache.json'):
         with open('./cache.json', "r") as json_file:
             self.pre_repertoir = json.load(json_file)
     else:
         print("> Computing wave files length...")
         self.pre_repertoir = [
             librosa.get_duration(filename=file)
             for file in tqdm(self.files)
         ]
         with open('./cache.json', mode="w") as json_file:
             json.dump(self.pre_repertoir, json_file)
     self.repertoir = [
         int(item / WINDOW_LENGTH) for item in self.pre_repertoir
     ]
     self.length = self.get_len()
Exemple #2
0
    def load_tts(self, tts_checkpoint, tts_config, use_cuda):
        # pylint: disable=global-statement
        global symbols, phonemes

        print(" > Loading TTS model ...")
        print(" | > model config: ", tts_config)
        print(" | > checkpoint file: ", tts_checkpoint)

        self.tts_config = load_config(tts_config)
        self.use_phonemes = self.tts_config.use_phonemes
        self.ap = AudioProcessor(**self.tts_config.audio)

        if 'characters' in self.tts_config.keys():
            symbols, phonemes = make_symbols(**self.tts_config.characters)

        if self.use_phonemes:
            self.input_size = len(phonemes)
        else:
            self.input_size = len(symbols)
        # TODO: fix this for multi-speaker model - load speakers
        if self.config.tts_speakers is not None:
            self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
            num_speakers = len(self.tts_speakers)
        else:
            num_speakers = 0
        self.tts_model = setup_model(self.input_size,
                                     num_speakers=num_speakers,
                                     c=self.tts_config)
        # load model state
        cp = torch.load(tts_checkpoint, map_location=torch.device('cpu'))
        # load the model
        self.tts_model.load_state_dict(cp['model'])
        if use_cuda:
            self.tts_model.cuda()
        self.tts_model.eval()
        self.tts_model.decoder.max_decoder_steps = 3000
        if 'r' in cp:
            self.tts_model.decoder.set_r(cp['r'])
            print(f" > model reduction factor: {cp['r']}")
Exemple #3
0
    def test_scaler(self):
        scaler_stats_path = os.path.join(get_tests_input_path(), 'scale_stats.npy')
        conf.audio['stats_path'] = scaler_stats_path
        conf.audio['preemphasis'] = 0.0
        conf.audio['do_trim_silence'] = True
        conf.audio['signal_norm'] = True

        ap = AudioProcessor(**conf.audio)
        mel_mean, mel_std, linear_mean, linear_std, _ = ap.load_stats(scaler_stats_path)
        ap.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)

        self.ap.signal_norm = False
        self.ap.preemphasis = 0.0

        # test scaler forward and backward transforms
        wav = self.ap.load_wav(WAV_FILE)
        mel_reference = self.ap.melspectrogram(wav)
        mel_norm = ap.melspectrogram(wav)
        mel_denorm = ap._denormalize(mel_norm)
        assert abs(mel_reference - mel_denorm).max() < 1e-4
Exemple #4
0
 def __init__(self, *args, **kwargs):
     super(TestTTSDataset, self).__init__(*args, **kwargs)
     self.max_loader_iter = 4
     self.ap = AudioProcessor(**c.audio)
Exemple #5
0
class TestTTSDataset(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(TestTTSDataset, self).__init__(*args, **kwargs)
        self.max_loader_iter = 4
        self.ap = AudioProcessor(**c.audio)

    def _create_dataloader(self, batch_size, r, bgs):
        items = ljspeech(c.data_path, 'metadata.csv')
        dataset = TTSDataset.MyDataset(
            r,
            c.text_cleaner,
            compute_linear_spec=True,
            ap=self.ap,
            meta_data=items,
            tp=c.characters if 'characters' in c.keys() else None,
            batch_group_size=bgs,
            min_seq_len=c.min_seq_len,
            max_seq_len=float("inf"),
            use_phonemes=False)
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                collate_fn=dataset.collate_fn,
                                drop_last=True,
                                num_workers=c.num_loader_workers)
        return dataloader, dataset

    def test_loader(self):
        if ok_ljspeech:
            dataloader, dataset = self._create_dataloader(2, c.r, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                text_input = data[0]
                text_lengths = data[1]
                speaker_name = data[2]
                linear_input = data[3]
                mel_input = data[4]
                mel_lengths = data[5]
                stop_target = data[6]
                item_idx = data[7]

                neg_values = text_input[text_input < 0]
                check_count = len(neg_values)
                assert check_count == 0, \
                    " !! Negative values in text_input: {}".format(check_count)
                # TODO: more assertion here
                assert isinstance(speaker_name[0], str)
                assert linear_input.shape[0] == c.batch_size
                assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
                assert mel_input.shape[0] == c.batch_size
                assert mel_input.shape[2] == c.audio['num_mels']
                # check normalization ranges
                if self.ap.symmetric_norm:
                    assert mel_input.max() <= self.ap.max_norm
                    assert mel_input.min() >= -self.ap.max_norm  #pylint: disable=invalid-unary-operand-type
                    assert mel_input.min() < 0
                else:
                    assert mel_input.max() <= self.ap.max_norm
                    assert mel_input.min() >= 0

    def test_batch_group_shuffle(self):
        if ok_ljspeech:
            dataloader, dataset = self._create_dataloader(2, c.r, 16)
            last_length = 0
            frames = dataset.items
            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                text_input = data[0]
                text_lengths = data[1]
                speaker_name = data[2]
                linear_input = data[3]
                mel_input = data[4]
                mel_lengths = data[5]
                stop_target = data[6]
                item_idx = data[7]

                avg_length = mel_lengths.numpy().mean()
                assert avg_length >= last_length
            dataloader.dataset.sort_items()
            is_items_reordered = False
            for idx, item in enumerate(dataloader.dataset.items):
                if item != frames[idx]:
                    is_items_reordered = True
                    break
            assert is_items_reordered

    def test_padding_and_spec(self):
        if ok_ljspeech:
            dataloader, dataset = self._create_dataloader(1, 1, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                text_input = data[0]
                text_lengths = data[1]
                speaker_name = data[2]
                linear_input = data[3]
                mel_input = data[4]
                mel_lengths = data[5]
                stop_target = data[6]
                item_idx = data[7]

                # check mel_spec consistency
                wav = np.asarray(self.ap.load_wav(item_idx[0]),
                                 dtype=np.float32)
                mel = self.ap.melspectrogram(wav).astype('float32')
                mel = torch.FloatTensor(mel).contiguous()
                mel_dl = mel_input[0]
                # NOTE: Below needs to check == 0 but due to an unknown reason
                # there is a slight difference between two matrices.
                # TODO: Check this assert cond more in detail.
                assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T -
                                                             mel_dl).max()

                # check mel-spec correctness
                mel_spec = mel_input[0].cpu().numpy()
                wav = self.ap.inv_melspectrogram(mel_spec.T)
                self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader.wav')
                shutil.copy(item_idx[0],
                            OUTPATH + '/mel_target_dataloader.wav')

                # check linear-spec
                linear_spec = linear_input[0].cpu().numpy()
                wav = self.ap.inv_spectrogram(linear_spec.T)
                self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav')
                shutil.copy(item_idx[0],
                            OUTPATH + '/linear_target_dataloader.wav')

                # check the last time step to be zero padded
                assert linear_input[0, -1].sum() != 0
                assert linear_input[0, -2].sum() != 0
                assert mel_input[0, -1].sum() != 0
                assert mel_input[0, -2].sum() != 0
                assert stop_target[0, -1] == 1
                assert stop_target[0, -2] == 0
                assert stop_target.sum() == 1
                assert len(mel_lengths.shape) == 1
                assert mel_lengths[0] == linear_input[0].shape[0]
                assert mel_lengths[0] == mel_input[0].shape[0]

            # Test for batch size 2
            dataloader, dataset = self._create_dataloader(2, 1, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                text_input = data[0]
                text_lengths = data[1]
                speaker_name = data[2]
                linear_input = data[3]
                mel_input = data[4]
                mel_lengths = data[5]
                stop_target = data[6]
                item_idx = data[7]

                if mel_lengths[0] > mel_lengths[1]:
                    idx = 0
                else:
                    idx = 1

                # check the first item in the batch
                assert linear_input[idx, -1].sum() != 0
                assert linear_input[idx, -2].sum() != 0, linear_input
                assert mel_input[idx, -1].sum() != 0
                assert mel_input[idx, -2].sum() != 0, mel_input
                assert stop_target[idx, -1] == 1
                assert stop_target[idx, -2] == 0
                assert stop_target[idx].sum() == 1
                assert len(mel_lengths.shape) == 1
                assert mel_lengths[idx] == mel_input[idx].shape[0]
                assert mel_lengths[idx] == linear_input[idx].shape[0]

                # check the second itme in the batch
                assert linear_input[1 - idx, -1].sum() == 0
                assert mel_input[1 - idx, -1].sum() == 0
                assert stop_target[1, mel_lengths[1] - 1] == 1
                assert stop_target[1, mel_lengths[1]:].sum() == 0
                assert len(mel_lengths.shape) == 1
Exemple #6
0
from torch import nn, optim

from mozilla_voice_tts.tts.layers.losses import MSELossMasked
from mozilla_voice_tts.tts.models.tacotron2 import Tacotron2
from mozilla_voice_tts.utils.io import load_config
from mozilla_voice_tts.utils.audio import AudioProcessor

#pylint: disable=unused-variable

torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

c = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))

ap = AudioProcessor(**c.audio)
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")


class TacotronTrainTest(unittest.TestCase):
    def test_train_step(self):  # pylint: disable=no-self-use
        input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
        input_lengths = torch.randint(100, 128, (8, )).long().to(device)
        input_lengths = torch.sort(input_lengths, descending=True)[0]
        mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
        mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
        mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
        mel_lengths[0] = 30
        stop_targets = torch.zeros(8, 30, 1).float().to(device)
        speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
Exemple #7
0
class Mel2MelDataset(Dataset):
    def __init__(self, path, INPUT_SR, TARGET_SR, WINDOW_LENGTH):
        self.INPUT_SR = INPUT_SR
        self.TARGET_SR = TARGET_SR
        self.WINDOW_LENGTH = WINDOW_LENGTH
        self.CONFIG = load_config('config_fr.json')
        self.CONFIG['audio']['sample_rate'] = self.INPUT_SR
        self.AP_INPUT = AudioProcessor(**self.CONFIG['audio'])
        self.CONFIG['audio']['sample_rate'] = self.TARGET_SR
        self.AP_TARGET = AudioProcessor(**self.CONFIG['audio'])
        self.files = glob.glob(path + '/**/*.wav', recursive=True)
        #If you change your dataset, delete cache.json
        if os.path.isfile('./cache.json'):
            with open('./cache.json', "r") as json_file:
                self.pre_repertoir = json.load(json_file)
        else:
            print("> Computing wave files length...")
            self.pre_repertoir = [
                librosa.get_duration(filename=file)
                for file in tqdm(self.files)
            ]
            with open('./cache.json', mode="w") as json_file:
                json.dump(self.pre_repertoir, json_file)
        self.repertoir = [
            int(item / WINDOW_LENGTH) for item in self.pre_repertoir
        ]
        self.length = self.get_len()

    def __len__(self):
        return self.length

    def __getitem__(self, id):
        ref = self.get_reference(id)
        input_wav, _ = librosa.load(self.files[ref[0]],
                                    offset=self.WINDOW_LENGTH * ref[1],
                                    duration=self.WINDOW_LENGTH)
        target_wav = librosa.resample(input_wav, self.INPUT_SR, self.TARGET_SR)
        input = torch.tensor(self.AP_INPUT.melspectrogram(input_wav))
        target = torch.tensor(self.AP_TARGET.melspectrogram(target_wav))
        scale_factor = (target.shape[0] / input.shape[0],
                        target.shape[1] / input.shape[1])
        input = torch.nn.functional.interpolate(
            input.unsqueeze(0).unsqueeze(0),
            scale_factor=scale_factor,
            mode='bilinear').reshape(target.shape)
        return {
            'image':
            self.normalize(input).unsqueeze(0).type(torch.FloatTensor),
            'mask': self.normalize(target).unsqueeze(0).type(torch.FloatTensor)
        }

    def get_reference(self, id):
        i = 0
        sum = 0
        while True:
            if (sum > id):
                return (i - 1, id - sum + self.repertoir[i - 1])
            else:
                sum += self.repertoir[i]
                i += 1

    def get_len(self):
        sum = 0
        for num in self.repertoir:
            sum += num
        return sum

    def normalize(self, tensor):
        return tensor / 8 + 0.5

    def denormalize(self, tensor):
        return tensor - 0.5 * 8
import torch
from tests import get_tests_input_path, get_tests_output_path, get_tests_path

from mozilla_voice_tts.utils.audio import AudioProcessor
from mozilla_voice_tts.utils.io import load_config
from mozilla_voice_tts.vocoder.layers.losses import MultiScaleSTFTLoss, STFTLoss, TorchSTFT

TESTS_PATH = get_tests_path()

OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
os.makedirs(OUT_PATH, exist_ok=True)

WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")

C = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
ap = AudioProcessor(**C.audio)


def test_torch_stft():
    torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
    # librosa stft
    wav = ap.load_wav(WAV_FILE)
    M_librosa = abs(ap._stft(wav))  # pylint: disable=protected-access
    # torch stft
    wav = torch.from_numpy(wav[None, :]).float()
    M_torch = torch_stft(wav)
    # check the difference b/w librosa and torch outputs
    assert (M_librosa - M_torch[0].data.numpy()).max() < 1e-5


def test_stft_loss():
Exemple #9
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global meta_data_train, meta_data_eval, symbols, phonemes
    # Audio processor
    ap = AudioProcessor(**c.audio)
    if 'characters' in c.keys():
        symbols, phonemes = make_symbols(**c.characters)

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])
    num_chars = len(phonemes) if c.use_phonemes else len(symbols)

    # load data instances
    meta_data_train, meta_data_eval = load_meta_data(c.datasets)

    # set the portion of the data used for training
    if 'train_portion' in c.keys():
        meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
    if 'eval_portion' in c.keys():
        meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]

    # parse speakers
    if c.use_speaker_embedding:
        speakers = get_speakers(meta_data_train)
        if args.restore_path:
            if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
                prev_out_path = os.path.dirname(args.restore_path)
                speaker_mapping = load_speaker_mapping(prev_out_path)
                if not speaker_mapping:
                    print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
                    speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
                    if not speaker_mapping:
                        raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
                speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
            elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
                prev_out_path = os.path.dirname(args.restore_path)
                speaker_mapping = load_speaker_mapping(prev_out_path)
                speaker_embedding_dim = None
                assert all([speaker in speaker_mapping
                            for speaker in speakers]), "As of now you, you cannot " \
                                                    "introduce new speakers to " \
                                                    "a previously trained model."
        elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
            speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
            speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
        elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
            raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
        else: # if start new train and don't use External Embedding file
            speaker_mapping = {name: i for i, name in enumerate(speakers)}
            speaker_embedding_dim = None
        save_speaker_mapping(OUT_PATH, speaker_mapping)
        num_speakers = len(speaker_mapping)
        print("Training with {} speakers: {}".format(num_speakers,
                                                     ", ".join(speakers)))
    else:
        num_speakers = 0
        speaker_embedding_dim = None
        speaker_mapping = None

    model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)

    params = set_weight_decay(model, c.wd)
    optimizer = RAdam(params, lr=c.lr, weight_decay=0)
    if c.stopnet and c.separate_stopnet:
        optimizer_st = RAdam(model.decoder.stopnet.parameters(),
                             lr=c.lr,
                             weight_decay=0)
    else:
        optimizer_st = None

    if c.apex_amp_level == "O1":
        # pylint: disable=import-outside-toplevel
        from apex import amp
        model.cuda()
        model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
    else:
        amp = None

    # setup criterion
    criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)

    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            # TODO: fix optimizer init, model.cuda() needs to be called before
            # optimizer restore
            # optimizer.load_state_dict(checkpoint['optimizer'])
            if c.reinit_layers:
                raise RuntimeError
            model.load_state_dict(checkpoint['model'])
        except KeyError:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model'], c)
            # torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
            # print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
            model.load_state_dict(model_dict)
            del model_dict

        if amp and 'amp' in checkpoint:
            amp.load_state_dict(checkpoint['amp'])

        for group in optimizer.param_groups:
            group['lr'] = c.lr
        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model.cuda()
        criterion.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)

    if c.noam_schedule:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.warmup_steps,
                           last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        # set gradual training
        if c.gradual_training is not None:
            r, c.batch_size = gradual_training_scheduler(global_step, c)
            c.r = r
            model.decoder.set_r(r)
            if c.bidirectional_decoder:
                model.decoder_backward.set_r(r)
            print("\n > Number of output frames:", model.decoder.r)
        train_avg_loss_dict, global_step = train(model, criterion, optimizer,
                                                 optimizer_st, scheduler, ap,
                                                 global_step, epoch, amp, speaker_mapping)
        eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = train_avg_loss_dict['avg_postnet_loss']
        if c.run_eval:
            target_loss = eval_avg_loss_dict['avg_postnet_loss']
        best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
                                    OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
Exemple #10
0
class Synthesizer(object):
    def __init__(self, config):
        self.wavernn = None
        self.vocoder_model = None
        self.config = config
        print(config)
        self.seg = self.get_segmenter("en")
        self.use_cuda = self.config.use_cuda
        if self.use_cuda:
            assert torch.cuda.is_available(
            ), "CUDA is not availabe on this machine."
        self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
                      self.config.use_cuda)
        if self.config.vocoder_checkpoint:
            self.load_vocoder(self.config.vocoder_checkpoint,
                              self.config.vocoder_config, self.config.use_cuda)
        if self.config.wavernn_lib_path:
            self.load_wavernn(self.config.wavernn_lib_path,
                              self.config.wavernn_checkpoint,
                              self.config.wavernn_config, self.config.use_cuda)

    @staticmethod
    def get_segmenter(lang):
        return pysbd.Segmenter(language=lang, clean=True)

    def load_tts(self, tts_checkpoint, tts_config, use_cuda):
        # pylint: disable=global-statement
        global symbols, phonemes

        print(" > Loading TTS model ...")
        print(" | > model config: ", tts_config)
        print(" | > checkpoint file: ", tts_checkpoint)

        self.tts_config = load_config(tts_config)
        self.use_phonemes = self.tts_config.use_phonemes
        self.ap = AudioProcessor(**self.tts_config.audio)

        if 'characters' in self.tts_config.keys():
            symbols, phonemes = make_symbols(**self.tts_config.characters)

        if self.use_phonemes:
            self.input_size = len(phonemes)
        else:
            self.input_size = len(symbols)
        # TODO: fix this for multi-speaker model - load speakers
        if self.config.tts_speakers is not None:
            self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
            num_speakers = len(self.tts_speakers)
        else:
            num_speakers = 0
        self.tts_model = setup_model(self.input_size,
                                     num_speakers=num_speakers,
                                     c=self.tts_config)
        # load model state
        cp = torch.load(tts_checkpoint, map_location=torch.device('cpu'))
        # load the model
        self.tts_model.load_state_dict(cp['model'])
        if use_cuda:
            self.tts_model.cuda()
        self.tts_model.eval()
        self.tts_model.decoder.max_decoder_steps = 3000
        if 'r' in cp:
            self.tts_model.decoder.set_r(cp['r'])
            print(f" > model reduction factor: {cp['r']}")

    def load_vocoder(self, model_file, model_config, use_cuda):
        self.vocoder_config = load_config(model_config)
        self.vocoder_model = setup_generator(self.vocoder_config)
        self.vocoder_model.load_state_dict(
            torch.load(model_file, map_location="cpu")["model"])
        self.vocoder_model.remove_weight_norm()
        self.vocoder_model.inference_padding = 0
        self.vocoder_config = load_config(model_config)

        if use_cuda:
            self.vocoder_model.cuda()
        self.vocoder_model.eval()

    def load_wavernn(self, lib_path, model_file, model_config, use_cuda):
        # TODO: set a function in wavernn code base for model setup and call it here.
        sys.path.append(
            lib_path)  # set this if WaveRNN is not installed globally
        #pylint: disable=import-outside-toplevel
        from WaveRNN.models.wavernn import Model
        print(" > Loading WaveRNN model ...")
        print(" | > model config: ", model_config)
        print(" | > model file: ", model_file)
        self.wavernn_config = load_config(model_config)
        # This is the default architecture we use for our models.
        # You might need to update it
        self.wavernn = Model(
            rnn_dims=512,
            fc_dims=512,
            mode=self.wavernn_config.mode,
            mulaw=self.wavernn_config.mulaw,
            pad=self.wavernn_config.pad,
            use_aux_net=self.wavernn_config.use_aux_net,
            use_upsample_net=self.wavernn_config.use_upsample_net,
            upsample_factors=self.wavernn_config.upsample_factors,
            feat_dims=80,
            compute_dims=128,
            res_out_dims=128,
            res_blocks=10,
            hop_length=self.ap.hop_length,
            sample_rate=self.ap.sample_rate,
        ).cuda()

        check = torch.load(model_file, map_location="cpu")
        self.wavernn.load_state_dict(check['model'])
        if use_cuda:
            self.wavernn.cuda()
        self.wavernn.eval()

    def save_wav(self, wav, path):
        # wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
        wav = np.array(wav)
        self.ap.save_wav(wav, path)

    def split_into_sentences(self, text):
        return self.seg.segment(text)

    def tts(self, text, speaker_id=None):
        start_time = time.time()
        wavs = []
        sens = self.split_into_sentences(text)
        print(sens)
        speaker_id = id_to_torch(speaker_id)
        if speaker_id is not None and self.use_cuda:
            speaker_id = speaker_id.cuda()

        for sen in sens:
            # preprocess the given text
            inputs = text_to_seqvec(sen, self.tts_config)
            inputs = numpy_to_torch(inputs, torch.long, cuda=self.use_cuda)
            inputs = inputs.unsqueeze(0)
            # synthesize voice
            _, postnet_output, _, _ = run_model_torch(self.tts_model, inputs,
                                                      self.tts_config, False,
                                                      speaker_id, None)
            if self.vocoder_model:
                # use native vocoder model
                vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0)
                wav = self.vocoder_model.inference(vocoder_input)
                if self.use_cuda:
                    wav = wav.cpu().numpy()
                else:
                    wav = wav.numpy()
                wav = wav.flatten()
            elif self.wavernn:
                # use 3rd paty wavernn
                vocoder_input = None
                if self.tts_config.model == "Tacotron":
                    vocoder_input = torch.FloatTensor(
                        self.ap.out_linear_to_mel(
                            linear_spec=postnet_output.T).T).T.unsqueeze(0)
                else:
                    vocoder_input = postnet_output[0].transpose(0,
                                                                1).unsqueeze(0)
                if self.use_cuda:
                    vocoder_input.cuda()
                wav = self.wavernn.generate(
                    vocoder_input,
                    batched=self.config.is_wavernn_batched,
                    target=11000,
                    overlap=550)
            else:
                # use GL
                if self.use_cuda:
                    postnet_output = postnet_output[0].cpu()
                else:
                    postnet_output = postnet_output[0]
                postnet_output = postnet_output.numpy()
                wav = inv_spectrogram(postnet_output, self.ap, self.tts_config)

            # trim silence
            wav = trim_silence(wav, self.ap)

            wavs += list(wav)
            wavs += [0] * 10000

        out = io.BytesIO()
        self.save_wav(wavs, out)

        # compute stats
        process_time = time.time() - start_time
        audio_time = len(wavs) / self.tts_config.audio['sample_rate']
        print(f" > Processing time: {process_time}")
        print(f" > Real-time factor: {process_time / audio_time}")
        return out
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments,
                     use_noise_augment, use_cache, num_workers):
    ''' run dataloader with given parameters and check conditions '''
    ap = AudioProcessor(**C.audio)
    _, train_items = load_wav_data(test_data_path, 10)
    dataset = GANDataset(ap,
                         train_items,
                         seq_len=seq_len,
                         hop_len=hop_len,
                         pad_short=2000,
                         conv_pad=conv_pad,
                         return_segments=return_segments,
                         use_noise_augment=use_noise_augment,
                         use_cache=use_cache)
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_workers,
                        pin_memory=True,
                        drop_last=True)

    max_iter = 10
    count_iter = 0

    # return random segments or return the whole audio
    if return_segments:
        for item1, _ in loader:
            feat1, wav1 = item1
            # feat2, wav2 = item2
            expected_feat_shape = (batch_size, ap.num_mels,
                                   seq_len // hop_len + conv_pad * 2)

            # check shapes
            assert np.all(feat1.shape == expected_feat_shape
                          ), f" [!] {feat1.shape} vs {expected_feat_shape}"
            assert (feat1.shape[2] - conv_pad * 2) * hop_len == wav1.shape[2]

            # check feature vs audio match
            if not use_noise_augment:
                for idx in range(batch_size):
                    audio = wav1[idx].squeeze()
                    feat = feat1[idx]
                    mel = ap.melspectrogram(audio)
                    # the first 2 and the last 2 frames are skipped due to the padding
                    # differences in stft
                    assert (feat - mel[:, :feat1.shape[-1]])[:, 2:-2].sum(
                    ) <= 0, f' [!] {(feat - mel[:, :feat1.shape[-1]])[:, 2:-2].sum()}'

            count_iter += 1
            # if count_iter == max_iter:
            #     break
    else:
        for item in loader:
            feat, wav = item
            expected_feat_shape = (batch_size, ap.num_mels,
                                   (wav.shape[-1] // hop_len) + (conv_pad * 2))
            assert np.all(feat.shape == expected_feat_shape
                          ), f" [!] {feat.shape} vs {expected_feat_shape}"
            assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]
            count_iter += 1
            if count_iter == max_iter:
                break
Exemple #12
0
 def __init__(self, *args, **kwargs):
     super(TestAudio, self).__init__(*args, **kwargs)
     self.ap = AudioProcessor(**conf.audio)
Exemple #13
0
class TestAudio(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(TestAudio, self).__init__(*args, **kwargs)
        self.ap = AudioProcessor(**conf.audio)

    def test_audio_synthesis(self):
        """ 1. load wav
            2. set normalization parameters
            3. extract mel-spec
            4. invert to wav and save the output
        """
        print(" > Sanity check for the process wav -> mel -> wav")

        def _test(max_norm, signal_norm, symmetric_norm, clip_norm):
            self.ap.max_norm = max_norm
            self.ap.signal_norm = signal_norm
            self.ap.symmetric_norm = symmetric_norm
            self.ap.clip_norm = clip_norm
            wav = self.ap.load_wav(WAV_FILE)
            mel = self.ap.melspectrogram(wav)
            wav_ = self.ap.inv_melspectrogram(mel)
            file_name = "/audio_test-melspec_max_norm_{}-signal_norm_{}-symmetric_{}-clip_norm_{}.wav"\
                .format(max_norm, signal_norm, symmetric_norm, clip_norm)
            print(" | > Creating wav file at : ", file_name)
            self.ap.save_wav(wav_, OUT_PATH + file_name)

        # maxnorm = 1.0
        _test(1., False, False, False)
        _test(1., True, False, False)
        _test(1., True, True, False)
        _test(1., True, False, True)
        _test(1., True, True, True)
        # maxnorm = 4.0
        _test(4., False, False, False)
        _test(4., True, False, False)
        _test(4., True, True, False)
        _test(4., True, False, True)
        _test(4., True, True, True)

    def test_normalize(self):
        """Check normalization and denormalization for range values and consistency """
        print(" > Testing normalization and denormalization.")
        wav = self.ap.load_wav(WAV_FILE)
        wav = self.ap.sound_norm(wav)  # normalize audio to get abetter normalization range below.
        self.ap.signal_norm = False
        x = self.ap.melspectrogram(wav)
        x_old = x

        self.ap.signal_norm = True
        self.ap.symmetric_norm = False
        self.ap.clip_norm = False
        self.ap.max_norm = 4.0
        x_norm = self.ap._normalize(x)
        print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}")
        assert (x_old - x).sum() == 0
        # check value range
        assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
        assert x_norm.min() >= 0 - 1, x_norm.min()
        # check denorm.
        x_ = self.ap._denormalize(x_norm)
        assert (x - x_).sum() < 1e-3, (x - x_).mean()

        self.ap.signal_norm = True
        self.ap.symmetric_norm = False
        self.ap.clip_norm = True
        self.ap.max_norm = 4.0
        x_norm = self.ap._normalize(x)
        print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}")


        assert (x_old - x).sum() == 0
        # check value range
        assert x_norm.max() <= self.ap.max_norm, x_norm.max()
        assert x_norm.min() >= 0, x_norm.min()
        # check denorm.
        x_ = self.ap._denormalize(x_norm)
        assert (x - x_).sum() < 1e-3, (x - x_).mean()

        self.ap.signal_norm = True
        self.ap.symmetric_norm = True
        self.ap.clip_norm = False
        self.ap.max_norm = 4.0
        x_norm = self.ap._normalize(x)
        print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}")


        assert (x_old - x).sum() == 0
        # check value range
        assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
        assert x_norm.min() >= -self.ap.max_norm - 2, x_norm.min()  #pylint: disable=invalid-unary-operand-type
        assert x_norm.min() <= 0, x_norm.min()
        # check denorm.
        x_ = self.ap._denormalize(x_norm)
        assert (x - x_).sum() < 1e-3, (x - x_).mean()

        self.ap.signal_norm = True
        self.ap.symmetric_norm = True
        self.ap.clip_norm = True
        self.ap.max_norm = 4.0
        x_norm = self.ap._normalize(x)
        print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}")


        assert (x_old - x).sum() == 0
        # check value range
        assert x_norm.max() <= self.ap.max_norm, x_norm.max()
        assert x_norm.min() >= -self.ap.max_norm, x_norm.min()  #pylint: disable=invalid-unary-operand-type
        assert x_norm.min() <= 0, x_norm.min()
        # check denorm.
        x_ = self.ap._denormalize(x_norm)
        assert (x - x_).sum() < 1e-3, (x - x_).mean()

        self.ap.signal_norm = True
        self.ap.symmetric_norm = False
        self.ap.max_norm = 1.0
        x_norm = self.ap._normalize(x)
        print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}")


        assert (x_old - x).sum() == 0
        assert x_norm.max() <= self.ap.max_norm, x_norm.max()
        assert x_norm.min() >= 0, x_norm.min()
        x_ = self.ap._denormalize(x_norm)
        assert (x - x_).sum() < 1e-3

        self.ap.signal_norm = True
        self.ap.symmetric_norm = True
        self.ap.max_norm = 1.0
        x_norm = self.ap._normalize(x)
        print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}")


        assert (x_old - x).sum() == 0
        assert x_norm.max() <= self.ap.max_norm, x_norm.max()
        assert x_norm.min() >= -self.ap.max_norm, x_norm.min()  #pylint: disable=invalid-unary-operand-type
        assert x_norm.min() < 0, x_norm.min()
        x_ = self.ap._denormalize(x_norm)
        assert (x - x_).sum() < 1e-3

    def test_scaler(self):
        scaler_stats_path = os.path.join(get_tests_input_path(), 'scale_stats.npy')
        conf.audio['stats_path'] = scaler_stats_path
        conf.audio['preemphasis'] = 0.0
        conf.audio['do_trim_silence'] = True
        conf.audio['signal_norm'] = True

        ap = AudioProcessor(**conf.audio)
        mel_mean, mel_std, linear_mean, linear_std, _ = ap.load_stats(scaler_stats_path)
        ap.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)

        self.ap.signal_norm = False
        self.ap.preemphasis = 0.0

        # test scaler forward and backward transforms
        wav = self.ap.load_wav(WAV_FILE)
        mel_reference = self.ap.melspectrogram(wav)
        mel_norm = ap.melspectrogram(wav)
        mel_denorm = ap._denormalize(mel_norm)
        assert abs(mel_reference - mel_denorm).max() < 1e-4
Exemple #14
0
        type=str,
        help="if CONFIG.use_external_speaker_embedding_file is true, name of speaker embedding reference file present in speakers.json, else target speaker_fileid if the model is multi-speaker.",
        default=None)
    parser.add_argument(
        '--gst_style',
        help="Wav path file for GST stylereference.",
        default=None)

    args = parser.parse_args()

    # load the config
    C = load_config(args.config_path)
    C.forward_attn_mask = True

    # load the audio processor
    ap = AudioProcessor(**C.audio)

    # if the vocabulary was passed, replace the default
    if 'characters' in C.keys():
        symbols, phonemes = make_symbols(**C.characters)

    speaker_embedding = None
    speaker_embedding_dim = None
    num_speakers = 0

    # load speakers
    if args.speakers_json != '':
        speaker_mapping = json.load(open(args.speakers_json, 'r'))
        num_speakers = len(speaker_mapping)
        if C.use_external_speaker_embedding_file:
            if args.speaker_fileid is not None:
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data
    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
                                                   c.eval_split_size)
    else:
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)

    # setup audio processor
    ap = AudioProcessor(**c.audio)

    # DISTRUBUTED
    # if num_gpus > 1:
    # init_distributed(args.rank, num_gpus, args.group_id,
    #  c.distributed["backend"], c.distributed["url"])

    # setup models
    model_gen = setup_generator(c)
    model_disc = setup_discriminator(c)

    # setup optimizers
    optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
    optimizer_disc = RAdam(model_disc.parameters(),
                           lr=c.lr_disc,
                           weight_decay=0)

    # schedulers
    scheduler_gen = None
    scheduler_disc = None
    if 'lr_scheduler_gen' in c:
        scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
        scheduler_gen = scheduler_gen(optimizer_gen,
                                      **c.lr_scheduler_gen_params)
    if 'lr_scheduler_disc' in c:
        scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
        scheduler_disc = scheduler_disc(optimizer_disc,
                                        **c.lr_scheduler_disc_params)

    # setup criterion
    criterion_gen = GeneratorLoss(c)
    criterion_disc = DiscriminatorLoss(c)

    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            print(" > Restoring Generator Model...")
            model_gen.load_state_dict(checkpoint['model'])
            print(" > Restoring Generator Optimizer...")
            optimizer_gen.load_state_dict(checkpoint['optimizer'])
            print(" > Restoring Discriminator Model...")
            model_disc.load_state_dict(checkpoint['model_disc'])
            print(" > Restoring Discriminator Optimizer...")
            optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
            if 'scheduler' in checkpoint:
                print(" > Restoring Generator LR Scheduler...")
                scheduler_gen.load_state_dict(checkpoint['scheduler'])
                # NOTE: Not sure if necessary
                scheduler_gen.optimizer = optimizer_gen
            if 'scheduler_disc' in checkpoint:
                print(" > Restoring Discriminator LR Scheduler...")
                scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
                scheduler_disc.optimizer = optimizer_disc
        except RuntimeError:
            # retore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model_gen.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model'], c)
            model_gen.load_state_dict(model_dict)

            model_dict = model_disc.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
            model_disc.load_state_dict(model_dict)
            del model_dict

        # reset lr if not countinuining training.
        for group in optimizer_gen.param_groups:
            group['lr'] = c.lr_gen

        for group in optimizer_disc.param_groups:
            group['lr'] = c.lr_disc

        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model_gen.cuda()
        criterion_gen.cuda()
        model_disc.cuda()
        criterion_disc.cuda()

    # DISTRUBUTED
    # if num_gpus > 1:
    #     model = apply_gradient_allreduce(model)

    num_params = count_parameters(model_gen)
    print(" > Generator has {} parameters".format(num_params), flush=True)
    num_params = count_parameters(model_disc)
    print(" > Discriminator has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(model_gen, criterion_gen, optimizer_gen,
                               model_disc, criterion_disc, optimizer_disc,
                               scheduler_gen, scheduler_disc, ap, global_step,
                               epoch)
        eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
                                      criterion_disc, ap, global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict[c.target_loss]
        best_loss = save_best_model(target_loss,
                                    best_loss,
                                    model_gen,
                                    optimizer_gen,
                                    scheduler_gen,
                                    model_disc,
                                    optimizer_disc,
                                    scheduler_disc,
                                    global_step,
                                    epoch,
                                    OUT_PATH,
                                    model_losses=eval_avg_loss_dict)
Exemple #16
0
def main():
    """Run preprocessing process."""
    parser = argparse.ArgumentParser(
        description="Compute mean and variance of spectrogtram features.")
    parser.add_argument(
        "--config_path",
        type=str,
        required=True,
        help="TTS config file path to define audio processin parameters.")
    parser.add_argument("--out_path",
                        default=None,
                        type=str,
                        help="directory to save the output file.")
    args = parser.parse_args()

    # load config
    CONFIG = load_config(args.config_path)
    CONFIG.audio['signal_norm'] = False  # do not apply earlier normalization
    CONFIG.audio['stats_path'] = None  # discard pre-defined stats

    # load audio processor
    ap = AudioProcessor(**CONFIG.audio)

    # load the meta data of target dataset
    dataset_items = load_meta_data(CONFIG.datasets)[0]  # take only train data
    print(f" > There are {len(dataset_items)} files.")

    mel_sum = 0
    mel_square_sum = 0
    linear_sum = 0
    linear_square_sum = 0
    N = 0
    for item in tqdm(dataset_items):
        # compute features
        wav = ap.load_wav(item[1])
        linear = ap.spectrogram(wav)
        mel = ap.melspectrogram(wav)

        # compute stats
        N += mel.shape[1]
        mel_sum += mel.sum(1)
        linear_sum += linear.sum(1)
        mel_square_sum += (mel**2).sum(axis=1)
        linear_square_sum += (linear**2).sum(axis=1)

    mel_mean = mel_sum / N
    mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2)
    linear_mean = linear_sum / N
    linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2)

    output_file_path = os.path.join(args.out_path, "scale_stats.npy")
    stats = {}
    stats['mel_mean'] = mel_mean
    stats['mel_std'] = mel_scale
    stats['linear_mean'] = linear_mean
    stats['linear_std'] = linear_scale

    print(f' > Avg mel spec mean: {mel_mean.mean()}')
    print(f' > Avg mel spec scale: {mel_scale.mean()}')
    print(f' > Avg linear spec mean: {linear_mean.mean()}')
    print(f' > Avg lienar spec scale: {linear_scale.mean()}')

    # set default config values for mean-var scaling
    CONFIG.audio['stats_path'] = output_file_path
    CONFIG.audio['signal_norm'] = True
    # remove redundant values
    del CONFIG.audio['max_norm']
    del CONFIG.audio['min_level_db']
    del CONFIG.audio['symmetric_norm']
    del CONFIG.audio['clip_norm']
    stats['audio_config'] = CONFIG.audio
    np.save(output_file_path, stats, allow_pickle=True)
    print(f' > scale_stats.npy is saved to {output_file_path}')
Exemple #17
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global meta_data_train
    global meta_data_eval

    ap = AudioProcessor(**c.audio)
    model = SpeakerEncoder(input_dim=c.model['input_dim'],
                           proj_dim=c.model['proj_dim'],
                           lstm_dim=c.model['lstm_dim'],
                           num_lstm_layers=c.model['num_lstm_layers'])
    optimizer = RAdam(model.parameters(), lr=c.lr)

    if c.loss == "ge2e":
        criterion = GE2ELoss(loss_method='softmax')
    elif c.loss == "angleproto":
        criterion = AngleProtoLoss()
    else:
        raise Exception("The %s  not is a loss supported" % c.loss)

    if args.restore_path:
        checkpoint = torch.load(args.restore_path)
        try:
            # TODO: fix optimizer init, model.cuda() needs to be called before
            # optimizer restore
            # optimizer.load_state_dict(checkpoint['optimizer'])
            if c.reinit_layers:
                raise RuntimeError
            model.load_state_dict(checkpoint['model'])
        except KeyError:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint, c)
            model.load_state_dict(model_dict)
            del model_dict
        for group in optimizer.param_groups:
            group['lr'] = c.lr
        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model = model.cuda()
        criterion.cuda()

    if c.lr_decay:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.warmup_steps,
                           last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    # pylint: disable=redefined-outer-name
    meta_data_train, meta_data_eval = load_meta_data(c.datasets)

    global_step = args.restore_step
    _, global_step = train(model, criterion, optimizer, scheduler, ap,
                           global_step)