def __init__(self, idim, adim, aheads, eunits, depth, use_scaled_pos_enc, use_masking, normalize_before, concat_after, pointwise_layer, conv_kernel): super(MelEncoder, self).__init__() self.use_scaled_pos_enc = use_scaled_pos_enc self.use_masking = use_masking # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) input_layer = torch.nn.Conv1d(idim, adim, 1) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=depth, input_layer=input_layer, dropout_rate=0.2, positional_dropout_rate=0.2, attention_dropout_rate=0.2, pos_enc_class=pos_enc_class, normalize_before=normalize_before, concat_after=concat_after, positionwise_layer_type=pointwise_layer, positionwise_conv_kernel_size=conv_kernel, )
def main(opt): with open(opt.infos_path, 'rb') as f: infos = pickle.load(f) #override and collect parameters if len(opt.input_h5) == 0: opt.input_h5 = infos['opt'].input_h5 if len(opt.input_json) == 0: opt.input_json = infos['opt'].input_json if opt.batch_size == 0: opt.batch_size = infos['opt'].batch_size if len(opt.id) == 0: opt.id = infos['opt'].id ignore = ['id', 'batch_size', 'beam_size', 'strat_from', 'language_eval'] for key, value in vars(infos['opt']).items(): if key not in ignore: if key in vars(opt): assert vars(opt)[key] == vars(infos['opt'])[key],\ key+" option not consistent" else: vars(opt).update({key: value}) vocab = infos['vocab'] device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') encoder = Encoder() decoder = Decoder(opt) encoder = encoder.to(device) decoder = decoder.to(device) decoder.load_state_dict(torch.load(opt.model, map_location=str(device))) encoder.eval() decoder.eval() criterion = utils.LanguageModelCriterion().to(device) if len(opt.image_folder) == 0: loader = get_loader(opt, 'test') loader.ix_to_word = vocab loss, split_predictions, lang_stats = \ eval_utils.eval_split(encoder, decoder, criterion, opt, vars(opt)) print('loss: ', loss) print(lang_stats) result_json_path = os.path.join(opt.checkpoint_path, "captions_"+opt.split+"2014_"+opt.id+"_results.json") with open(result_json_path, "w") as f: json.dump(split_predictions, f)
def send(self, sock, msg): """ sends data to peer from socket """ try: peer_ip, peer_port = sock.getpeername() to_send = { "peer_ip": peer_ip, "peer_port": peer_port, "ip": self.ip_addr, "port": self.port } data = self.make_header(msg, to_send) data = Encoder.json_decode(data) sock.send(data) except socket.error: Logger.log_error("cant send to peer")
def handle(self): raw_data = self.request.recv(BUFFER_SIZE).strip() self.data = Encoder.json_encode(raw_data) #for tests Logger.log_info(self.data)
def __init__(self, idim: int, odim: int, hp: Dict): """Initialize feed-forward Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. """ # initialize base classes assert check_argument_types() torch.nn.Module.__init__(self) # fill missing arguments # store hyperparameters self.idim = idim self.odim = odim self.use_scaled_pos_enc = hp.model.use_scaled_pos_enc self.use_masking = hp.model.use_masking # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=hp.model.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=hp.model.adim, attention_heads=hp.model.aheads, linear_units=hp.model.eunits, num_blocks=hp.model.elayers, input_layer=encoder_input_layer, dropout_rate=0.2, positional_dropout_rate=0.2, attention_dropout_rate=0.2, pos_enc_class=pos_enc_class, normalize_before=hp.model.encoder_normalize_before, concat_after=hp.model.encoder_concat_after, positionwise_layer_type=hp.model.positionwise_layer_type, positionwise_conv_kernel_size=hp.model. positionwise_conv_kernel_size, ) self.duration_predictor = DurationPredictor( idim=hp.model.adim, n_layers=hp.model.duration_predictor_layers, n_chans=hp.model.duration_predictor_chans, kernel_size=hp.model.duration_predictor_kernel_size, dropout_rate=hp.model.duration_predictor_dropout_rate, ) self.energy_predictor = EnergyPredictor( idim=hp.model.adim, n_layers=hp.model.duration_predictor_layers, n_chans=hp.model.duration_predictor_chans, kernel_size=hp.model.duration_predictor_kernel_size, dropout_rate=hp.model.duration_predictor_dropout_rate, min=hp.data.e_min, max=hp.data.e_max, ) self.energy_embed = torch.nn.Linear(hp.model.adim, hp.model.adim) self.pitch_predictor = PitchPredictor( idim=hp.model.adim, n_layers=hp.model.duration_predictor_layers, n_chans=hp.model.duration_predictor_chans, kernel_size=hp.model.duration_predictor_kernel_size, dropout_rate=hp.model.duration_predictor_dropout_rate, min=hp.data.p_min, max=hp.data.p_max, ) self.pitch_embed = torch.nn.Linear(hp.model.adim, hp.model.adim) # define length regulator self.length_regulator = LengthRegulator() ###### AdaSpeech self.utterance_encoder = UtteranceEncoder(idim=hp.audio.n_mels) self.phoneme_level_encoder = PhonemeLevelEncoder(idim=hp.audio.n_mels) self.phoneme_level_predictor = PhonemeLevelPredictor( idim=hp.model.adim) self.phone_level_embed = torch.nn.Linear(hp.model.phn_latent_dim, hp.model.adim) self.acoustic_criterion = AcousticPredictorLoss() # define decoder # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=hp.model.adim, attention_dim=hp.model.ddim, attention_heads=hp.model.aheads, linear_units=hp.model.dunits, num_blocks=hp.model.dlayers, input_layer="linear", dropout_rate=0.2, positional_dropout_rate=0.2, attention_dropout_rate=0.2, pos_enc_class=pos_enc_class, normalize_before=hp.model.decoder_normalize_before, concat_after=hp.model.decoder_concat_after, positionwise_layer_type=hp.model.positionwise_layer_type, positionwise_conv_kernel_size=hp.model. positionwise_conv_kernel_size, ) # define postnet self.postnet = (None if hp.model.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=hp.model.postnet_layers, n_chans=hp.model.postnet_chans, n_filts=hp.model.postnet_filts, use_batch_norm=hp.model.use_batch_norm, dropout_rate=hp.model.postnet_dropout_rate, )) # define final projection self.feat_out = torch.nn.Linear(hp.model.ddim, odim * hp.model.reduction_factor) # initialize parameters self._reset_parameters( init_type=hp.model.transformer_init, init_enc_alpha=hp.model.initial_encoder_alpha, init_dec_alpha=hp.model.initial_decoder_alpha, ) # define criterions self.duration_criterion = DurationPredictorLoss() self.energy_criterion = EnergyPredictorLoss() self.pitch_criterion = PitchPredictorLoss() self.criterion = torch.nn.L1Loss(reduction="mean") self.use_weighted_masking = hp.model.use_weighted_masking
def __init__(self, idim, odim): """Initialize feed-forward Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. """ # initialize base classes torch.nn.Module.__init__(self) # fill missing arguments # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = hp.reduction_factor self.use_scaled_pos_enc = hp.use_scaled_pos_enc self.use_masking = hp.use_masking # TODO(kan-bayashi): support reduction_factor > 1 if self.reduction_factor != 1: raise NotImplementedError("Support only reduction_factor = 1.") # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=hp.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=hp.adim, attention_heads=hp.aheads, linear_units=hp.eunits, num_blocks=hp.elayers, input_layer=encoder_input_layer, dropout_rate=0.2, positional_dropout_rate=0.2, attention_dropout_rate=0.2, pos_enc_class=pos_enc_class, normalize_before=hp.encoder_normalize_before, concat_after=hp.encoder_concat_after, positionwise_layer_type=hp.positionwise_layer_type, positionwise_conv_kernel_size=hp.positionwise_conv_kernel_size) self.duration_predictor = DurationPredictor( idim=hp.adim, n_layers=hp.duration_predictor_layers, n_chans=hp.duration_predictor_chans, kernel_size=hp.duration_predictor_kernel_size, dropout_rate=hp.duration_predictor_dropout_rate, ) self.energy_predictor = EnergyPredictor( idim=hp.adim, n_layers=hp.duration_predictor_layers, n_chans=hp.duration_predictor_chans, kernel_size=hp.duration_predictor_kernel_size, dropout_rate=hp.duration_predictor_dropout_rate, ) self.energy_embed = torch.nn.Linear(hp.adim, hp.adim) self.pitch_predictor = PitchPredictor( idim=hp.adim, n_layers=hp.duration_predictor_layers, n_chans=hp.duration_predictor_chans, kernel_size=hp.duration_predictor_kernel_size, dropout_rate=hp.duration_predictor_dropout_rate, ) self.pitch_embed = torch.nn.Linear(hp.adim, hp.adim) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=256, attention_dim=256, attention_heads=hp.aheads, linear_units=hp.dunits, num_blocks=hp.dlayers, input_layer=None, dropout_rate=0.2, positional_dropout_rate=0.2, attention_dropout_rate=0.2, pos_enc_class=pos_enc_class, normalize_before=hp.decoder_normalize_before, concat_after=hp.decoder_concat_after, positionwise_layer_type=hp.positionwise_layer_type, positionwise_conv_kernel_size=hp.positionwise_conv_kernel_size) # define postnet self.postnet = (None if hp.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=hp.postnet_layers, n_chans=hp.postnet_chans, n_filts=hp.postnet_filts, use_batch_norm=hp.use_batch_norm, dropout_rate=hp.postnet_dropout_rate, )) # define final projection self.feat_out = torch.nn.Linear(hp.adim, odim * hp.reduction_factor) # initialize parameters self._reset_parameters(init_type=hp.transformer_init, init_enc_alpha=hp.initial_encoder_alpha, init_dec_alpha=hp.initial_decoder_alpha) # define criterions self.duration_criterion = DurationPredictorLoss() self.energy_criterion = EnergyPredictorLoss() self.pitch_criterion = PitchPredictorLoss() self.criterion = torch.nn.L1Loss(reduction='mean')
class FeedForwardTransformer(torch.nn.Module): """Feed Forward Transformer for TTS a.k.a. FastSpeech. This is a module of FastSpeech, feed-forward Transformer with duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive processing during inference, resulting in fast decoding compared with auto-regressive Transformer. .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: https://arxiv.org/pdf/1905.09263.pdf """ def __init__(self, idim, odim): """Initialize feed-forward Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. """ # initialize base classes torch.nn.Module.__init__(self) # fill missing arguments # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = hp.reduction_factor self.use_scaled_pos_enc = hp.use_scaled_pos_enc self.use_masking = hp.use_masking # TODO(kan-bayashi): support reduction_factor > 1 if self.reduction_factor != 1: raise NotImplementedError("Support only reduction_factor = 1.") # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=hp.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=hp.adim, attention_heads=hp.aheads, linear_units=hp.eunits, num_blocks=hp.elayers, input_layer=encoder_input_layer, dropout_rate=0.2, positional_dropout_rate=0.2, attention_dropout_rate=0.2, pos_enc_class=pos_enc_class, normalize_before=hp.encoder_normalize_before, concat_after=hp.encoder_concat_after, positionwise_layer_type=hp.positionwise_layer_type, positionwise_conv_kernel_size=hp.positionwise_conv_kernel_size) self.duration_predictor = DurationPredictor( idim=hp.adim, n_layers=hp.duration_predictor_layers, n_chans=hp.duration_predictor_chans, kernel_size=hp.duration_predictor_kernel_size, dropout_rate=hp.duration_predictor_dropout_rate, ) self.energy_predictor = EnergyPredictor( idim=hp.adim, n_layers=hp.duration_predictor_layers, n_chans=hp.duration_predictor_chans, kernel_size=hp.duration_predictor_kernel_size, dropout_rate=hp.duration_predictor_dropout_rate, ) self.energy_embed = torch.nn.Linear(hp.adim, hp.adim) self.pitch_predictor = PitchPredictor( idim=hp.adim, n_layers=hp.duration_predictor_layers, n_chans=hp.duration_predictor_chans, kernel_size=hp.duration_predictor_kernel_size, dropout_rate=hp.duration_predictor_dropout_rate, ) self.pitch_embed = torch.nn.Linear(hp.adim, hp.adim) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=256, attention_dim=256, attention_heads=hp.aheads, linear_units=hp.dunits, num_blocks=hp.dlayers, input_layer=None, dropout_rate=0.2, positional_dropout_rate=0.2, attention_dropout_rate=0.2, pos_enc_class=pos_enc_class, normalize_before=hp.decoder_normalize_before, concat_after=hp.decoder_concat_after, positionwise_layer_type=hp.positionwise_layer_type, positionwise_conv_kernel_size=hp.positionwise_conv_kernel_size) # define postnet self.postnet = (None if hp.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=hp.postnet_layers, n_chans=hp.postnet_chans, n_filts=hp.postnet_filts, use_batch_norm=hp.use_batch_norm, dropout_rate=hp.postnet_dropout_rate, )) # define final projection self.feat_out = torch.nn.Linear(hp.adim, odim * hp.reduction_factor) # initialize parameters self._reset_parameters(init_type=hp.transformer_init, init_enc_alpha=hp.initial_encoder_alpha, init_dec_alpha=hp.initial_decoder_alpha) # define criterions self.duration_criterion = DurationPredictorLoss() self.energy_criterion = EnergyPredictorLoss() self.pitch_criterion = PitchPredictorLoss() self.criterion = torch.nn.L1Loss(reduction='mean') def _forward(self, xs, ilens, ys=None, olens=None, ds=None, es=None, ps=None, is_inference=False): # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) # print("Ys :",ys.shape) torch.Size([32, 868, 80]) # forward duration predictor and length regulator d_masks = make_pad_mask(ilens).to(xs.device) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, d_outs, ilens) # (B, Lmax, adim) e_outs = self.energy_predictor.inference(hs) p_outs = self.pitch_predictor.inference(hs) one_hot_energy = energy_to_one_hot(e_outs, False) # (B, Lmax, adim) one_hot_pitch = pitch_to_one_hot(p_outs, False) # (B, Lmax, adim) else: with torch.no_grad(): # ds = self.duration_calculator(xs, ilens, ys, olens) # (B, Tmax) one_hot_energy = energy_to_one_hot( es) # (B, Lmax, adim) torch.Size([32, 868, 256]) # print("one_hot_energy:", one_hot_energy.shape) one_hot_pitch = pitch_to_one_hot( ps) # (B, Lmax, adim) torch.Size([32, 868, 256]) # print("one_hot_pitch:", one_hot_pitch.shape) mel_masks = make_pad_mask(olens).to(xs.device) # print("Before Hs:", hs.shape) torch.Size([32, 121, 256]) d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) # print("d_outs:", d_outs.shape) torch.Size([32, 121]) hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim) # print("After Hs:",hs.shape) torch.Size([32, 868, 256]) e_outs = self.energy_predictor(hs, mel_masks) # print("e_outs:", e_outs.shape) torch.Size([32, 868]) p_outs = self.pitch_predictor(hs, mel_masks) # print("p_outs:", p_outs.shape) torch.Size([32, 868]) hs = hs + self.pitch_embed(one_hot_pitch) # (B, Lmax, adim) hs = hs + self.energy_embed(one_hot_energy) # (B, Lmax, adim) # forward decoder if olens is not None: h_masks = self._source_mask(olens) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) if is_inference: return before_outs, after_outs, d_outs else: return before_outs, after_outs, d_outs, e_outs, p_outs def forward(self, xs, ilens, ys, olens, ds, es, ps, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) xs = xs[:, :max(ilens)] ys = ys[:, :max(olens)] # forward propagation before_outs, after_outs, d_outs, e_outs, p_outs = self._forward( xs, ilens, ys, olens, ds, es, ps, is_inference=False) # modifiy mod part of groundtruth if hp.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] # apply mask to remove padded part if self.use_masking: in_masks = make_non_pad_mask(ilens).to(xs.device) d_outs = d_outs.masked_select(in_masks) ds = ds.masked_select(in_masks) out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) mel_masks = make_non_pad_mask(olens).to(ys.device) before_outs = before_outs.masked_select(out_masks) es = es.masked_select(mel_masks) # Write size ps = ps.masked_select(mel_masks) # Write size e_outs = e_outs.masked_select(mel_masks) # Write size p_outs = p_outs.masked_select(mel_masks) # Write size after_outs = (after_outs.masked_select(out_masks) if after_outs is not None else None) ys = ys.masked_select(out_masks) # calculate loss before_loss = self.criterion(before_outs, ys) after_loss = 0 if after_outs is not None: after_loss = self.criterion(after_outs, ys) l1_loss = before_loss + after_loss duration_loss = self.duration_criterion(d_outs, ds) energy_loss = self.energy_criterion(e_outs, es) pitch_loss = self.pitch_criterion(p_outs, ps) # make weighted mask and apply it if hp.use_weighted_masking: out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) out_weights = out_masks.float() / out_masks.sum( dim=1, keepdim=True).float() out_weights /= ys.size(0) * ys.size(2) duration_masks = make_non_pad_mask(ilens).to(ys.device) duration_weights = ( duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float()) duration_weights /= ds.size(0) # apply weight l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum() duration_loss = (duration_loss.mul(duration_weights).masked_select( duration_masks).sum()) loss = l1_loss + duration_loss + energy_loss + pitch_loss report_keys = [ { "l1_loss": l1_loss.item() }, { "before_loss": before_loss.item() }, { "after_loss": after_loss.item() }, { "duration_loss": duration_loss.item() }, { "energy_loss": energy_loss.item() }, { "pitch_loss": pitch_loss.item() }, { "loss": loss.item() }, ] # report extra information if self.use_scaled_pos_enc: report_keys += [ { "encoder_alpha": self.encoder.embed[-1].alpha.data.item() }, { "decoder_alpha": self.decoder.embed[-1].alpha.data.item() }, ] #self.reporter.report(report_keys) return loss, report_keys def calculate_all_attentions(self, xs, ilens, ys, olens, ds, es, ps, *args, **kwargs): """Calculate all of the attention weights. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). Returns: dict: Dict of attention weights and outputs. """ with torch.no_grad(): # remove unnecessary padded part (for multi-gpus) xs = xs[:, :max(ilens)] ys = ys[:, :max(olens)] # forward propagation outs, _, _, _, _ = self._forward(xs, ilens, ys, olens, ds, es, ps, is_inference=False) att_ws_dict = dict() if hp.attn_plot: for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): attn = m.attn.cpu().numpy() if "encoder" in name: attn = [ a[:, :l, :l] for a, l in zip(attn, ilens.tolist()) ] elif "decoder" in name: if "src" in name: attn = [ a[:, :ol, :il] for a, il, ol in zip( attn, ilens.tolist(), olens.tolist()) ] elif "self" in name: attn = [ a[:, :l, :l] for a, l in zip(attn, olens.tolist()) ] else: logging.warning("unknown attention module: " + name) else: logging.warning("unknown attention module: " + name) att_ws_dict[name] = attn att_ws_dict["predicted_fbank"] = [ m[:l].T for m, l in zip(outs.cpu().numpy(), olens.tolist()) ] return att_ws_dict def inference(self, x, inference_args, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). inference_args (Namespace): Dummy for compatibility. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (1, L, odim). None: Dummy for compatibility. None: Dummy for compatibility. """ # setup batch axis ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs = x.unsqueeze(0) # inference _, outs, _ = self._forward(xs, ilens, is_inference=True) # (L, odim) return outs[0], None, None def _source_mask(self, ilens): """Make masks for self-attention. Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1) def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): # initialize parameters initialize(self, init_type) # initialize alpha in scaled positional encoding if self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) def _transfer_from_teacher(self, transferred_encoder_module): if transferred_encoder_module == "all": for (n1, p1), (n2, p2) in zip(self.encoder.named_parameters(), self.teacher.encoder.named_parameters()): assert n1 == n2, "It seems that encoder structure is different." assert p1.shape == p2.shape, "It seems that encoder size is different." p1.data.copy_(p2.data) elif transferred_encoder_module == "embed": student_shape = self.encoder.embed[0].weight.data.shape teacher_shape = self.teacher.encoder.embed[0].weight.data.shape assert student_shape == teacher_shape, "It seems that embed dimension is different." self.encoder.embed[0].weight.data.copy_( self.teacher.encoder.embed[0].weight.data) else: raise NotImplementedError("Support only all or embed.") @property def attention_plot_class(self): """Return plot class for attention weight plot.""" return TTSPlot @property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ["loss", "l1_loss", "duration_loss"] if self.use_scaled_pos_enc: plot_keys += ["encoder_alpha", "decoder_alpha"] return plot_keys
def __init__(self, idim, odim, args=None): """Initialize TTS-Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. """ # initialize base classes torch.nn.Module.__init__(self) # store hyperparameters self.idim = idim self.odim = odim self.use_scaled_pos_enc = hp.use_scaled_pos_enc self.reduction_factor = hp.reduction_factor self.loss_type = "L1" self.use_guided_attn_loss = True if self.use_guided_attn_loss: if hp.num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = hp.elayers else: self.num_layers_applied_guided_attn = hp.num_layers_applied_guided_attn if hp.num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = hp.aheads else: self.num_heads_applied_guided_attn = hp.num_heads_applied_guided_attn self.modules_applied_guided_attn = hp.modules_applied_guided_attn # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding encoder_input_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=hp.adim, padding_idx=padding_idx ) self.encoder = Encoder( idim=idim, attention_dim=hp.adim, attention_heads=hp.aheads, linear_units=hp.eunits, input_layer=encoder_input_layer, dropout_rate=hp.transformer_enc_dropout_rate, positional_dropout_rate=hp.transformer_enc_positional_dropout_rate, attention_dropout_rate=hp.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=hp.encoder_normalize_before, concat_after=hp.encoder_concat_after ) # define core decoder if hp.dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet( idim=odim, n_layers=hp.dprenet_layers, n_units=hp.dprenet_units, dropout_rate=hp.dprenet_dropout_rate ), torch.nn.Linear(hp.dprenet_units, hp.adim) ) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=-1, attention_dim=hp.adim, attention_heads=hp.aheads, linear_units=hp.dunits, dropout_rate=hp.transformer_dec_dropout_rate, positional_dropout_rate=hp.transformer_dec_positional_dropout_rate, self_attention_dropout_rate=hp.transformer_dec_attn_dropout_rate, src_attention_dropout_rate=hp.transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=hp.decoder_normalize_before, concat_after=hp.decoder_concat_after ) # define final projection self.feat_out = torch.nn.Linear(hp.adim, odim * hp.reduction_factor) self.prob_out = torch.nn.Linear(hp.adim, hp.reduction_factor) # define postnet self.postnet = None if hp.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=hp.postnet_layers, n_chans=hp.postnet_chans, n_filts=hp.postnet_filts, use_batch_norm=hp.use_batch_norm, dropout_rate=hp.postnet_dropout_rate ) # define loss function self.criterion = TransformerLoss(use_masking=hp.use_masking, bce_pos_weight=hp.bce_pos_weight) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=0.4, alpha=1.0, ) # initialize parameters self._reset_parameters(init_type=hp.transformer_init, init_enc_alpha=hp.initial_encoder_alpha, init_dec_alpha=hp.initial_decoder_alpha)
def train(opt): loader = get_loader(opt, 'train') opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length summry_writer = tensorboardX.SummaryWriter() infos = {} histories = {} if opt.start_from is not None: infos_path = os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl') histories_path = os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl') # open infos and check if models are compatible with open(infos_path, 'rb') as f: infos = pickle.load(f) saved_model_opt = infos['opt'] need_be_same = ['hidden_size'] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme],\ "Command line argument and saved model disagree on %s"%(checkme) if os.path.isfile(histories_path): with open(histories_path, 'rb') as f: histories = pickle.load(f) iteration = infos.get('iter', 0) current_epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) if opt.load_best_score == 1: best_val_score = infos.get("best_val_score", None) encoder = Encoder() decoder = Decoder(opt) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") encoder = encoder.to(device) decoder = decoder.to(device) criterion = utils.LanguageModelCriterion().to(device) optimizer = optim.Adam(decoder.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) if vars(opt).get('start_from', None) is not None: optimizer_path = os.path.join(opt.start_from, 'optimizer.pth') optimizer.load_state_dict(torch.load(optimizer_path)) total_step = len(loader) start = time.time() for epoch in range(current_epoch, opt.max_epochs): if epoch > opt.learning_rate_decay_start and \ opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every deccay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * deccay_factor utils.set_lr(optimizer, opt.current_lr) print("learing rate change form {} to {}".format( opt.learning_rate, opt.current_lr)) else: opt.current_lr = opt.learning_rate for i, data in enumerate(loader, iteration): if i > total_step - 1: iteration = 0 break transform = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) imgs = [] for k in range(data['imgs'].shape[0]): img = torch.tensor(data['imgs'][k], dtype=torch.float) img = transform(img) imgs.append(img) imgs = torch.stack(imgs, dim=0).to(device) labels = torch.tensor(data['labels'].astype(np.int32), dtype=torch.long).to(device) masks = torch.tensor(data['masks'], dtype=torch.float).to(device) with torch.no_grad(): features = encoder(imgs) preds = decoder(features, labels) loss = criterion(preds, labels[:, 1:], masks[:, 1:]) optimizer.zero_grad() loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() print("iter: {}/{} (epoch {}), train loss = {:.3f}, time/batch = {}"\ .format(i, total_step, epoch, train_loss, utils.get_duration(start))) log_iter = i + epoch * total_step # write training loss summary if (i % opt.losses_log_every) == 0: summry_writer.add_scalar('train_loss', train_loss, log_iter) summry_writer.add_scalar('learning_rate', opt.current_lr, log_iter) # make evaluation on validation set, and save model if (i % opt.save_checkpoint_every == 0): #eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss,\ predictions,\ lang_stats = eval_utils.eval_split(encoder, decoder, criterion, opt, eval_kwargs) summry_writer.add_scalar('valaidation loss', val_loss, log_iter) if lang_stats is not None: for metric, score in lang_stats.items(): summry_writer.add_scalar(metric, score, log_iter) val_result_history[i] = { "loss": val_loss, "lang_stats": lang_stats, "predictions": predictions } if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss.item() best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True if not os.path.exists(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) checkpoint_ptah = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(decoder.state_dict(), checkpoint_ptah) print("model saved to {}".format(checkpoint_ptah)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = i + 1 infos['epoch'] = epoch infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.ix_to_word histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history infos_path = os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl') histories_path = os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl') with open(infos_path, 'wb') as f: pickle.dump(infos, f) print("infos saved into {}".format(infos_path)) with open(histories_path, 'wb') as f: pickle.dump(histories, f) print('histories saved into {}'.format(histories_path)) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(decoder.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: pickle.dump(infos, f) summry_writer.close()