Beispiel #1
0
    def eval_model(self, global_epoch):
        running_loss = 0.
        for step, (melX, melY, lengths) in enumerate(self.valid_loader):
            self.model.eval()
            melX = melX.to(self.device)
            melY = melY.to(self.device)
            lengths = lengths.to(self.device)
            target_mask = sequence_mask(lengths, max_len=melY.size(1)).unsqueeze(-1)

            melX_outputs = self.model(melX)

            mel_l1_loss, mel_binary_div = self.spec_loss(melX_outputs, melY, target_mask)
            loss = (1 - self.w) * mel_l1_loss + self.w * mel_binary_div

            running_loss += loss.item()

        if global_epoch % self.eval_interval == 0:
            idx = min(1, len(lengths) - 1)
            mel_output = melX_outputs[idx].cpu().data.numpy()
            mel_output = prepare_spec_image(audio._denormalize(mel_output))
            self.writer.add_image("(Eval) Predicted mel spectrogram", mel_output, global_epoch)

            # Target mel spectrogram
            melY = melY[idx].cpu().data.numpy()
            melY = prepare_spec_image(audio._denormalize(melY))
            self.writer.add_image("(Eval) Target mel spectrogram", melY, global_epoch)
            melX = melX[idx].cpu().data.numpy()
            melX = prepare_spec_image(audio._denormalize(melX))
            self.writer.add_image("(Eval) Source mel spectrogram", melX, global_epoch)

        avg_loss = running_loss / len(self.valid_loader)
        self.writer.add_scalar("valid loss (per epoch)", avg_loss, global_epoch)
        print("Valid Loss: {}".format(avg_loss))
Beispiel #2
0
    def forward(self, input, length, words, display):
        # input is [bsz, len, 2*nhid]
        # length is [bsz, ]
        # words is [bsz, len]
        bsz, l, nhid2 = input.size()
        mask = sequence_mask(length, max_length=l)  # [bsz, len]
        num_blocks = int(math.ceil(l / self.block_size))
        block_embeddings, block_alphas = [], []
        for i in range(num_blocks):
            begin = i * self.block_size
            end = min(l, self.block_size * (i + 1))
            # e, a, info = self.attender(input[:, begin:end, :].contiguous(), mask[:, begin:end], display)
            e, a, _ = self.attender(input[:, begin:end, :].contiguous(),
                                    length - begin,
                                    words,
                                    display=False)
            block_embeddings.append(e)
            block_alphas.append(a)
        block_embeddings = torch.stack(
            block_embeddings)  # [nblock, bsz, hop, 2*nhid]
        block_embeddings = block_embeddings.view(
            num_blocks, -1, nhid2)  # [nblock, bsz*hop, 2*nhid]
        block_embeddings = torch.transpose(
            block_embeddings, 0, 1).contiguous()  # [bsz*hop, nblock, 2*nhid]
        block_mask = mask[:, ::self.block_size].contiguous().unsqueeze(
            1).expand(-1, self.att_hops,
                      -1).contiguous().view(-1,
                                            num_blocks)  # [bsz*hop, nblock]
        # sent_embeddings, sent_alphas, sent_info = self.attender(block_embeddings, block_mask, display)
        sent_embeddings, sent_alphas, _ = self.attender(block_embeddings,
                                                        block_mask.sum(dim=1),
                                                        words,
                                                        display=False)
        # [bsz*hop, hop, 2*nhid], [bsz*hop, hop, nblock]
        sent_embeddings = sent_embeddings.view(bsz, self.att_hops**2,
                                               nhid2)  #[bsz, hop*hop, 2*nhid]

        # construct alphas
        # block_alphas is list of [bsz, hop, blocksz] of length nblock
        if num_blocks < self.att_hops:  # TODO: hop should be 1 if num_blocks is small.
            sent_alphas = []
        else:
            sent_alphas = sent_alphas.view(bsz, self.att_hops, self.att_hops,
                                           num_blocks)
            sent_alphas = map(torch.squeeze,
                              torch.chunk(torch.transpose(sent_alphas, 0,
                                                          1).contiguous(),
                                          chunks=self.att_hops)
                              )  # list of [bsz, hop, nblock] of length hop

        info = []
        if display:
            pass

        return sent_embeddings, block_alphas + sent_alphas, info
