Exemple #1
0
    def __init__(self, idim, adim, aheads, eunits, depth, use_scaled_pos_enc,
                 use_masking, normalize_before, concat_after, pointwise_layer,
                 conv_kernel):

        super(MelEncoder, self).__init__()

        self.use_scaled_pos_enc = use_scaled_pos_enc
        self.use_masking = use_masking

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = (ScaledPositionalEncoding
                         if self.use_scaled_pos_enc else PositionalEncoding)

        input_layer = torch.nn.Conv1d(idim, adim, 1)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=adim,
            attention_heads=aheads,
            linear_units=eunits,
            num_blocks=depth,
            input_layer=input_layer,
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=normalize_before,
            concat_after=concat_after,
            positionwise_layer_type=pointwise_layer,
            positionwise_conv_kernel_size=conv_kernel,
        )
def main(opt):
    with open(opt.infos_path, 'rb') as f:
        infos = pickle.load(f)

    #override and collect parameters
    if len(opt.input_h5) == 0:
        opt.input_h5 = infos['opt'].input_h5

    if len(opt.input_json) == 0:
        opt.input_json = infos['opt'].input_json

    if opt.batch_size == 0:
        opt.batch_size = infos['opt'].batch_size

    if len(opt.id) == 0:
        opt.id = infos['opt'].id
    ignore = ['id', 'batch_size', 'beam_size', 'strat_from', 'language_eval']

    for key, value in vars(infos['opt']).items():
        if key not in ignore:
            if key in vars(opt):
                assert vars(opt)[key] == vars(infos['opt'])[key],\
                key+" option not consistent"
            else:
                vars(opt).update({key: value})
    vocab = infos['vocab']
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    encoder = Encoder()
    decoder = Decoder(opt)
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    decoder.load_state_dict(torch.load(opt.model, map_location=str(device)))
    encoder.eval()
    decoder.eval()
    criterion = utils.LanguageModelCriterion().to(device)
    if len(opt.image_folder) == 0:
        loader = get_loader(opt, 'test')
        loader.ix_to_word = vocab
        loss, split_predictions, lang_stats = \
        eval_utils.eval_split(encoder, decoder, criterion, opt, vars(opt))
        print('loss: ', loss)
        print(lang_stats)

        result_json_path = os.path.join(opt.checkpoint_path, "captions_"+opt.split+"2014_"+opt.id+"_results.json")
        with open(result_json_path, "w") as f:
            json.dump(split_predictions, f)
Exemple #3
0
    def send(self, sock, msg):
        """
        sends data to peer from socket
        """
        try:
            peer_ip, peer_port = sock.getpeername()
            to_send = {
                "peer_ip": peer_ip,
                "peer_port": peer_port,
                "ip": self.ip_addr,
                "port": self.port
            }

            data = self.make_header(msg, to_send)
            data = Encoder.json_decode(data)
            sock.send(data)
        except socket.error:
            Logger.log_error("cant send to peer")
Exemple #4
0
 def handle(self):
     raw_data = self.request.recv(BUFFER_SIZE).strip()
     self.data = Encoder.json_encode(raw_data)
     #for tests
     Logger.log_info(self.data)
