class Tacotron2_sa(TTSInterface, torch.nn.Module): @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("tacotron 2 model setting") # encoder group.add_argument( "--embed-dim", default=512, type=int, help="Number of dimension of embedding", ) group.add_argument("--elayers", default=1, type=int, help="Number of encoder layers") group.add_argument( "--eunits", "-u", default=512, type=int, help="Number of encoder hidden units", ) group.add_argument( "--econv-layers", default=3, type=int, help="Number of encoder convolution layers", ) group.add_argument( "--econv-chans", default=512, type=int, help="Number of encoder convolution channels", ) group.add_argument( "--econv-filts", default=5, type=int, help="Filter size of encoder convolution", ) # decoder group.add_argument("--dlayers", default=2, type=int, help="Number of decoder layers") group.add_argument("--dunits", default=1024, type=int, help="Number of decoder hidden units") group.add_argument("--prenet-layers", default=2, type=int, help="Number of prenet layers") group.add_argument( "--prenet-units", default=256, type=int, help="Number of prenet hidden units", ) group.add_argument("--postnet-layers", default=5, type=int, help="Number of postnet layers") group.add_argument("--postnet-chans", default=512, type=int, help="Number of postnet channels") group.add_argument("--postnet-filts", default=5, type=int, help="Filter size of postnet") group.add_argument( "--output-activation", default=None, type=str, nargs="?", help="Output activation function", ) # model (parameter) related group.add_argument( "--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization", ) group.add_argument( "--use-concate", default=True, type=strtobool, help= "Whether to concatenate encoder embedding with decoder outputs", ) group.add_argument( "--use-residual", default=True, type=strtobool, help="Whether to use residual connection in conv layer", ) group.add_argument("--dropout-rate", default=0.5, type=float, help="Dropout rate") group.add_argument("--zoneout-rate", default=0.1, type=float, help="Zoneout rate") group.add_argument("--reduction-factor", default=1, type=int, help="Reduction factor") group.add_argument( "--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions", ) group.add_argument("--spc-dim", default=None, type=int, help="Number of spectrogram dimensions") group.add_argument("--pretrained-model", default=None, type=str, help="Pretrained model path") # loss related group.add_argument( "--use-masking", default=False, type=strtobool, help="Whether to use masking in calculation of loss", ) group.add_argument( "--use-weighted-masking", default=False, type=strtobool, help="Whether to use weighted masking in calculation of loss", ) # duration predictor settings group.add_argument( "--duration-predictor-layers", default=2, type=int, help="Number of layers in duration predictor", ) group.add_argument( "--duration-predictor-chans", default=384, type=int, help="Number of channels in duration predictor", ) group.add_argument( "--duration-predictor-kernel-size", default=3, type=int, help="Kernel size in duration predictor", ) group.add_argument( "--duration-predictor-dropout-rate", default=0.1, type=float, help="Dropout rate for duration predictor", ) return parser def __init__(self, idim, odim, args=None, com_args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - embed_dim (int): Dimension of character embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True) - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) args = vars(args) if 'use_fe_condition' not in args.keys(): args['use_fe_condition'] = com_args.use_fe_condition if 'append_position' not in args.keys(): args['append_position'] = com_args.append_position args = argparse.Namespace(**args) # store hyperparameters self.idim = idim self.odim = odim self.embed_dim = args.embed_dim self.spk_embed_dim = args.spk_embed_dim self.reduction_factor = args.reduction_factor self.use_fe_condition = args.use_fe_condition self.append_position = args.append_position # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError("there is no such an activation function. (%s)" % args.output_activation) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder( idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx, resume=args.encoder_resume, ) dec_idim = (args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim) self.dec = Decoder( idim=dec_idim, odim=odim, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor, use_fe_condition=args.use_fe_condition, append_position=args.append_position, ) self.duration_predictor = DurationPredictor( idim=dec_idim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) reduction = 'none' if args.use_weighted_masking else 'mean' self.duration_criterion = DurationPredictorLoss(reduction=reduction) #-------------- picth/energy predictor definition ---------------# if self.use_fe_condition: output_dim = 1 # pitch prediction pitch_predictor_layers = 2 pitch_predictor_chans = 384 pitch_predictor_kernel_size = 3 pitch_predictor_dropout_rate = 0.5 pitch_embed_kernel_size = 9 pitch_embed_dropout_rate = 0.5 self.stop_gradient_from_pitch_predictor = False self.pitch_predictor = VariancePredictor( idim=dec_idim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout_rate, output_dim=output_dim, ) self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout_rate), ) # energy prediction energy_predictor_layers = 2 energy_predictor_chans = 384 energy_predictor_kernel_size = 3 energy_predictor_dropout_rate = 0.5 energy_embed_kernel_size = 9 energy_embed_dropout_rate = 0.5 self.stop_gradient_from_energy_predictor = False self.energy_predictor = VariancePredictor( idim=dec_idim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout_rate, output_dim=output_dim, ) self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout_rate), ) # define criterions self.prosody_criterion = prosody_criterions( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking) self.taco2_loss = Tacotron2Loss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, ) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model) print('\n############## number of network parameters ##############\n') parameters = filter(lambda p: p.requires_grad, self.enc.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Encoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.dec.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Decoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.duration_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for duration_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for whole network: %.5fM' % parameters) print('\n##########################################################\n') def forward(self, xs, ilens, ys, olens, spembs=None, extras=None, new_ys=None, non_zero_lens_mask=None, ds_nonzeros=None, output_masks=None, position=None, f0=None, energy=None, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim). new_ys (Tensor): reorganized mel-spectrograms non_zero_lens_masks (Tensor) ds_nonzeros (Tensor) output_masks (Tensor) position (Tenor): position values for each phoneme f0 (Tensor): pitch energy (Tensor) Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] # calculate FCL-taco2-enc outputs hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) # duration predictor loss cal ds = extras.squeeze(-1) d_masks = make_pad_mask(ilens).to(xs.device) d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) duration_masks = make_non_pad_mask(ilens).to(ys.device) d_outs = d_outs.masked_select(duration_masks) duration_loss = self.duration_criterion( d_outs, ds.masked_select(duration_masks)) if self.use_fe_condition: expand_hs = hs fe_masks = d_masks if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(expand_hs.detach(), fe_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor( expand_hs, fe_masks.unsqueeze(-1)) # B x Tmax x 1 if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(expand_hs.detach(), fe_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor( expand_hs, fe_masks.unsqueeze(-1)) # B x Tmax x 1 pitch_loss = self.prosody_criterion(p_outs, f0, ilens) energy_loss = self.prosody_criterion(e_outs, energy, ilens) p_embs = self.pitch_embed(f0.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(energy.transpose(1, 2)).transpose(1, 2) else: p_embs = None e_embs = None ylens = olens after_outs, before_outs = self.dec(hs, hlens, ds, ys, ylens, new_ys, non_zero_lens_mask, ds_nonzeros, output_masks, position, f0, energy, p_embs, e_embs) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] # caluculate taco2 loss l1_loss, mse_loss = self.taco2_loss(after_outs, before_outs, ys, olens) loss = l1_loss + mse_loss + duration_loss report_keys = [ { "l1_loss": l1_loss.item() }, { "mse_loss": mse_loss.item() }, { "dur_loss": duration_loss.item() }, ] if self.use_fe_condition: prosody_weight = 1.0 loss = loss + prosody_weight * (pitch_loss + energy_loss) report_keys += [ { 'pitch_loss': pitch_loss.item() }, { 'energy_loss': energy_loss.item() }, ] report_keys += [{"loss": loss.item()}] self.reporter.report(report_keys) return loss def inference(self, x, inference_args, spemb=None, dur=None, f0=None, energy=None, utt_id=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). """ # inference h = self.enc.inference(x) # Tmax x h-dim if self.spk_embed_dim is not None: spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1) h = torch.cat([h, spemb], dim=-1) ilens = torch.LongTensor([h.shape[0]]).to(h.device) d_masks = make_pad_mask(ilens).to(h.device) if dur is not None: d_outs = dur.reshape(-1).long() else: d_outs = self.duration_predictor.inference(h.unsqueeze(0), d_masks) # B x Tmax d_outs = d_outs.squeeze(0).long() if self.use_fe_condition: if f0 is not None: p_outs = f0.unsqueeze(0) e_outs = energy.unsqueeze(0) else: expand_hs = h.unsqueeze(0) fe_masks = d_masks p_outs = self.pitch_predictor(expand_hs, fe_masks.unsqueeze(-1)) e_outs = self.energy_predictor(expand_hs, fe_masks.unsqueeze(-1)) p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose( 1, 2).squeeze(0) e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose( 1, 2).squeeze(0) else: p_outs = None e_outs = None p_embs = None e_embs = None if self.append_position: position = [] for iid in range(d_outs.shape[0]): if d_outs[iid] != 0: position.append( torch.FloatTensor(list(range(d_outs[iid].long()))) / d_outs[iid]) position = pad_list(position, 0) position = position.to(h.device) else: position = None outs = self.dec.inference( h, d_outs, position, p_embs, e_embs, ) return outs @property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ["loss", "l1_loss", "mse_loss", "dur_loss"] if self.use_fe_condition: plot_keys += ["pitch_loss", "energy_loss"] return plot_keys
class Tacotron2_sa(TTSInterface, torch.nn.Module): @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("tacotron 2 model setting") # encoder group.add_argument( "--embed-dim", default=512, type=int, help="Number of dimension of embedding", ) group.add_argument( "--elayers", default=1, type=int, help="Number of encoder layers" ) group.add_argument( "--eunits", "-u", default=512, type=int, help="Number of encoder hidden units", ) group.add_argument( "--econv-layers", default=3, type=int, help="Number of encoder convolution layers", ) group.add_argument( "--econv-chans", default=512, type=int, help="Number of encoder convolution channels", ) group.add_argument( "--econv-filts", default=5, type=int, help="Filter size of encoder convolution", ) # decoder group.add_argument( "--dlayers", default=2, type=int, help="Number of decoder layers" ) group.add_argument( "--dunits", default=1024, type=int, help="Number of decoder hidden units" ) group.add_argument( "--prenet-layers", default=2, type=int, help="Number of prenet layers" ) group.add_argument( "--prenet-units", default=256, type=int, help="Number of prenet hidden units", ) group.add_argument( "--postnet-layers", default=5, type=int, help="Number of postnet layers" ) group.add_argument( "--postnet-chans", default=512, type=int, help="Number of postnet channels" ) group.add_argument( "--postnet-filts", default=5, type=int, help="Filter size of postnet" ) group.add_argument( "--output-activation", default=None, type=str, nargs="?", help="Output activation function", ) # model (parameter) related group.add_argument( "--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization", ) group.add_argument( "--use-concate", default=True, type=strtobool, help="Whether to concatenate encoder embedding with decoder outputs", ) group.add_argument( "--use-residual", default=True, type=strtobool, help="Whether to use residual connection in conv layer", ) group.add_argument( "--dropout-rate", default=0.5, type=float, help="Dropout rate" ) group.add_argument( "--zoneout-rate", default=0.1, type=float, help="Zoneout rate" ) group.add_argument( "--reduction-factor", default=1, type=int, help="Reduction factor" ) group.add_argument( "--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions", ) group.add_argument( "--spc-dim", default=None, type=int, help="Number of spectrogram dimensions" ) group.add_argument( "--pretrained-model", default=None, type=str, help="Pretrained model path" ) # loss related group.add_argument( "--use-masking", default=False, type=strtobool, help="Whether to use masking in calculation of loss", ) group.add_argument( "--use-weighted-masking", default=False, type=strtobool, help="Whether to use weighted masking in calculation of loss", ) # duration predictor settings group.add_argument( "--duration-predictor-layers", default=2, type=int, help="Number of layers in duration predictor", ) group.add_argument( "--duration-predictor-chans", default=384, type=int, help="Number of channels in duration predictor", ) group.add_argument( "--duration-predictor-kernel-size", default=3, type=int, help="Kernel size in duration predictor", ) group.add_argument( "--duration-predictor-dropout-rate", default=0.1, type=float, help="Dropout rate for duration predictor", ) return parser def __init__(self, idim, odim, args=None, com_args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - embed_dim (int): Dimension of character embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True) - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) args = vars(args) if 'use_fe_condition' not in args.keys(): args['use_fe_condition'] = com_args.use_fe_condition if 'append_position' not in args.keys(): args['append_position'] = com_args.append_position args = argparse.Namespace(**args) # store hyperparameters self.idim = idim self.odim = odim self.embed_dim = args.embed_dim self.spk_embed_dim = args.spk_embed_dim self.reduction_factor = args.reduction_factor self.use_fe_condition = args.use_fe_condition self.append_position = args.append_position # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError( "there is no such an activation function. (%s)" % args.output_activation ) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder( idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx, is_student=False, ) dec_idim = ( args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim ) self.dec = Decoder( idim=dec_idim, odim=odim, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor, use_fe_condition=args.use_fe_condition, append_position=args.append_position, is_student=False, ) self.duration_predictor = DurationPredictor( idim=dec_idim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) # reduction = 'none' if args.use_weighted_masking else 'mean' # self.duration_criterion = DurationPredictorLoss(reduction=reduction) #-------------- picth/energy predictor definition ---------------# if self.use_fe_condition: output_dim=1 # pitch prediction pitch_predictor_layers=2 pitch_predictor_chans=384 pitch_predictor_kernel_size=3 pitch_predictor_dropout_rate=0.5 pitch_embed_kernel_size=9 pitch_embed_dropout_rate=0.5 self.stop_gradient_from_pitch_predictor=False self.pitch_predictor = VariancePredictor( idim=dec_idim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout_rate, output_dim=output_dim, ) self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size-1)//2, ), torch.nn.Dropout(pitch_embed_dropout_rate), ) # energy prediction energy_predictor_layers=2 energy_predictor_chans=384 energy_predictor_kernel_size=3 energy_predictor_dropout_rate=0.5 energy_embed_kernel_size=9 energy_embed_dropout_rate=0.5 self.stop_gradient_from_energy_predictor=False self.energy_predictor = VariancePredictor( idim=dec_idim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout_rate, output_dim=output_dim, ) self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size-1)//2, ), torch.nn.Dropout(energy_embed_dropout_rate), ) # # define criterions # self.prosody_criterion = prosody_criterions( # use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking) self.taco2_loss = Tacotron2Loss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, ) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model) print('\n############## number of network parameters ##############\n') parameters = filter(lambda p: p.requires_grad, self.enc.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Encoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.dec.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Decoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.duration_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for duration_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for whole network: %.5fM' % parameters) print('\n##########################################################\n') def forward( self, xs, ilens, ys, olens, spembs=None, extras=None, new_ys=None, non_zero_lens_mask=None, ds_nonzeros=None, output_masks=None, position=None, f0=None, energy=None, *args, **kwargs ): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim). new_ys (Tensor): reorganized mel-spectrograms non_zero_lens_masks (Tensor) ds_nonzeros (Tensor) output_masks (Tensor) position (Tenor): position values for each phoneme f0 (Tensor): pitch energy (Tensor) Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] # calculate FCL-taco2-enc outputs hs, hlens, enc_distill_items = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) # duration predictor loss cal ds = extras.squeeze(-1) d_masks = make_pad_mask(ilens).to(xs.device) d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) d_outs = d_outs.unsqueeze(-1) # (B, Tmax, 1) # duration_masks = make_non_pad_mask(ilens).to(ys.device) # d_outs = d_outs.masked_select(duration_masks) # duration_loss = self.duration_criterion(d_outs, ds.masked_select(duration_masks)) if self.use_fe_condition: expand_hs = hs fe_masks = d_masks if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(expand_hs.detach(), fe_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(expand_hs, fe_masks.unsqueeze(-1)) # B x Tmax x 1 if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(expand_hs.detach(), fe_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(expand_hs, fe_masks.unsqueeze(-1)) # B x Tmax x 1 # pitch_loss = self.prosody_criterion(p_outs,f0,ilens) # energy_loss = self.prosody_criterion(e_outs,energy,ilens) p_embs = self.pitch_embed(f0.transpose(1,2)).transpose(1,2) e_embs = self.energy_embed(energy.transpose(1,2)).transpose(1,2) else: p_embs = None e_embs = None ylens = olens after_outs, before_outs, dec_distill_items = self.dec(hs, hlens, ds, ys, ylens, new_ys, non_zero_lens_mask, ds_nonzeros, output_masks, position, p_embs, e_embs) prosody_distill_items = [d_outs, p_outs, e_outs, p_embs, e_embs] enc_distill_items = self.detach_items(enc_distill_items) dec_distill_items = self.detach_items(dec_distill_items) prosody_distill_items = self.detach_items(prosody_distill_items) return after_outs, before_outs, enc_distill_items, dec_distill_items, prosody_distill_items def detach_items(self, items): items = [it.detach() for it in items] return items