Beispiel #3
0
    def forward(self, input, length, words, display):
        # input is [bsz, len, 2*nhid]
        # mask is [bsz, len]
        # words is [bsz, len]
        bsz, l, nhid2 = input.size()
        mask = sequence_mask(length, max_length=l)  # [bsz, len]
        z_soft = self.hard_attender(input.view(-1,
                                               nhid2)).view(bsz, l,
                                                            1)  # [bsz, l, 1]
        z_hard = torch.bernoulli(z_soft).byte() & mask.unsqueeze(2)
        gate = (z_hard.float() - z_soft).detach() + z_soft
        gate_h = gate * input  # [bsz, l, 2*nhid]
        '''
        # TODO: how to optimize the case when gate=0
        new_length = z_hard.int().sum(dim=1).squeeze().long() # [bsz, ]
        new_l = unwrap_scalar_variable(torch.max(new_length))
        new_input = [[Variable(input.data.new(nhid2).zero_())]*new_l for _ in range(bsz)]
        for i in range(bsz):
            k = 0
            for j in range(l): # TODO faster iteration
                if unwrap_scalar_variable(z_hard[i][j])==1:
                    new_input[i][k] = gate_h[i][j]
                    k += 1
            new_input[i] = torch.stack(new_input[i])
        new_input = torch.stack(new_input) # [bsz, new_l, 2*nhid]
        new_mask = sequence_mask(new_length, max_length=new_l) # [bsz, new_l]
        return self.soft_attender(new_input, new_mask)
        '''
        embeddings, alphas, _ = self.soft_attender(gate_h, length, words,
                                                   False)

        info = []
        if display:
            for i in range(bsz):
                s = '\n'
                for j in range(self.att_hops):
                    for k in range(unwrap_scalar_variable(length[i])):
                        if unwrap_scalar_variable(z_hard[i][k]) == 1:
                            s += '%s(%.2f) ' % (self.dictionary.itow(
                                unwrap_scalar_variable(words[i][k])),
                                                unwrap_scalar_variable(
                                                    alphas[i][j][k]))
                        else:
                            s += '--%s-- ' % (self.dictionary.itow(
                                unwrap_scalar_variable(words[i][k])))
                    s += '\n\n'
                info.append(s)

        supplements = {
            'attention': alphas,
            'info': info,
        }

        return embeddings, supplements
Beispiel #4
0
    def forward(self, inputs, input_lens):
        max_len = inputs.size(1)
        mask = ~sequence_mask(input_lens, max_len).unsqueeze(1).to(
            inputs.device)

        output = inputs
        for layer in self.encoder_layers:
            output = layer(output, mask)
        output = self.layer_norm(output)

        return output, mask
Beispiel #5
0
    def forward(self, input, target, lengths=None, mask=None, max_len=None):
        if lengths is None and mask is None:
            raise RuntimeError("Should provide either lengths or mask")

        # (B, T, 1)
        if mask is None:
            mask = sequence_mask(lengths, max_len).unsqueeze(-1)
            raise RuntimeError("Mask is None")

        # (B, T, D)
        mask_ = mask.expand_as(input)
        loss = self.criterion(input * mask_, target * mask_)
        return loss / mask_.sum()
Beispiel #6
0
    def train(self, global_step=0, global_epoch=1):
        while global_epoch < self.epoch:
            running_loss = 0.
            for step, (melX, melY, lengths) in enumerate(tqdm(self.train_loader)):
                self.model.train()

                # Learn rate scheduler
                current_lr = noam_learning_rate_decay(self.args.learn_rate, global_step)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = current_lr
                self.optimizer.zero_grad()

                # Transform data to CUDA device
                melX = melX.to(self.device)
                melY = melY.to(self.device)
                lengths = lengths.to(self.device)

                target_mask = sequence_mask(lengths, max_len=melY.size(1)).unsqueeze(-1)

                # Apply model
                melX_output = self.model(melX) # TODO : code model

                # Losses
                mel_l1_loss, mel_binary_div = self.spec_loss(melX_output, melY, target_mask)
                loss = (1 - self.w) * mel_l1_loss + self.w * mel_binary_div

                # Update
                loss.backward()
                self.optimizer.step()
                # Logs
                self.writer.add_scalar("loss", float(loss.item()), global_step)
                self.writer.add_scalar("mel_l1_loss", float(mel_l1_loss.item()), global_step)
                self.writer.add_scalar("mel_binary_div_loss", float(mel_binary_div.item()), global_step)
                self.writer.add_scalar("learning rate", current_lr, global_step)

                global_step += 1
                running_loss += loss.item()

            if (global_epoch % self.checkpoint_interval == 0):
                self.save_checkpoint(global_step, global_epoch)
            if global_epoch % self.eval_interval == 0:
                self.save_states(global_epoch, melX_output, melX, melY, lengths)
            self.eval_model(global_epoch)
            avg_loss = running_loss / len(self.train_loader)
            self.writer.add_scalar("train loss (per epoch)", avg_loss, global_epoch)
            print("Train Loss: {}".format(avg_loss))
            global_epoch += 1
