Пример #1
0
 def build_model(self):
     with tf.variable_scope("SSRN"):
         ## OSW: use 'mels' for input both in training and synthesis -- can be either variable or placeholder
         self.Z_logits, self.Z = SSRN(self.hp,
                                      self.mels,
                                      training=self.training,
                                      speaker_codes=self.speakers,
                                      reuse=self.reuse)
Пример #2
0
    def __init__(self, num=1, mode="train"):
        '''
        Args:
          num: Either 1 or 2. 1 for Text2Mel 2 for SSRN.
          mode: Either "train" or "synthesize".
        '''
        # Load vocabulary
        self.char2idx, self.idx2char = load_vocab()

        # Set flag
        training = True if mode == "train" else False

        # Graph
        # Data Feeding
        ## L: Text. (B, N), int32
        ## mels: Reduced melspectrogram. (B, T/r, n_mels) float32
        ## mags: Magnitude. (B, T, n_fft//2+1) float32

        self.L = tf.placeholder(tf.int32, shape=(None, None))
        self.mels = tf.placeholder(tf.float32, shape=(None, None, hp.n_mels))
        self.prev_max_attentions = tf.placeholder(tf.int32, shape=(None, ))

        with tf.variable_scope("Text2Mel"):
            # Get S or decoder inputs. (B, T//r, n_mels)
            self.S = tf.concat(
                (tf.zeros_like(self.mels[:, :1, :]), self.mels[:, :-1, :]), 1)

            # Networks
            with tf.variable_scope("TextEnc"):
                self.K, self.V = TextEnc(self.L,
                                         training=training)  # (N, Tx, e)

            with tf.variable_scope("AudioEnc"):
                self.Q = AudioEnc(self.S, training=training)

            with tf.variable_scope("Attention"):
                # R: (B, T/r, 2d)
                # alignments: (B, N, T/r)
                # max_attentions: (B,)
                self.R, self.alignments, self.max_attentions = Attention(
                    self.Q,
                    self.K,
                    self.V,
                    mononotic_attention=(not training),
                    prev_max_attentions=self.prev_max_attentions)
            with tf.variable_scope("AudioDec"):
                self.Y_logits, self.Y = AudioDec(
                    self.R, training=training)  # (B, T/r, n_mels)

            # During inference, the predicted melspectrogram values are fed.
        with tf.variable_scope("SSRN"):
            self.Z_logits, self.Z = SSRN(self.Y, training=training)

        with tf.variable_scope("gs"):
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)
    def build_model(self):

        self.load_data_in_memory(model='ssrn')
        self.add_data(reuse=self.reuse, model='ssrn')

        with tf.variable_scope("SSRN"):
            ## OSW: use 'mels' for input both in training and synthesis -- can be either variable or placeholder
            self.Z_logits, self.Z = SSRN(self.hp,
                                         self.mels,
                                         training=self.training,
                                         speaker_codes=self.speakers,
                                         reuse=self.reuse)