Exemple #5
0
    def __init__(self, idim: int, odim: int, hp: Dict):
        """Initialize feed-forward Transformer module.
        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
        """
        # initialize base classes
        assert check_argument_types()
        torch.nn.Module.__init__(self)

        # fill missing arguments

        # store hyperparameters
        self.idim = idim
        self.odim = odim

        self.use_scaled_pos_enc = hp.model.use_scaled_pos_enc
        self.use_masking = hp.model.use_masking

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = (ScaledPositionalEncoding
                         if self.use_scaled_pos_enc else PositionalEncoding)

        # define encoder
        encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                 embedding_dim=hp.model.adim,
                                                 padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=hp.model.adim,
            attention_heads=hp.model.aheads,
            linear_units=hp.model.eunits,
            num_blocks=hp.model.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.model.encoder_normalize_before,
            concat_after=hp.model.encoder_concat_after,
            positionwise_layer_type=hp.model.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.model.
            positionwise_conv_kernel_size,
        )

        self.duration_predictor = DurationPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
        )

        self.energy_predictor = EnergyPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
            min=hp.data.e_min,
            max=hp.data.e_max,
        )
        self.energy_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)

        self.pitch_predictor = PitchPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
            min=hp.data.p_min,
            max=hp.data.p_max,
        )
        self.pitch_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)

        # define length regulator
        self.length_regulator = LengthRegulator()

        ###### AdaSpeech

        self.utterance_encoder = UtteranceEncoder(idim=hp.audio.n_mels)

        self.phoneme_level_encoder = PhonemeLevelEncoder(idim=hp.audio.n_mels)

        self.phoneme_level_predictor = PhonemeLevelPredictor(
            idim=hp.model.adim)

        self.phone_level_embed = torch.nn.Linear(hp.model.phn_latent_dim,
                                                 hp.model.adim)

        self.acoustic_criterion = AcousticPredictorLoss()

        # define decoder
        # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
        self.decoder = Encoder(
            idim=hp.model.adim,
            attention_dim=hp.model.ddim,
            attention_heads=hp.model.aheads,
            linear_units=hp.model.dunits,
            num_blocks=hp.model.dlayers,
            input_layer="linear",
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.model.decoder_normalize_before,
            concat_after=hp.model.decoder_concat_after,
            positionwise_layer_type=hp.model.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.model.
            positionwise_conv_kernel_size,
        )

        # define postnet
        self.postnet = (None if hp.model.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=hp.model.postnet_layers,
            n_chans=hp.model.postnet_chans,
            n_filts=hp.model.postnet_filts,
            use_batch_norm=hp.model.use_batch_norm,
            dropout_rate=hp.model.postnet_dropout_rate,
        ))

        # define final projection
        self.feat_out = torch.nn.Linear(hp.model.ddim,
                                        odim * hp.model.reduction_factor)

        # initialize parameters
        self._reset_parameters(
            init_type=hp.model.transformer_init,
            init_enc_alpha=hp.model.initial_encoder_alpha,
            init_dec_alpha=hp.model.initial_decoder_alpha,
        )

        # define criterions
        self.duration_criterion = DurationPredictorLoss()
        self.energy_criterion = EnergyPredictorLoss()
        self.pitch_criterion = PitchPredictorLoss()
        self.criterion = torch.nn.L1Loss(reduction="mean")
        self.use_weighted_masking = hp.model.use_weighted_masking