Beispiel #7
0
def masked_cross_entropy(logits, target, length, per_example=False):
    """
    Args:
        logits (Variable, FloatTensor): [batch, max_len, num_classes]
            - unnormalized probability for each class
        target (Variable, LongTensor): [batch, max_len]
            - index of true class for each corresponding step
        length (Variable, LongTensor): [batch]
            - length of each data in a batch
    Returns:
        loss (Variable): []
            - An average loss value masked by the length
    """
    batch_size, max_len, num_classes = logits.size()

    # [batch_size * max_len, num_classes]
    logits_flat = logits.view(-1, num_classes)

    # [batch_size * max_len, num_classes]
    log_probs_flat = F.log_softmax(logits_flat, dim=1)

    # [batch_size * max_len, 1]
    target_flat = target.view(-1, 1)

    # Negative Log-likelihood: -sum {  1* log P(target)  + 0 log P(non-target)} = -sum( log P(target) )
    # [batch_size * max_len, 1]
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)

    # [batch_size, max_len]
    losses = losses_flat.view(batch_size, max_len)

    # [batch_size, max_len]
    mask = sequence_mask(sequence_length=length, max_len=max_len)

    # Apply masking on loss
    losses = losses * mask.float()

    # word-wise cross entropy
    # loss = losses.sum() / length.float().sum()

    if per_example:
        # loss: [batch_size]
        return losses.sum(1)
    else:
        loss = losses.sum()
        return loss, length.float().sum()
Beispiel #8
0
    def forward(self, input, length, words, display):
        # input is [bsz, len, 2*nhid]
        # length is [bsz, ]
        # words is [bsz, len]
        bsz, l, nhid2 = input.size()
        mask = sequence_mask(length, max_length=l)  # [bsz, len]
        compressed_embeddings = input.view(-1, nhid2)  # [bsz*len, 2*nhid]

        alphas = self.attender(compressed_embeddings)  # [bsz*len, hop]
        alphas = alphas.view(bsz, l, -1)  # [bsz, len, hop]
        alphas = torch.transpose(alphas, 1, 2).contiguous()  # [bsz, hop, len]
        alphas = self.softmax(alphas.view(-1, l))  # [bsz*hop, len]
        alphas = alphas.view(bsz, -1, l)  # [bsz, hop, len]

        mask = mask.unsqueeze(1).expand(-1, self.att_hops,
                                        -1)  # [bsz, hop, len]
        alphas = alphas * mask.float() + 1e-20
        alphas = alphas / alphas.sum(2, keepdim=True)  # renorm

        info = []
        if display:
            for i in range(bsz):
                s = '\n'
                for j in range(self.att_hops):
                    for k in range(unwrap_scalar_variable(length[i])):
                        s += '%s(%.2f) ' % (self.dictionary.itow(
                            unwrap_scalar_variable(words[i][k])),
                                            unwrap_scalar_variable(
                                                alphas[i][j][k]))
                    s += '\n\n'
                info.append(s)

        supplements = {
            'attention': alphas,  # [bsz, hop, len]
            'info': info,
        }

        return torch.bmm(alphas, input), supplements  # [bsz, hop, 2*nhid]