Пример #4
0
    def __init__(self, num=1):

        # Load vocabulary
        self.char2idx, self.idx2char = self.load_vocab()

        # Set flag
        training = False

        # Graph
        # Data Feeding

        # Synthesize
        self.L = tf.placeholder(tf.int32, shape=(None, None))
        self.mels = tf.placeholder(tf.float32, shape=(None, None, hp.n_mels))
        self.prev_max_attentions = tf.placeholder(tf.int32, shape=(None, ))

        with tf.variable_scope("Text2Mel"):
            # Get S or decoder inputs. (B, T//r, n_mels)
            self.S = tf.concat(
                (tf.zeros_like(self.mels[:, :1, :]), self.mels[:, :-1, :]), 1)

            # Networks
            with tf.variable_scope("TextEnc"):
                self.K, self.V = TextEnc(self.L,
                                         training=training)  # (N, Tx, e)
            with tf.variable_scope("AudioEnc"):
                self.Q = AudioEnc(self.S, training=training)

            with tf.variable_scope("Attention"):
                # R: (B, T/r, 2d)
                # alignments: (B, N, T/r)
                # max_attentions: (B,)
                self.R, self.alignments, self.max_attentions = Attention(
                    self.Q,
                    self.K,
                    self.V,
                    mononotic_attention=(not training),
                    prev_max_attentions=self.prev_max_attentions)
            with tf.variable_scope("AudioDec"):
                self.Y_logits, self.Y = AudioDec(
                    self.R, training=training)  # (B, T/r, n_mels)

        # During inference, the predicted melspectrogram values are fed.
        with tf.variable_scope("SSRN"):
            self.Z_logits, self.Z = SSRN(self.Y, training=training)

        with tf.variable_scope("gs"):
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)
Пример #5
0
    def __init__(self, num=1, mode="train"):
        '''
        Args:
          num: Either 1 or 2. 1 for Text2Mel 2 for SSRN.
          mode: Either "train" or "synthesize".
        '''
        # Load vocabulary
        self.char2idx, self.idx2char = load_vocab()

        # Set flag
        training = True if mode == "train" else False

        # Graph
        # Data Feeding
        ## L: Text. (B, N), int32
        ## mels: Reduced melspectrogram. (B, T/r, n_mels) float32
        ## mags: Magnitude. (B, T, n_fft//2+1) float32
        if mode == "train":
            self.L, self.mels, self.mags, self.fnames, self.num_batch = get_batch(
            )
            self.prev_max_attentions = tf.ones(shape=(hp.B, ), dtype=tf.int32)
            self.gts = tf.convert_to_tensor(guided_attention())
        else:  # Synthesize
            self.L = tf.placeholder(tf.int32, shape=(None, None))
            self.mels = tf.placeholder(tf.float32,
                                       shape=(None, None, hp.n_mels))
            self.prev_max_attentions = tf.placeholder(tf.int32, shape=(None, ))

        if num == 1 or (not training):
            with tf.variable_scope("Text2Mel"):
                # Get S or decoder inputs. (B, T//r, n_mels)
                self.S = tf.concat(
                    (tf.zeros_like(self.mels[:, :1, :]), self.mels[:, :-1, :]),
                    1)

                # Networks
                with tf.variable_scope("TextEnc"):
                    self.K, self.V = TextEnc(self.L,
                                             training=training)  # (N, Tx, e)

                with tf.variable_scope("AudioEnc"):
                    self.Q = AudioEnc(self.S, training=training)

                with tf.variable_scope("Attention"):
                    # R: (B, T/r, 2d)
                    # alignments: (B, N, T/r)
                    # max_attentions: (B,)
                    self.R, self.alignments, self.max_attentions = Attention(
                        self.Q,
                        self.K,
                        self.V,
                        mononotic_attention=(not training),
                        prev_max_attentions=self.prev_max_attentions)
                with tf.variable_scope("AudioDec"):
                    self.Y_logits, self.Y = AudioDec(
                        self.R, training=training)  # (B, T/r, n_mels)
        else:  # num==2 & training. Note that during training,
            # the ground truth melspectrogram values are fed.
            with tf.variable_scope("SSRN"):
                self.Z_logits, self.Z = SSRN(self.mels, training=training)

        if not training:
            # During inference, the predicted melspectrogram values are fed.
            with tf.variable_scope("SSRN"):
                self.Z_logits, self.Z = SSRN(self.Y, training=training)

        with tf.variable_scope("gs"):
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)

        if training:
            if num == 1:  # Text2Mel
                # mel L1 loss
                self.loss_mels = tf.reduce_mean(tf.abs(self.Y - self.mels))

                # mel binary divergence loss
                self.loss_bd1 = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=self.Y_logits, labels=self.mels))

                # guided_attention loss
                self.A = tf.pad(self.alignments, [(0, 0), (0, hp.max_N),
                                                  (0, hp.max_T)],
                                mode="CONSTANT",
                                constant_values=-1.)[:, :hp.max_N, :hp.max_T]
                self.attention_masks = tf.to_float(tf.not_equal(self.A, -1))
                self.loss_att = tf.reduce_sum(
                    tf.abs(self.A * self.gts) * self.attention_masks)
                self.mask_sum = tf.reduce_sum(self.attention_masks)
                self.loss_att /= self.mask_sum

                # total loss
                self.loss = self.loss_mels + self.loss_bd1 + self.loss_att

                tf.summary.scalar('train/loss_mels', self.loss_mels)
                tf.summary.scalar('train/loss_bd1', self.loss_bd1)
                tf.summary.scalar('train/loss_att', self.loss_att)
                tf.summary.image(
                    'train/mel_gt',
                    tf.expand_dims(tf.transpose(self.mels[:1], [0, 2, 1]), -1))
                tf.summary.image(
                    'train/mel_hat',
                    tf.expand_dims(tf.transpose(self.Y[:1], [0, 2, 1]), -1))
            else:  # SSRN
                # mag L1 loss
                self.loss_mags = tf.reduce_mean(tf.abs(self.Z - self.mags))

                # mag binary divergence loss
                self.loss_bd2 = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=self.Z_logits, labels=self.mags))

                # total loss
                self.loss = self.loss_mags + self.loss_bd2

                tf.summary.scalar('train/loss_mags', self.loss_mags)
                tf.summary.scalar('train/loss_bd2', self.loss_bd2)
                tf.summary.image(
                    'train/mag_gt',
                    tf.expand_dims(tf.transpose(self.mags[:1], [0, 2, 1]), -1))
                tf.summary.image(
                    'train/mag_hat',
                    tf.expand_dims(tf.transpose(self.Z[:1], [0, 2, 1]), -1))

            # Training Scheme
            self.lr = learning_rate_decay(hp.lr, self.global_step)
            self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr)
            tf.summary.scalar("lr", self.lr)

            ## gradient clipping
            self.gvs = self.optimizer.compute_gradients(self.loss)
            self.clipped = []
            for grad, var in self.gvs:
                grad = tf.clip_by_value(grad, -1., 1.)
                self.clipped.append((grad, var))
                self.train_op = self.optimizer.apply_gradients(
                    self.clipped, global_step=self.global_step)

            # Summary
            self.merged = tf.summary.merge_all()