Exemple #6
0
    def __init__(self, idim, odim):
        """Initialize feed-forward Transformer module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.

        """
        # initialize base classes

        torch.nn.Module.__init__(self)

        # fill missing arguments

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.reduction_factor = hp.reduction_factor
        self.use_scaled_pos_enc = hp.use_scaled_pos_enc
        self.use_masking = hp.use_masking

        # TODO(kan-bayashi): support reduction_factor > 1
        if self.reduction_factor != 1:
            raise NotImplementedError("Support only reduction_factor = 1.")

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding

        # define encoder
        encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                 embedding_dim=hp.adim,
                                                 padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=hp.adim,
            attention_heads=hp.aheads,
            linear_units=hp.eunits,
            num_blocks=hp.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.encoder_normalize_before,
            concat_after=hp.encoder_concat_after,
            positionwise_layer_type=hp.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.positionwise_conv_kernel_size)

        self.duration_predictor = DurationPredictor(
            idim=hp.adim,
            n_layers=hp.duration_predictor_layers,
            n_chans=hp.duration_predictor_chans,
            kernel_size=hp.duration_predictor_kernel_size,
            dropout_rate=hp.duration_predictor_dropout_rate,
        )

        self.energy_predictor = EnergyPredictor(
            idim=hp.adim,
            n_layers=hp.duration_predictor_layers,
            n_chans=hp.duration_predictor_chans,
            kernel_size=hp.duration_predictor_kernel_size,
            dropout_rate=hp.duration_predictor_dropout_rate,
        )
        self.energy_embed = torch.nn.Linear(hp.adim, hp.adim)

        self.pitch_predictor = PitchPredictor(
            idim=hp.adim,
            n_layers=hp.duration_predictor_layers,
            n_chans=hp.duration_predictor_chans,
            kernel_size=hp.duration_predictor_kernel_size,
            dropout_rate=hp.duration_predictor_dropout_rate,
        )
        self.pitch_embed = torch.nn.Linear(hp.adim, hp.adim)

        # define length regulator
        self.length_regulator = LengthRegulator()

        # define decoder
        # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
        self.decoder = Encoder(
            idim=256,
            attention_dim=256,
            attention_heads=hp.aheads,
            linear_units=hp.dunits,
            num_blocks=hp.dlayers,
            input_layer=None,
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.decoder_normalize_before,
            concat_after=hp.decoder_concat_after,
            positionwise_layer_type=hp.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.positionwise_conv_kernel_size)

        # define postnet
        self.postnet = (None if hp.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=hp.postnet_layers,
            n_chans=hp.postnet_chans,
            n_filts=hp.postnet_filts,
            use_batch_norm=hp.use_batch_norm,
            dropout_rate=hp.postnet_dropout_rate,
        ))

        # define final projection
        self.feat_out = torch.nn.Linear(hp.adim, odim * hp.reduction_factor)

        # initialize parameters
        self._reset_parameters(init_type=hp.transformer_init,
                               init_enc_alpha=hp.initial_encoder_alpha,
                               init_dec_alpha=hp.initial_decoder_alpha)

        # define criterions
        self.duration_criterion = DurationPredictorLoss()
        self.energy_criterion = EnergyPredictorLoss()
        self.pitch_criterion = PitchPredictorLoss()
        self.criterion = torch.nn.L1Loss(reduction='mean')