Beispiel #9
0
    def train(self,
              train_seq2seq,
              train_postnet,
              global_epoch=1,
              global_step=0):
        while global_epoch < self.epoch:
            running_loss = 0.
            running_linear_loss = 0.
            running_mel_loss = 0.
            for step, (ling, mel, linear, lengths,
                       speaker_ids) in enumerate(tqdm(self.train_loader)):
                self.model.train()
                ismultispeaker = speaker_ids is not None
                # Learn rate scheduler
                current_lr = noam_learning_rate_decay(
                    self.hparams.initial_learning_rate, global_step)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = current_lr
                self.optimizer.zero_grad()

                # Transform data to CUDA device
                if train_seq2seq:
                    ling = ling.to(self.device)
                    mel = mel.to(self.device)
                if train_postnet:
                    linear = linear.to(self.device)
                lengths = lengths.to(self.device)
                speaker_ids = speaker_ids.to(
                    self.device) if ismultispeaker else None
                target_mask = sequence_mask(lengths,
                                            max_len=mel.size(1)).unsqueeze(-1)

                # Apply model
                if train_seq2seq and train_postnet:
                    _, mel_outputs, linear_outputs = self.model(
                        ling, mel, speaker_ids=speaker_ids)
                #elif train_seq2seq:
                #    mel_style = self.model.gst(tmel)
                #    style_embed = mel_style.expand_as(smel)
                #    mel_input = smel + style_embed
                #    mel_outputs = self.model.seq2seq(mel_input)
                #    linear_outputs = None
                #elif train_postnet:
                #    linear_outputs = self.model.postnet(smel)
                #    mel_outputs = None

                # Losses
                if train_seq2seq:
                    mel_l1_loss, mel_binary_div = self.spec_loss(
                        mel_outputs, mel, target_mask)
                    mel_loss = (1 -
                                self.w) * mel_l1_loss + self.w * mel_binary_div
                if train_postnet:
                    linear_l1_loss, linear_binary_div = self.spec_loss(
                        linear_outputs, linear, target_mask)
                    linear_loss = (
                        1 -
                        self.w) * linear_l1_loss + self.w * linear_binary_div

                # Combine losses
                if train_seq2seq and train_postnet:
                    loss = mel_loss + linear_loss
                elif train_seq2seq:
                    loss = mel_loss
                elif train_postnet:
                    loss = linear_loss

                # Update
                loss.backward()
                self.optimizer.step()
                # Logs
                if train_seq2seq:
                    self.writer.add_scalar("mel loss", float(mel_loss.item()),
                                           global_step)
                    self.writer.add_scalar("mel_l1_loss",
                                           float(mel_l1_loss.item()),
                                           global_step)
                    self.writer.add_scalar("mel_binary_div_loss",
                                           float(mel_binary_div.item()),
                                           global_step)
                if train_postnet:
                    self.writer.add_scalar("linear_loss",
                                           float(linear_loss.item()),
                                           global_step)
                    self.writer.add_scalar("linear_l1_loss",
                                           float(linear_l1_loss.item()),
                                           global_step)
                    self.writer.add_scalar("linear_binary_div_loss",
                                           float(linear_binary_div.item()),
                                           global_step)
                self.writer.add_scalar("loss", float(loss.item()), global_step)
                self.writer.add_scalar("learning rate", current_lr,
                                       global_step)

                global_step += 1
                running_loss += loss.item()
                running_linear_loss += linear_loss.item()
                running_mel_loss += mel_loss.item()

            if (global_epoch % self.checkpoint_interval == 0):
                self.save_checkpoint(global_step, global_epoch)
            if global_epoch % self.eval_interval == 0:
                self.save_states(global_epoch, mel_outputs, linear_outputs,
                                 ling, mel, linear, lengths)
            self.eval_model(global_epoch, train_seq2seq, train_postnet)
            avg_loss = running_loss / len(self.train_loader)
            avg_linear_loss = running_linear_loss / len(self.train_loader)
            avg_mel_loss = running_mel_loss / len(self.train_loader)
            self.writer.add_scalar("train loss (per epoch)", avg_loss,
                                   global_epoch)
            self.writer.add_scalar("train linear loss (per epoch)",
                                   avg_linear_loss, global_epoch)
            self.writer.add_scalar("train mel loss (per epoch)", avg_mel_loss,
                                   global_epoch)
            print("Train Loss: {}".format(avg_loss))
            global_epoch += 1