Пример #6
0
    config_ssrn = state_ssrn["config"]
    if config_ssrn != config_t2m:
        print(
            "WARNING: Text2Mel and SSRN have different saved configs. Will use Text2Mel config!"
        )
    Config.set_config(config_t2m)

    # Load networks
    print("Loading Text2Mel...")
    text2mel = Text2Mel().to(device)
    text2mel.eval()
    text2mel_step = state_t2m["global_step"]
    text2mel.load_state_dict(state_t2m["model"])

    print("Loading SSRN...")
    ssrn = SSRN().to(device)
    ssrn.eval()
    ssrn_step = state_ssrn["global_step"]
    ssrn.load_state_dict(state_ssrn["model"])

    while True:
        text = input("> ")
        text = spell_out_numbers(text, args.language)
        text = normalize(text)
        text = text + Config.vocab_end_of_text
        text = vocab_lookup(text)

        L = torch.tensor(text, device=device, requires_grad=False).unsqueeze(0)
        S = torch.zeros(1,
                        Config.max_T,
                        Config.F,
Пример #7
0
def train():
    print("training begins...")
    # training and validation relative directories
    train_directory = "../../ETTT/Pytorch-DCTTS/LJSpeech_data/"
    val_directory = "../../ETTT/Pytorch-DCTTS/LJSpeech_val/"
    t_data = LJDataset(train_directory)
    v_data = LJDataset(val_directory)
    train_len = len(t_data.bases)
    val_len = len(v_data.bases)

    # training parameters
    batch_size = 40
    epochs = 500
    save_every = 5
    learning_rate = 1e-4
    max_grad_norm = 1.0

    # create model and optim
    hp = Hparams()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = SSRN(hp, device)
    optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # main training loop
    for ep in tqdm(range(epochs)):
        total_loss = 0  # epoch loss
        t_loader = DataLoader(t_data,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=False,
                              collate_fn=ssrn_collate_fn)
        for data in tqdm(t_loader):
            # initialize batch_loss
            batch_loss = model.compute_batch_loss(data)
            # batch update
            optim.zero_grad()
            batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(max_norm=max_grad_norm,
                                           parameters=model.parameters())
            optim.step()
            total_loss += batch_loss.detach().cpu().numpy()
        # one epoch complete, add to total loss and print
        print(
            "epoch {}, total loss:{}, average total loss:{}, validating now..."
            .format(ep, float(total_loss),
                    float(total_loss) / train_len))
        # if time to save, we save model
        if ep % save_every == 0:
            torch.save(
                model.state_dict(),
                "save_stuff/checkpoint/epoch_" + str(ep) + "_ssrn_model.pt")

        # Validation phase
        with torch.no_grad():
            total_loss = 0
            v_loader = DataLoader(v_data,
                                  batch_size=batch_size // 10,
                                  shuffle=True,
                                  drop_last=False,
                                  collate_fn=ssrn_collate_fn)
            for data in tqdm(v_loader):
                loss = model.compute_batch_loss(data)
                total_loss += loss.detach().cpu().numpy()
            # printing
            print("validation loss:{}, average validation loss:{}".format(
                float(total_loss),
                float(total_loss) / val_len))
            for dat in data:
                x, y = dat
            # predict
            predict, _ = model.forward((x.view(1, 80, -1)).to(device))
            np.save("save_stuff/mel_pred/epoch_" + str(ep) + "_mel_pred.npy",
                    predict.detach().cpu().numpy())
            np.save(
                "save_stuff/mel_pred/epoch_" + str(ep) + "_ground_truth.npy",
                y)
                    key, value, getattr(Config, key))
        if conflicts:
            print(warning)
            if args.cc:
                print("Will use the current config file.\n")
            else:
                print(
                    "Will fall back to saved config. If you want to use the current config file, run with flag "
                    "'-cc'\n")
                Config.set_config(conf)

    # Tensorboard
    writer = SummaryWriter(args.log_dir)

    print("Loading SSRN...")
    net = SSRN().to(device)
    net.apply(weight_init)

    l1_criterion = nn.L1Loss().to(device)
    bd_criterion = nn.BCEWithLogitsLoss().to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    global_step = 0

    # Learning rate decay. Noam scheme
    warmup_steps = 4000.0

    def decay(_):
        step = global_step + 1
        return warmup_steps**0.5 * min(step * warmup_steps**-1.5, step**-0.5)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=decay)