Exemple #7
0
class FeedForwardTransformer(torch.nn.Module):
    """Feed Forward Transformer for TTS a.k.a. FastSpeech.

    This is a module of FastSpeech, feed-forward Transformer with duration predictor described in
    `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive
    processing during inference, resulting in fast decoding compared with auto-regressive Transformer.

    .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
        https://arxiv.org/pdf/1905.09263.pdf

    """
    def __init__(self, idim, odim):
        """Initialize feed-forward Transformer module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.

        """
        # initialize base classes

        torch.nn.Module.__init__(self)

        # fill missing arguments

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.reduction_factor = hp.reduction_factor
        self.use_scaled_pos_enc = hp.use_scaled_pos_enc
        self.use_masking = hp.use_masking

        # TODO(kan-bayashi): support reduction_factor > 1
        if self.reduction_factor != 1:
            raise NotImplementedError("Support only reduction_factor = 1.")

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding

        # define encoder
        encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                 embedding_dim=hp.adim,
                                                 padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=hp.adim,
            attention_heads=hp.aheads,
            linear_units=hp.eunits,
            num_blocks=hp.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.encoder_normalize_before,
            concat_after=hp.encoder_concat_after,
            positionwise_layer_type=hp.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.positionwise_conv_kernel_size)

        self.duration_predictor = DurationPredictor(
            idim=hp.adim,
            n_layers=hp.duration_predictor_layers,
            n_chans=hp.duration_predictor_chans,
            kernel_size=hp.duration_predictor_kernel_size,
            dropout_rate=hp.duration_predictor_dropout_rate,
        )

        self.energy_predictor = EnergyPredictor(
            idim=hp.adim,
            n_layers=hp.duration_predictor_layers,
            n_chans=hp.duration_predictor_chans,
            kernel_size=hp.duration_predictor_kernel_size,
            dropout_rate=hp.duration_predictor_dropout_rate,
        )
        self.energy_embed = torch.nn.Linear(hp.adim, hp.adim)

        self.pitch_predictor = PitchPredictor(
            idim=hp.adim,
            n_layers=hp.duration_predictor_layers,
            n_chans=hp.duration_predictor_chans,
            kernel_size=hp.duration_predictor_kernel_size,
            dropout_rate=hp.duration_predictor_dropout_rate,
        )
        self.pitch_embed = torch.nn.Linear(hp.adim, hp.adim)

        # define length regulator
        self.length_regulator = LengthRegulator()

        # define decoder
        # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
        self.decoder = Encoder(
            idim=256,
            attention_dim=256,
            attention_heads=hp.aheads,
            linear_units=hp.dunits,
            num_blocks=hp.dlayers,
            input_layer=None,
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.decoder_normalize_before,
            concat_after=hp.decoder_concat_after,
            positionwise_layer_type=hp.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.positionwise_conv_kernel_size)

        # define postnet
        self.postnet = (None if hp.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=hp.postnet_layers,
            n_chans=hp.postnet_chans,
            n_filts=hp.postnet_filts,
            use_batch_norm=hp.use_batch_norm,
            dropout_rate=hp.postnet_dropout_rate,
        ))

        # define final projection
        self.feat_out = torch.nn.Linear(hp.adim, odim * hp.reduction_factor)

        # initialize parameters
        self._reset_parameters(init_type=hp.transformer_init,
                               init_enc_alpha=hp.initial_encoder_alpha,
                               init_dec_alpha=hp.initial_decoder_alpha)

        # define criterions
        self.duration_criterion = DurationPredictorLoss()
        self.energy_criterion = EnergyPredictorLoss()
        self.pitch_criterion = PitchPredictorLoss()
        self.criterion = torch.nn.L1Loss(reduction='mean')

    def _forward(self,
                 xs,
                 ilens,
                 ys=None,
                 olens=None,
                 ds=None,
                 es=None,
                 ps=None,
                 is_inference=False):
        # forward encoder
        x_masks = self._source_mask(ilens)
        hs, _ = self.encoder(xs, x_masks)  # (B, Tmax, adim)
        # print("Ys :",ys.shape)     torch.Size([32, 868, 80])

        # forward duration predictor and length regulator
        d_masks = make_pad_mask(ilens).to(xs.device)

        if is_inference:
            d_outs = self.duration_predictor.inference(hs,
                                                       d_masks)  # (B, Tmax)
            hs = self.length_regulator(hs, d_outs, ilens)  # (B, Lmax, adim)
            e_outs = self.energy_predictor.inference(hs)
            p_outs = self.pitch_predictor.inference(hs)
            one_hot_energy = energy_to_one_hot(e_outs,
                                               False)  # (B, Lmax, adim)
            one_hot_pitch = pitch_to_one_hot(p_outs, False)  # (B, Lmax, adim)
        else:
            with torch.no_grad():
                # ds = self.duration_calculator(xs, ilens, ys, olens)  # (B, Tmax)
                one_hot_energy = energy_to_one_hot(
                    es)  # (B, Lmax, adim)   torch.Size([32, 868, 256])
                # print("one_hot_energy:", one_hot_energy.shape)
                one_hot_pitch = pitch_to_one_hot(
                    ps)  # (B, Lmax, adim)   torch.Size([32, 868, 256])
                # print("one_hot_pitch:", one_hot_pitch.shape)
            mel_masks = make_pad_mask(olens).to(xs.device)
            # print("Before Hs:", hs.shape)  torch.Size([32, 121, 256])
            d_outs = self.duration_predictor(hs, d_masks)  # (B, Tmax)
            # print("d_outs:", d_outs.shape)        torch.Size([32, 121])
            hs = self.length_regulator(hs, ds, ilens)  # (B, Lmax, adim)
            # print("After Hs:",hs.shape)  torch.Size([32, 868, 256])
            e_outs = self.energy_predictor(hs, mel_masks)
            # print("e_outs:", e_outs.shape)  torch.Size([32, 868])
            p_outs = self.pitch_predictor(hs, mel_masks)
            # print("p_outs:", p_outs.shape)   torch.Size([32, 868])
        hs = hs + self.pitch_embed(one_hot_pitch)  # (B, Lmax, adim)
        hs = hs + self.energy_embed(one_hot_energy)  # (B, Lmax, adim)
        # forward decoder
        if olens is not None:
            h_masks = self._source_mask(olens)
        else:
            h_masks = None
        zs, _ = self.decoder(hs, h_masks)  # (B, Lmax, adim)
        before_outs = self.feat_out(zs).view(zs.size(0), -1,
                                             self.odim)  # (B, Lmax, odim)

        # postnet -> (B, Lmax//r * r, odim)
        if self.postnet is None:
            after_outs = before_outs
        else:
            after_outs = before_outs + self.postnet(before_outs.transpose(
                1, 2)).transpose(1, 2)

        if is_inference:
            return before_outs, after_outs, d_outs
        else:
            return before_outs, after_outs, d_outs, e_outs, p_outs

    def forward(self, xs, ilens, ys, olens, ds, es, ps, *args, **kwargs):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).

        Returns:
            Tensor: Loss value.

        """
        # remove unnecessary padded part (for multi-gpus)
        xs = xs[:, :max(ilens)]
        ys = ys[:, :max(olens)]

        # forward propagation
        before_outs, after_outs, d_outs, e_outs, p_outs = self._forward(
            xs, ilens, ys, olens, ds, es, ps, is_inference=False)

        # modifiy mod part of groundtruth
        if hp.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_olen = max(olens)
            ys = ys[:, :max_olen]

        # apply mask to remove padded part
        if self.use_masking:
            in_masks = make_non_pad_mask(ilens).to(xs.device)
            d_outs = d_outs.masked_select(in_masks)
            ds = ds.masked_select(in_masks)
            out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            mel_masks = make_non_pad_mask(olens).to(ys.device)
            before_outs = before_outs.masked_select(out_masks)
            es = es.masked_select(mel_masks)  # Write size
            ps = ps.masked_select(mel_masks)  # Write size
            e_outs = e_outs.masked_select(mel_masks)  # Write size
            p_outs = p_outs.masked_select(mel_masks)  # Write size
            after_outs = (after_outs.masked_select(out_masks)
                          if after_outs is not None else None)
            ys = ys.masked_select(out_masks)

        # calculate loss
        before_loss = self.criterion(before_outs, ys)
        after_loss = 0
        if after_outs is not None:
            after_loss = self.criterion(after_outs, ys)
            l1_loss = before_loss + after_loss
        duration_loss = self.duration_criterion(d_outs, ds)
        energy_loss = self.energy_criterion(e_outs, es)
        pitch_loss = self.pitch_criterion(p_outs, ps)

        # make weighted mask and apply it
        if hp.use_weighted_masking:
            out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            out_weights = out_masks.float() / out_masks.sum(
                dim=1, keepdim=True).float()
            out_weights /= ys.size(0) * ys.size(2)
            duration_masks = make_non_pad_mask(ilens).to(ys.device)
            duration_weights = (
                duration_masks.float() /
                duration_masks.sum(dim=1, keepdim=True).float())
            duration_weights /= ds.size(0)

            # apply weight
            l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum()
            duration_loss = (duration_loss.mul(duration_weights).masked_select(
                duration_masks).sum())

        loss = l1_loss + duration_loss + energy_loss + pitch_loss
        report_keys = [
            {
                "l1_loss": l1_loss.item()
            },
            {
                "before_loss": before_loss.item()
            },
            {
                "after_loss": after_loss.item()
            },
            {
                "duration_loss": duration_loss.item()
            },
            {
                "energy_loss": energy_loss.item()
            },
            {
                "pitch_loss": pitch_loss.item()
            },
            {
                "loss": loss.item()
            },
        ]

        # report extra information
        if self.use_scaled_pos_enc:
            report_keys += [
                {
                    "encoder_alpha": self.encoder.embed[-1].alpha.data.item()
                },
                {
                    "decoder_alpha": self.decoder.embed[-1].alpha.data.item()
                },
            ]
        #self.reporter.report(report_keys)

        return loss, report_keys

    def calculate_all_attentions(self, xs, ilens, ys, olens, ds, es, ps, *args,
                                 **kwargs):
        """Calculate all of the attention weights.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).

        Returns:
            dict: Dict of attention weights and outputs.

        """
        with torch.no_grad():
            # remove unnecessary padded part (for multi-gpus)
            xs = xs[:, :max(ilens)]
            ys = ys[:, :max(olens)]

            # forward propagation
            outs, _, _, _, _ = self._forward(xs,
                                             ilens,
                                             ys,
                                             olens,
                                             ds,
                                             es,
                                             ps,
                                             is_inference=False)

        att_ws_dict = dict()
        if hp.attn_plot:
            for name, m in self.named_modules():
                if isinstance(m, MultiHeadedAttention):
                    attn = m.attn.cpu().numpy()
                    if "encoder" in name:
                        attn = [
                            a[:, :l, :l] for a, l in zip(attn, ilens.tolist())
                        ]
                    elif "decoder" in name:
                        if "src" in name:
                            attn = [
                                a[:, :ol, :il] for a, il, ol in zip(
                                    attn, ilens.tolist(), olens.tolist())
                            ]
                        elif "self" in name:
                            attn = [
                                a[:, :l, :l]
                                for a, l in zip(attn, olens.tolist())
                            ]
                        else:
                            logging.warning("unknown attention module: " +
                                            name)
                    else:
                        logging.warning("unknown attention module: " + name)
                    att_ws_dict[name] = attn
        att_ws_dict["predicted_fbank"] = [
            m[:l].T for m, l in zip(outs.cpu().numpy(), olens.tolist())
        ]

        return att_ws_dict

    def inference(self, x, inference_args, *args, **kwargs):
        """Generate the sequence of features given the sequences of characters.

        Args:
            x (Tensor): Input sequence of characters (T,).
            inference_args (Namespace): Dummy for compatibility.
            spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).

        Returns:
            Tensor: Output sequence of features (1, L, odim).
            None: Dummy for compatibility.
            None: Dummy for compatibility.

        """
        # setup batch axis
        ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
        xs = x.unsqueeze(0)

        # inference
        _, outs, _ = self._forward(xs, ilens, is_inference=True)  # (L, odim)

        return outs[0], None, None

    def _source_mask(self, ilens):
        """Make masks for self-attention.

        Examples:
            >>> ilens = [5, 3]
            >>> self._source_mask(ilens)
            tensor([[[1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1]],
                    [[1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0]]], dtype=torch.uint8)

        """
        x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
        return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)

    def _reset_parameters(self,
                          init_type,
                          init_enc_alpha=1.0,
                          init_dec_alpha=1.0):
        # initialize parameters
        initialize(self, init_type)

        # initialize alpha in scaled positional encoding
        if self.use_scaled_pos_enc:
            self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha)
            self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)

    def _transfer_from_teacher(self, transferred_encoder_module):
        if transferred_encoder_module == "all":
            for (n1, p1), (n2,
                           p2) in zip(self.encoder.named_parameters(),
                                      self.teacher.encoder.named_parameters()):
                assert n1 == n2, "It seems that encoder structure is different."
                assert p1.shape == p2.shape, "It seems that encoder size is different."
                p1.data.copy_(p2.data)
        elif transferred_encoder_module == "embed":
            student_shape = self.encoder.embed[0].weight.data.shape
            teacher_shape = self.teacher.encoder.embed[0].weight.data.shape
            assert student_shape == teacher_shape, "It seems that embed dimension is different."
            self.encoder.embed[0].weight.data.copy_(
                self.teacher.encoder.embed[0].weight.data)
        else:
            raise NotImplementedError("Support only all or embed.")

    @property
    def attention_plot_class(self):
        """Return plot class for attention weight plot."""
        return TTSPlot

    @property
    def base_plot_keys(self):
        """Return base key names to plot during training. keys should match what `chainer.reporter` reports.

        If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values.
        also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values.

        Returns:
            list: List of strings which are base keys to plot during training.

        """
        plot_keys = ["loss", "l1_loss", "duration_loss"]
        if self.use_scaled_pos_enc:
            plot_keys += ["encoder_alpha", "decoder_alpha"]

        return plot_keys
Exemple #8
0
    def __init__(self, idim, odim, args=None):
        """Initialize TTS-Transformer module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.

        """
        # initialize base classes
        torch.nn.Module.__init__(self)


        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.use_scaled_pos_enc = hp.use_scaled_pos_enc
        self.reduction_factor = hp.reduction_factor
        self.loss_type = "L1"
        self.use_guided_attn_loss = True
        if self.use_guided_attn_loss:
            if hp.num_layers_applied_guided_attn == -1:
                self.num_layers_applied_guided_attn = hp.elayers
            else:
                self.num_layers_applied_guided_attn = hp.num_layers_applied_guided_attn
            if hp.num_heads_applied_guided_attn == -1:
                self.num_heads_applied_guided_attn = hp.aheads
            else:
                self.num_heads_applied_guided_attn = hp.num_heads_applied_guided_attn
            self.modules_applied_guided_attn = hp.modules_applied_guided_attn

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding


        encoder_input_layer = torch.nn.Embedding(
            num_embeddings=idim,
            embedding_dim=hp.adim,
            padding_idx=padding_idx
        )
        self.encoder = Encoder(
            idim=idim,
            attention_dim=hp.adim,
            attention_heads=hp.aheads,
            linear_units=hp.eunits,
            input_layer=encoder_input_layer,
            dropout_rate=hp.transformer_enc_dropout_rate,
            positional_dropout_rate=hp.transformer_enc_positional_dropout_rate,
            attention_dropout_rate=hp.transformer_enc_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.encoder_normalize_before,
            concat_after=hp.encoder_concat_after
        )



        # define core decoder
        if hp.dprenet_layers != 0:
            # decoder prenet
            decoder_input_layer = torch.nn.Sequential(
                DecoderPrenet(
                    idim=odim,
                    n_layers=hp.dprenet_layers,
                    n_units=hp.dprenet_units,
                    dropout_rate=hp.dprenet_dropout_rate
                ),
                torch.nn.Linear(hp.dprenet_units, hp.adim)
            )
        else:
            decoder_input_layer = "linear"
        self.decoder = Decoder(
            odim=-1,
            attention_dim=hp.adim,
            attention_heads=hp.aheads,
            linear_units=hp.dunits,
            dropout_rate=hp.transformer_dec_dropout_rate,
            positional_dropout_rate=hp.transformer_dec_positional_dropout_rate,
            self_attention_dropout_rate=hp.transformer_dec_attn_dropout_rate,
            src_attention_dropout_rate=hp.transformer_enc_dec_attn_dropout_rate,
            input_layer=decoder_input_layer,
            use_output_layer=False,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.decoder_normalize_before,
            concat_after=hp.decoder_concat_after
        )

        # define final projection
        self.feat_out = torch.nn.Linear(hp.adim, odim * hp.reduction_factor)
        self.prob_out = torch.nn.Linear(hp.adim, hp.reduction_factor)

        # define postnet
        self.postnet = None if hp.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=hp.postnet_layers,
            n_chans=hp.postnet_chans,
            n_filts=hp.postnet_filts,
            use_batch_norm=hp.use_batch_norm,
            dropout_rate=hp.postnet_dropout_rate
        )

        # define loss function
        self.criterion = TransformerLoss(use_masking=hp.use_masking,
                                         bce_pos_weight=hp.bce_pos_weight)
        if self.use_guided_attn_loss:
            self.attn_criterion = GuidedMultiHeadAttentionLoss(
                sigma=0.4,
                alpha=1.0,
            )

        # initialize parameters
        self._reset_parameters(init_type=hp.transformer_init,
                               init_enc_alpha=hp.initial_encoder_alpha,
                               init_dec_alpha=hp.initial_decoder_alpha)
Exemple #9
0
def train(opt):
    loader = get_loader(opt, 'train')
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    summry_writer = tensorboardX.SummaryWriter()
    infos = {}
    histories = {}
    if opt.start_from is not None:
        infos_path = os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')
        histories_path = os.path.join(opt.start_from,
                                      'histories_' + opt.id + '.pkl')
        # open infos and check if models are compatible
        with open(infos_path, 'rb') as f:
            infos = pickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ['hidden_size']
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme],\
                "Command line argument and saved model disagree on %s"%(checkme)
        if os.path.isfile(histories_path):
            with open(histories_path, 'rb') as f:
                histories = pickle.load(f)

    iteration = infos.get('iter', 0)
    current_epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})

    if opt.load_best_score == 1:
        best_val_score = infos.get("best_val_score", None)

    encoder = Encoder()
    decoder = Decoder(opt)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    criterion = utils.LanguageModelCriterion().to(device)
    optimizer = optim.Adam(decoder.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    if vars(opt).get('start_from', None) is not None:
        optimizer_path = os.path.join(opt.start_from, 'optimizer.pth')
        optimizer.load_state_dict(torch.load(optimizer_path))

    total_step = len(loader)
    start = time.time()
    for epoch in range(current_epoch, opt.max_epochs):
        if epoch > opt.learning_rate_decay_start and \
            opt.learning_rate_decay_start >= 0:
            frac = (epoch - opt.learning_rate_decay_start
                    ) // opt.learning_rate_decay_every
            deccay_factor = opt.learning_rate_decay_rate**frac
            opt.current_lr = opt.learning_rate * deccay_factor
            utils.set_lr(optimizer, opt.current_lr)
            print("learing rate change form {} to {}".format(
                opt.learning_rate, opt.current_lr))
        else:
            opt.current_lr = opt.learning_rate
        for i, data in enumerate(loader, iteration):
            if i > total_step - 1:
                iteration = 0
                break
            transform = transforms.Normalize((0.485, 0.456, 0.406),
                                             (0.229, 0.224, 0.225))
            imgs = []
            for k in range(data['imgs'].shape[0]):
                img = torch.tensor(data['imgs'][k], dtype=torch.float)
                img = transform(img)
                imgs.append(img)
            imgs = torch.stack(imgs, dim=0).to(device)
            labels = torch.tensor(data['labels'].astype(np.int32),
                                  dtype=torch.long).to(device)
            masks = torch.tensor(data['masks'], dtype=torch.float).to(device)

            with torch.no_grad():
                features = encoder(imgs)
            preds = decoder(features, labels)

            loss = criterion(preds, labels[:, 1:], masks[:, 1:])
            optimizer.zero_grad()
            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()

            print("iter: {}/{} (epoch {}), train loss = {:.3f}, time/batch = {}"\
                 .format(i, total_step, epoch, train_loss, utils.get_duration(start)))

            log_iter = i + epoch * total_step
            # write training loss summary
            if (i % opt.losses_log_every) == 0:
                summry_writer.add_scalar('train_loss', train_loss, log_iter)
                summry_writer.add_scalar('learning_rate', opt.current_lr,
                                         log_iter)

            # make evaluation on validation set, and save model
            if (i % opt.save_checkpoint_every == 0):
                #eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss,\
                predictions,\
                lang_stats = eval_utils.eval_split(encoder, decoder, criterion,
                                                   opt, eval_kwargs)
                summry_writer.add_scalar('valaidation loss', val_loss,
                                         log_iter)
                if lang_stats is not None:
                    for metric, score in lang_stats.items():
                        summry_writer.add_scalar(metric, score, log_iter)
                val_result_history[i] = {
                    "loss": val_loss,
                    "lang_stats": lang_stats,
                    "predictions": predictions
                }

                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss.item()

                best_flag = False
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.exists(opt.checkpoint_path):
                    os.makedirs(opt.checkpoint_path)
                checkpoint_ptah = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(decoder.state_dict(), checkpoint_ptah)
                print("model saved to {}".format(checkpoint_ptah))

                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = i + 1
                infos['epoch'] = epoch
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.ix_to_word

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                infos_path = os.path.join(opt.checkpoint_path,
                                          'infos_' + opt.id + '.pkl')
                histories_path = os.path.join(opt.checkpoint_path,
                                              'histories_' + opt.id + '.pkl')
                with open(infos_path, 'wb') as f:
                    pickle.dump(infos, f)
                print("infos saved into {}".format(infos_path))

                with open(histories_path, 'wb') as f:
                    pickle.dump(histories, f)
                print('histories saved into {}'.format(histories_path))
                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(decoder.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        pickle.dump(infos, f)

    summry_writer.close()