Beispiel #10
0
    def eval_model(self, global_epoch, train_seq2seq, train_postnet):
        happy_ref = np.load('../feat/Acoustic_frame/mel/emc00103.npy')
        happy_ref = torch.from_numpy(happy_ref).unsqueeze(0)
        sad_ref = np.load('../feat/Acoustic_frame/mel/ema00203.npy')
        sad_ref = torch.from_numpy(sad_ref).unsqueeze(0)
        angry_ref = np.load('../feat/Acoustic_frame/mel/eme00303.npy')
        angry_ref = torch.from_numpy(angry_ref).unsqueeze(0)
        running_loss = 0.
        running_linear_loss = 0.
        running_mel_loss = 0.
        for step, (ling, mel, linear, lengths,
                   speaker_ids) in enumerate(self.valid_loader):
            self.model.eval()
            ismultispeaker = speaker_ids is not None
            if train_seq2seq:
                ling = ling.to(self.device)
                mel = mel.to(self.device)
                happy_ref = happy_ref.to(self.device)
                sad_ref = sad_ref.to(self.device)
                angry_ref = angry_ref.to(self.device)
            if train_postnet:
                linear = linear.to(self.device)
            lengths = lengths.to(self.device)
            speaker_ids = speaker_ids.to(
                self.device) if ismultispeaker else None
            target_mask = sequence_mask(lengths,
                                        max_len=mel.size(1)).unsqueeze(-1)
            with torch.no_grad():
                # Apply model
                if train_seq2seq and train_postnet:
                    _, mel_outputs, linear_outputs = self.model(
                        ling, mel, speaker_ids=speaker_ids)
                """
                elif train_seq2seq:
                    mel_style = self.model.gst(tmel)
                    style_embed = mel_style.expand_as(smel)
                    mel_input = smel + style_embed
                    mel_outputs = self.model.seq2seq(mel_input)
                    linear_outputs = None
                elif train_postnet:
                    linear_outputs = self.model.postnet(tmel)
                    mel_outputs = None
                """

            # Losses
            if train_seq2seq:
                mel_l1_loss, mel_binary_div = self.spec_loss(
                    mel_outputs, mel, target_mask)
                mel_loss = (1 - self.w) * mel_l1_loss + self.w * mel_binary_div
            if train_postnet:
                linear_l1_loss, linear_binary_div = self.spec_loss(
                    linear_outputs, linear, target_mask)
                linear_loss = (
                    1 - self.w) * linear_l1_loss + self.w * linear_binary_div

            # Combine losses
            if train_seq2seq and train_postnet:
                loss = mel_loss + linear_loss
            elif train_seq2seq:
                loss = mel_loss
            elif train_postnet:
                loss = linear_loss
            running_loss += loss.item()
            running_linear_loss += linear_loss.item()
            running_mel_loss += mel_loss.item()
        B = ling.size(0)
        if ismultispeaker:
            speaker_ids = np.zeros(B)
            speaker_ids = torch.LongTensor(speaker_ids).to(self.device)
        else:
            speaker_ids = None
        _, happy_mel_outputs, happy_linear_outputs = self.model(
            ling, happy_ref, speaker_ids)
        _, sad_mel_outputs, sad_linear_outputs = self.model(
            ling, sad_ref, speaker_ids)
        _, angry_mel_outputs, angry_linear_outputs = self.model(
            ling, angry_ref, speaker_ids)

        if global_epoch % self.eval_interval == 0:
            for idx in range(B):
                if mel_outputs is not None:
                    happy_mel_output = happy_mel_outputs[idx].cpu().data.numpy(
                    )
                    happy_mel_output = prepare_spec_image(
                        audio._denormalize(happy_mel_output))
                    self.writer.add_image(
                        "(Eval) Happy mel spectrogram {}".format(idx),
                        happy_mel_output, global_epoch)

                    sad_mel_output = sad_mel_outputs[idx].cpu().data.numpy()
                    sad_mel_output = prepare_spec_image(
                        audio._denormalize(sad_mel_output))
                    self.writer.add_image(
                        "(Eval) Sad mel spectrogram {}".format(idx),
                        sad_mel_output, global_epoch)

                    angry_mel_output = angry_mel_outputs[idx].cpu().data.numpy(
                    )
                    angry_mel_output = prepare_spec_image(
                        audio._denormalize(angry_mel_output))
                    self.writer.add_image(
                        "(Eval) Angry mel spectrogram {}".format(idx),
                        angry_mel_output, global_epoch)

                    mel_output = mel_outputs[idx].cpu().data.numpy()
                    mel_output = prepare_spec_image(
                        audio._denormalize(mel_output))
                    self.writer.add_image(
                        "(Eval) Predicted mel spectrogram {}".format(idx),
                        mel_output, global_epoch)

                    mel1 = mel[idx].cpu().data.numpy()
                    mel1 = prepare_spec_image(audio._denormalize(mel1))
                    self.writer.add_image(
                        "(Eval) Source mel spectrogram {}".format(idx), mel1,
                        global_epoch)

                if linear_outputs is not None:
                    linear_output = linear_outputs[idx].cpu().data.numpy()
                    spectrogram = prepare_spec_image(
                        audio._denormalize(linear_output))
                    self.writer.add_image(
                        "(Eval) Predicted spectrogram {}".format(idx),
                        spectrogram, global_epoch)
                    signal = audio.inv_spectrogram(linear_output.T)
                    signal /= np.max(np.abs(signal))
                    path = join(
                        self.checkpoint_dir,
                        "epoch{:09d}_{}_predicted.wav".format(
                            global_epoch, idx))
                    audio.save_wav(signal, path)
                    try:
                        self.writer.add_audio(
                            "(Eval) Predicted audio signal {}".format(idx),
                            signal,
                            global_epoch,
                            sample_rate=self.fs)
                    except Exception as e:
                        warn(str(e))
                        pass

                    happy_linear_output = happy_linear_outputs[idx].cpu(
                    ).data.numpy()
                    spectrogram = prepare_spec_image(
                        audio._denormalize(happy_linear_output))
                    self.writer.add_image(
                        "(Eval) Happy spectrogram {}".format(idx), spectrogram,
                        global_epoch)
                    signal = audio.inv_spectrogram(happy_linear_output.T)
                    signal /= np.max(np.abs(signal))
                    path = join(
                        self.checkpoint_dir,
                        "epoch{:09d}_{}_happy.wav".format(global_epoch, idx))
                    audio.save_wav(signal, path)
                    try:
                        self.writer.add_audio(
                            "(Eval) Happy audio signal {}".format(idx),
                            signal,
                            global_epoch,
                            sample_rate=self.fs)
                    except Exception as e:
                        warn(str(e))
                        pass

                    angry_linear_output = angry_linear_outputs[idx].cpu(
                    ).data.numpy()
                    spectrogram = prepare_spec_image(
                        audio._denormalize(angry_linear_output))
                    self.writer.add_image(
                        "(Eval) Angry spectrogram {}".format(idx), spectrogram,
                        global_epoch)
                    signal = audio.inv_spectrogram(angry_linear_output.T)
                    signal /= np.max(np.abs(signal))
                    path = join(
                        self.checkpoint_dir,
                        "epoch{:09d}_{}_angry.wav".format(global_epoch, idx))
                    audio.save_wav(signal, path)
                    try:
                        self.writer.add_audio(
                            "(Eval) Angry audio signal {}".format(idx),
                            signal,
                            global_epoch,
                            sample_rate=self.fs)
                    except Exception as e:
                        warn(str(e))
                        pass

                    sad_linear_output = sad_linear_outputs[idx].cpu(
                    ).data.numpy()
                    spectrogram = prepare_spec_image(
                        audio._denormalize(sad_linear_output))
                    self.writer.add_image(
                        "(Eval) Sad spectrogram {}".format(idx), spectrogram,
                        global_epoch)
                    signal = audio.inv_spectrogram(sad_linear_output.T)
                    signal /= np.max(np.abs(signal))
                    path = join(
                        self.checkpoint_dir,
                        "epoch{:09d}_{}_sad.wav".format(global_epoch, idx))
                    audio.save_wav(signal, path)
                    try:
                        self.writer.add_audio(
                            "(Eval) Sad audio signal {}".format(idx),
                            signal,
                            global_epoch,
                            sample_rate=self.fs)
                    except Exception as e:
                        warn(str(e))
                        pass

                    linear1 = linear[idx].cpu().data.numpy()
                    spectrogram = prepare_spec_image(
                        audio._denormalize(linear1))
                    self.writer.add_image(
                        "(Eval) Target spectrogram {}".format(idx),
                        spectrogram, global_epoch)
                    signal = audio.inv_spectrogram(linear1.T)
                    signal /= np.max(np.abs(signal))
                    try:
                        self.writer.add_audio(
                            "(Eval) Target audio signal {}".format(idx),
                            signal,
                            global_epoch,
                            sample_rate=self.fs)
                    except Exception as e:
                        warn(str(e))
                        pass

        avg_loss = running_loss / len(self.valid_loader)
        avg_linear_loss = running_linear_loss / len(self.valid_loader)
        avg_mel_loss = running_mel_loss / len(self.valid_loader)
        self.writer.add_scalar("valid loss (per epoch)", avg_loss,
                               global_epoch)
        self.writer.add_scalar("valid linear loss (per epoch)",
                               avg_linear_loss, global_epoch)
        self.writer.add_scalar("valid mel loss (per epoch)", avg_mel_loss,
                               global_epoch)
        print("Valid Loss: {}".format(avg_loss))
Beispiel #11
0
    def forward(self, input_seq, decoder_hidden, encoder_outputs, src_lens,
                rate_sents, cate_sents, srclen_cates, senti_sents):
        """ Args:
            - input_seq      : (batch_size)
            - decoder_hidden : (t=0) last encoder hidden state (num_layers * num_directions, batch_size, hidden_size)
                               (t>0) previous decoder hidden state (num_layers, batch_size, hidden_size)
            - encoder_outputs: (max_src_len, batch_size, hidden_size * num_directions)

            Returns:
            - output           : (batch_size, vocab_size)
            - decoder_hidden   : (num_layers, batch_size, hidden_size)
            - attention_weights: (batch_size, max_src_len)
        """
        # (batch_size) => (seq_len=1, batch_size)
        input_seq = input_seq.unsqueeze(0)

        # (seq_len=1, batch_size) => (seq_len=1, batch_size, word_vec_size)
        emb = self.embedding(input_seq)

        # Add external embeddings: (batch_size, feature_size) => (num_layers, batch_size, feature_size
        if self.ext_rate_embedding:
            ext_rate_embedding = self.ext_rate_embedding(rate_sents)
            ext_rate_embedding = ext_rate_embedding.unsqueeze(0).repeat(
                self.num_layers, 1, 1)
        if self.ext_appcate_embedding:
            ext_appcate_embedding = self.ext_appcate_embedding(cate_sents)
            ext_appcate_embedding = ext_appcate_embedding.unsqueeze(0).repeat(
                self.num_layers, 1, 1)
        if self.ext_seqlen_embedding:
            ext_seqlen_embedding = self.ext_seqlen_embedding(srclen_cates)
            ext_seqlen_embedding = ext_seqlen_embedding.unsqueeze(0).repeat(
                self.num_layers, 1, 1)
        if self.ext_senti_embedding:
            ext_senti_embedding = self.ext_senti_embedding(senti_sents)
            ext_senti_embedding = ext_senti_embedding.unsqueeze(0).repeat(
                self.num_layers, 1, 1)

        # rnn returns:
        # - decoder_output: (seq_len=1, batch_size, hidden_size)
        # - decoder_hidden: (num_layers, batch_size, hidden_size)

        if self.tie_ext_feature:
            if self.ext_rate_embedding:
                decoder_hidden = torch.cat(
                    (decoder_hidden, ext_rate_embedding), 2)
            if self.ext_appcate_embedding:
                decoder_hidden = torch.cat(
                    (decoder_hidden, ext_appcate_embedding), 2)
            if self.ext_seqlen_embedding:
                decoder_hidden = torch.cat(
                    (decoder_hidden, ext_seqlen_embedding), 2)
            if self.ext_senti_embedding:
                decoder_hidden = torch.cat(
                    (decoder_hidden, ext_senti_embedding), 2)
            # decoder_hidden = torch.cat((decoder_hidden, ext_rate_embedding, ext_appcate_embedding, ext_seqlen_embedding, ext_senti_embedding), 2)
        decoder_hidden = F.tanh(self.W_r(decoder_hidden))
        decoder_output, decoder_hidden = self.rnn(emb, decoder_hidden)

        # (seq_len=1, batch_size, hidden_size) => (batch_size, seq_len=1, hidden_size)
        decoder_output = decoder_output.transpose(0, 1)
        """ 
        ------------------------------------------------------------------------------------------
        Notes of computing attention scores
        ------------------------------------------------------------------------------------------
        # For-loop version:

        max_src_len = encoder_outputs.size(0)
        batch_size = encoder_outputs.size(1)
        attention_scores = Variable(torch.zeros(batch_size, max_src_len))

        # For every batch, every time step of encoder's hidden state, calculate attention score.
        for b in range(batch_size):
            for t in range(max_src_len):
                # Loung. eq(8) -- general form content-based attention:
                attention_scores[b,t] = decoder_output[b].dot(attention.W_a(encoder_outputs[t,b]))

        ------------------------------------------------------------------------------------------
        # Vectorized version:

        1. decoder_output: (batch_size, seq_len=1, hidden_size)
        2. encoder_outputs: (max_src_len, batch_size, hidden_size * num_directions)
        3. W_a(encoder_outputs): (max_src_len, batch_size, hidden_size)
                        .transpose(0,1)  : (batch_size, max_src_len, hidden_size) 
                        .transpose(1,2)  : (batch_size, hidden_size, max_src_len)
        4. attention_scores: 
                        (batch_size, seq_len=1, hidden_size) * (batch_size, hidden_size, max_src_len) 
                        => (batch_size, seq_len=1, max_src_len)
        """

        if self.attention:
            # attention_scores: (batch_size, seq_len=1, max_src_len)
            attention_scores = torch.bmm(
                decoder_output,
                self.W_a(encoder_outputs).transpose(0, 1).transpose(1, 2))

            # attention_mask: (batch_size, seq_len=1, max_src_len)
            attention_mask = sequence_mask(src_lens).unsqueeze(1)

            # Fills elements of tensor with `-float('inf')` where `mask` is 1.
            attention_scores.data.masked_fill_(1 - attention_mask.data,
                                               -float('inf'))

            # attention_weights: (batch_size, seq_len=1, max_src_len) => (batch_size, max_src_len) for `F.softmax`
            # => (batch_size, seq_len=1, max_src_len)
            try:  # torch 0.3.x
                attention_weights = F.softmax(attention_scores.squeeze(1),
                                              dim=1).unsqueeze(1)
            except:
                attention_weights = F.softmax(
                    attention_scores.squeeze(1)).unsqueeze(1)

            # context_vector:
            # (batch_size, seq_len=1, max_src_len) * (batch_size, max_src_len, encoder_hidden_size * num_directions)
            # => (batch_size, seq_len=1, encoder_hidden_size * num_directions)
            context_vector = torch.bmm(attention_weights,
                                       encoder_outputs.transpose(0, 1))

            # concat_input: (batch_size, seq_len=1, encoder_hidden_size * num_directions + decoder_hidden_size)
            concat_input = torch.cat([context_vector, decoder_output], -1)

            # (batch_size, seq_len=1, encoder_hidden_size * num_directions + decoder_hidden_size) => (batch_size, seq_len=1, decoder_hidden_size)
            concat_output = F.tanh(self.W_c(concat_input))

            # Prepare returns:
            # (batch_size, seq_len=1, max_src_len) => (batch_size, max_src_len)
            attention_weights = attention_weights.squeeze(1)
        else:
            attention_weights = None
            concat_output = decoder_output

        # If input and output embeddings are tied,
        # project `decoder_hidden_size` to `word_vec_size`.
        if self.tie_embeddings:
            output = self.W_s(self.W_proj(concat_output))
        else:
            # (batch_size, seq_len=1, decoder_hidden_size) => (batch_size, seq_len=1, vocab_size)
            output = self.W_s(concat_output)

            # Prepare returns:
        # (batch_size, seq_len=1, vocab_size) => (batch_size, vocab_size)
        output = output.squeeze(1)

        del src_lens

        return output, decoder_hidden, attention_weights