def forward(self, xs, ilens, ys, olens, spembs=None): """Calculate forward propagation. Args: xs (Tensor): Batch of the padded sequences of character ids (B, Tmax). ilens (Tensor): Batch of lengths of each input sequence (B,). ys (Tensor): Batch of the padded sequence of target features (B, Lmax, odim). olens (Tensor): Batch of lengths of each output sequence (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). Returns: Tensor: Batch of durations (B, Tmax). """ if isinstance(self.teacher_model, Transformer): att_ws = self._calculate_encoder_decoder_attentions(xs, ilens, ys, olens, spembs=spembs) # TODO(kan-bayashi): fix this issue # this does not work in multi-gpu case. registered buffer is not saved. if int(self.diag_head_idx) == -1: self._init_diagonal_head(att_ws) att_ws = att_ws[:, self.diag_head_idx] else: # NOTE(kan-bayashi): Here we assume that the teacher is tacotron 2 att_ws = self.teacher_model.calculate_all_attentions( xs, ilens, ys, spembs=spembs, keep_tensor=True) durations = [ self._calculate_duration(att_w, ilen, olen) for att_w, ilen, olen in zip(att_ws, ilens, olens) ] return pad_list(durations, 0)
def common_collate_fn( data: Collection[Tuple[str, Dict[str, np.ndarray]]], float_pad_value: Union[float, int] = 0.0, int_pad_value: int = -32768, not_sequence: Collection[str] = (), ) -> Tuple[List[str], Dict[str, torch.Tensor]]: """Concatenate ndarray-list to an array and convert to torch.Tensor. Examples: >>> from muskit.samplers.constant_batch_sampler import ConstantBatchSampler, >>> import muskit.tasks.abs_task >>> from muskit.train.dataset import MuskitDataset >>> sampler = ConstantBatchSampler(...) >>> dataset = MuskitDataset(...) >>> keys = next(iter(sampler) >>> batch = [dataset[key] for key in keys] >>> batch = common_collate_fn(batch) >>> model(**batch) Note that the dict-keys of batch are propagated from that of the dataset as they are. """ assert check_argument_types() uttids = [u for u, _ in data] data = [d for _, d in data] assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" assert all( not k.endswith("_lengths") for k in data[0] ), f"*_lengths is reserved: {list(data[0])}" output = {} for key in data[0]: # NOTE(kamo): # Each models, which accepts these values finally, are responsible # to repaint the pad_value to the desired value for each tasks. if data[0][key].dtype.kind == "i": pad_value = int_pad_value else: pad_value = float_pad_value array_list = [d[key] for d in data] # Assume the first axis is length: # tensor_list: Batch x (Length, ...) tensor_list = [torch.from_numpy(a) for a in array_list] # tensor: (Batch, Length, ...) tensor = pad_list(tensor_list, pad_value) output[key] = tensor # lens: (Batch,) if key not in not_sequence: lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long) output[key + "_lengths"] = lens output = (uttids, output) # logging.info(f'output:{output}') # TODO allow the tuple type # assert check_return_type(output) return output
def add_sos_eos(ys_pad, sos, eos, ignore_id): """Add <sos> and <eos> labels. :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param int sos: index of <sos> :param int eos: index of <eos> :param int ignore_id: index of padding :return: padded tensor (B, Lmax) :rtype: torch.Tensor :return: padded tensor (B, Lmax) :rtype: torch.Tensor """ from muskit.torch_utils.nets_utils import pad_list _sos = ys_pad.new([sos]) _eos = ys_pad.new([eos]) ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys ys_in = [torch.cat([_sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, _eos], dim=0) for y in ys] return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
def forward( self, input: torch.Tensor, input_lengths: torch.Tensor = None, feats_lengths: torch.Tensor = None, durations: torch.Tensor = None, durations_lengths: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # If not provide, we assume that the inputs have the same length if input_lengths is None: input_lengths = (input.new_ones(input.shape[0], dtype=torch.long) * input.shape[1]) # Domain-conversion: e.g. Stft: time -> time-freq input_stft, energy_lengths = self.stft(input, input_lengths) assert input_stft.dim() >= 4, input_stft.shape assert input_stft.shape[-1] == 2, input_stft.shape # input_stft: (..., F, 2) -> (..., F) input_power = input_stft[..., 0]**2 + input_stft[..., 1]**2 # sum over frequency (B, N, F) -> (B, N) energy = torch.sqrt(torch.clamp(input_power.sum(dim=2), min=1.0e-10)) # (Optional): Adjust length to match with the mel-spectrogram if feats_lengths is not None: energy = [ self._adjust_num_frames(e[:el].view(-1), fl) for e, el, fl in zip(energy, energy_lengths, feats_lengths) ] energy_lengths = feats_lengths # (Optional): Average by duration to calculate token-wise energy if self.use_token_averaged_energy: durations = durations * self.reduction_factor energy = [ self._average_by_duration(e[:el].view(-1), d) for e, el, d in zip(energy, energy_lengths, durations) ] energy_lengths = durations_lengths # Padding if isinstance(energy, list): energy = pad_list(energy, 0.0) # Return with the shape (B, T, 1) return energy.unsqueeze(-1), energy_lengths
def forward( self, input: torch.Tensor, input_lengths: torch.Tensor = None, feats_lengths: torch.Tensor = None, durations: torch.Tensor = None, durations_lengths: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # If not provide, we assume that the inputs have the same length if input_lengths is None: input_lengths = (input.new_ones(input.shape[0], dtype=torch.long) * input.shape[1]) # F0 extraction pitch = [ self._calculate_f0(x[:xl]) for x, xl in zip(input, input_lengths) ] # (Optional): Adjust length to match with the mel-spectrogram if feats_lengths is not None: pitch = [ self._adjust_num_frames(p, fl).view(-1) for p, fl in zip(pitch, feats_lengths) ] # (Optional): Average by duration to calculate token-wise f0 if self.use_token_averaged_f0: durations = durations * self.reduction_factor pitch = [ self._average_by_duration(p, d).view(-1) for p, d in zip(pitch, durations) ] pitch_lengths = durations_lengths else: pitch_lengths = input.new_tensor([len(p) for p in pitch], dtype=torch.long) # Padding pitch = pad_list(pitch, 0.0) # Return with the shape (B, T, 1) return pitch.unsqueeze(-1), pitch_lengths
def forward(self, xs, ds, alpha=1.0): """Calculate forward propagation. Args: xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D). ds (LongTensor): Batch of durations of each frame (B, T). alpha (float, optional): Alpha value to control speed of speech. Returns: Tensor: replicated input tensor based on durations (B, T*, D). """ if alpha != 1.0: assert alpha > 0 ds = torch.round(ds.float() * alpha).long() if ds.sum() == 0: logging.warning("predicted durations includes all 0 sequences. " "fill the first element with 1.") # NOTE(kan-bayashi): This case must not be happened in teacher forcing. # It will be happened in inference with a bad duration predictor. # So we do not need to care the padded sequence case here. ds[ds.sum(dim=1).eq(0)] = 1 repeat = [torch.repeat_interleave(x, d, dim=0) for x, d in zip(xs, ds)] return pad_list(repeat, self.pad_value)
def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None): """Forward function. Args: x: (Batch, Time, Freq) x_lengths: (Batch,) """ if x_lengths is None or all(le == x_lengths[0] for le in x_lengths): # Note that applying same warping for each sample y = time_warp(x, window=self.window, mode=self.mode) else: # FIXME(kamo): I have no idea to batchify Timewarp ys = [] for i in range(x.size(0)): _y = time_warp( x[i][None, : x_lengths[i]], window=self.window, mode=self.mode, )[0] ys.append(_y) y = pad_list(ys, 0.0) return y, x_lengths
def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, singing: torch.Tensor, singing_lengths: torch.Tensor, durations: Optional[torch.Tensor] = None, durations_lengths: Optional[torch.Tensor] = None, score: Optional[torch.Tensor] = None, score_lengths: Optional[torch.Tensor] = None, pitch: Optional[torch.Tensor] = None, pitch_lengths: Optional[torch.Tensor] = None, tempo: Optional[torch.Tensor] = None, tempo_lengths: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, energy_lengths: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, flag_IsValid=False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Caclualte outputs and return the loss tensor. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). singing (Tensor): Singing waveform tensor (B, T_wav). singing_lengths (Tensor): Singing length tensor (B,). duration (Optional[Tensor]): Duration tensor. - phone id sequence duration_lengths (Optional[Tensor]): Duration length tensor (B,). score (Optional[Tensor]): Duration tensor. score_lengths (Optional[Tensor]): Duration length tensor (B,). pitch (Optional[Tensor]): Pitch tensor. pitch_lengths (Optional[Tensor]): Pitch length tensor (B,). energy (Optional[Tensor]): Energy tensor. energy_lengths (Optional[Tensor]): Energy length tensor (B,). spembs (Optional[Tensor]): Speaker embedding tensor (B, D). sids (Optional[Tensor]): Speaker ID tensor (B, 1). lids (Optional[Tensor]): Language ID tensor (B, 1). Returns: Tensor: Loss scalar tensor. Dict[str, float]: Statistics to be monitored. Tensor: Weight tensor to summarize losses. """ with autocast(False): # if self.text_extract is not None and text is None: # text, text_lengths = self.text_extract( # input=text, # input_lengths=text_lengths, # ) # Extract features # logging.info(f'singing.shape={singing.shape}, singing_lengths.shape={singing_lengths.shape}') if self.feats_extract is not None: feats, feats_lengths = self.feats_extract( singing, singing_lengths ) # singing to spec feature (frame level) else: # Use precalculated feats (feats_type != raw case) feats, feats_lengths = singing, singing_lengths # Extract auxiliary features # score : 128 midi pitch # tempo : bpm # duration : # input-> phone-id seqence | output -> frame level(取众数 from window) or syllable level ds = None if isinstance(self.score_feats_extract, FrameScoreFeats): ( label, label_lengths, score, score_lengths, tempo, tempo_lengths, ) = self.score_feats_extract( durations=durations.unsqueeze(-1), durations_lengths=durations_lengths, score=score.unsqueeze(-1), score_lengths=score_lengths, tempo=tempo.unsqueeze(-1), tempo_lengths=tempo_lengths, ) label = label[:, : label_lengths.max()] # for data-parallel # calculate durations, new text & text_length # Syllable Level duration info needs phone # NOTE(Shuai) Duplicate adjacent phones will appear in text files sometimes # e.g. oniku_0000000000000000hato_0002 # 10.951 11.107 sh # 11.107 11.336 i # 11.336 11.610 i # 11.610 11.657 k _text_cal = [] _text_length_cal = [] ds = [] for i, _ in enumerate(label_lengths): _phone = label[i, : label_lengths[i]] _output, counts = torch.unique_consecutive( _phone, return_counts=True ) _text_cal.append(_output) _text_length_cal.append(len(_output)) ds.append(counts) ds = pad_list(ds, pad_value=0).to(text.device) text = pad_list(_text_cal, pad_value=0).to( text.device, dtype=torch.long ) text_lengths = torch.tensor(_text_length_cal).to(text.device) elif isinstance(self.score_feats_extract, SyllableScoreFeats): extractMethod_frame = FrameScoreFeats( fs=self.score_feats_extract.fs, n_fft=self.score_feats_extract.n_fft, win_length=self.score_feats_extract.win_length, hop_length=self.score_feats_extract.hop_length, window=self.score_feats_extract.window, center=self.score_feats_extract.center, ) # logging.info(f"extractMethod_frame: {extractMethod_frame}") ( labelFrame, labelFrame_lengths, scoreFrame, scoreFrame_lengths, tempoFrame, tempoFrame_lengths, ) = extractMethod_frame( durations=durations.unsqueeze(-1), durations_lengths=durations_lengths, score=score.unsqueeze(-1), score_lengths=score_lengths, tempo=tempo.unsqueeze(-1), tempo_lengths=tempo_lengths, ) labelFrame = labelFrame[ :, : labelFrame_lengths.max() ] # for data-parallel scoreFrame = scoreFrame[ :, : scoreFrame_lengths.max() ] # for data-parallel # Extract Syllable Level label, score, tempo information from Frame Level ( label, label_lengths, score, score_lengths, tempo, tempo_lengths, ) = self.score_feats_extract( durations=labelFrame, durations_lengths=labelFrame_lengths, score=scoreFrame, score_lengths=scoreFrame_lengths, tempo=tempoFrame, tempo_lengths=tempoFrame_lengths, ) # calculate durations, represent syllable encoder outputs to feats mapping # Syllable Level duration info needs phone & midi ds = [] for i, _ in enumerate(labelFrame_lengths): assert labelFrame_lengths[i] == scoreFrame_lengths[i] assert label_lengths[i] == score_lengths[i] frame_length = labelFrame_lengths[i] _phoneFrame = labelFrame[i, :frame_length] _midiFrame = scoreFrame[i, :frame_length] # Clean _phoneFrame & _midiFrame for index in range(frame_length): if _phoneFrame[index] == 0 and _midiFrame[index] == 0: frame_length -= 1 feats_lengths[i] -= 1 syllable_length = label_lengths[i] _phoneSyllable = label[i, :syllable_length] _midiSyllable = score[i, :syllable_length] # logging.info(f"_phoneFrame: {_phoneFrame}, _midiFrame: {_midiFrame}") # logging.info(f"_phoneSyllable: {_phoneSyllable}, _midiSyllable: {_midiSyllable}, _tempoSyllable: {tempo[i]}") start_index = 0 ds_tmp = [] flag_finish = 0 for index in range(syllable_length): _findPhone = _phoneSyllable[index] _findMidi = _midiSyllable[index] _length = 0 if flag_finish == 1: # Fix error in _phoneSyllable & _midiSyllable label[i, index] = 0 score[i, index] = 0 tempo[i, index] = 0 label_lengths[i] -= 1 score_lengths[i] -= 1 tempo_lengths[i] -= 1 else: for indexFrame in range(start_index, frame_length): if ( _phoneFrame[indexFrame] == _findPhone and _midiFrame[indexFrame] == _findMidi ): _length += 1 else: # logging.info(f"_findPhone: {_findPhone}, _findMidi: {_findMidi}, _length: {_length}") ds_tmp.append(_length) start_index = indexFrame break if indexFrame == frame_length - 1: # logging.info(f"_findPhone: {_findPhone}, _findMidi: {_findMidi}, _length: {_length}") flag_finish = 1 ds_tmp.append(_length) start_index = indexFrame # logging.info("Finish") break # logging.info(f"ds_tmp: {ds_tmp}, sum(ds_tmp): {sum(ds_tmp)}, frame_length: {frame_length}, feats_lengths[i]: {feats_lengths[i]}") assert ( sum(ds_tmp) == frame_length and sum(ds_tmp) == feats_lengths[i] ) ds.append(torch.tensor(ds_tmp)) ds = pad_list(ds, pad_value=0).to(label.device) if self.pitch_extract is not None and pitch is None: pitch, pitch_lengths = self.pitch_extract( input=singing, input_lengths=singing_lengths, feats_lengths=feats_lengths, ) if self.energy_extract is not None and energy is None: energy, energy_lengths = self.energy_extract( singing, singing_lengths, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, ) # Normalize if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) if self.pitch_normalize is not None: pitch, pitch_lengths = self.pitch_normalize(pitch, pitch_lengths) if self.energy_normalize is not None: energy, energy_lengths = self.energy_normalize(energy, energy_lengths) # Make batch for svs inputs batch = dict( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, flag_IsValid=flag_IsValid, ) if spembs is not None: batch.update(spembs=spembs) if sids is not None: batch.update(sids=sids) if lids is not None: batch.update(lids=lids) if label is not None: label = label.to(dtype=torch.long) batch.update(label=label, label_lengths=label_lengths) if score is not None and pitch is None: score = score.to(dtype=torch.long) batch.update(midi=score, midi_lengths=score_lengths) if tempo is not None: tempo = tempo.to(dtype=torch.long) batch.update(tempo=tempo, tempo_lengths=tempo_lengths) if ds is not None: batch.update(ds=ds) if self.pitch_extract is not None and pitch is not None: batch.update(midi=pitch, midi_lengths=pitch_lengths) if self.energy_extract is not None and energy is not None: batch.update(energy=energy, energy_lengths=energy_lengths) if self.svs.require_raw_singing: batch.update(singing=singing, singing_lengths=singing_lengths) return self.svs(**batch)
def inference( self, text: torch.Tensor, durations: torch.Tensor, score: torch.Tensor, singing: Optional[torch.Tensor] = None, pitch: Optional[torch.Tensor] = None, tempo: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, **decode_config, ) -> Dict[str, torch.Tensor]: """Caclualte features and return them as a dict. Args: text (Tensor): Text index tensor (T_text). singing (Tensor): Singing waveform tensor (T_wav). spembs (Optional[Tensor]): Speaker embedding tensor (D,). sids (Optional[Tensor]): Speaker ID tensor (1,). lids (Optional[Tensor]): Language ID tensor (1,). durations (Optional[Tensor): Duration tensor. pitch (Optional[Tensor): Pitch tensor. energy (Optional[Tensor): Energy tensor. Returns: Dict[str, Tensor]: Dict of outputs. """ durations_lengths = torch.tensor([len(durations)]) score_lengths = torch.tensor([len(score)]) tempo_lengths = torch.tensor([len(tempo)]) assert durations_lengths == score_lengths and durations_lengths == tempo_lengths # unsqueeze of singing must be here, or it'll cause error in the return dim of STFT text = text.unsqueeze(0) # for data-parallel durations = durations.unsqueeze(0) # for data-parallel score = score.unsqueeze(0) # for data-parallel tempo = tempo.unsqueeze(0) # for data-parallel # Extract auxiliary features # score : 128 midi pitch # tempo : bpm # duration : # input-> phone-id seqence | output -> frame level(取众数 from window) or syllable level ds = None batch_size = text.size(0) assert batch_size == 1 if isinstance(self.score_feats_extract, FrameScoreFeats): ( label, label_lengths, score, score_lengths, tempo, tempo_lengths, ) = self.score_feats_extract( durations=durations.unsqueeze(-1), durations_lengths=durations_lengths, score=score.unsqueeze(-1), score_lengths=score_lengths, tempo=tempo.unsqueeze(-1), tempo_lengths=tempo_lengths, ) # calculate durations, new text & text_length # Syllable Level duration info needs phone # NOTE(Shuai) Duplicate adjacent phones will appear in text files sometimes # e.g. oniku_0000000000000000hato_0002 # 10.951 11.107 sh # 11.107 11.336 i # 11.336 11.610 i # 11.610 11.657 k _text_cal = [] _text_length_cal = [] ds = [] for i in range(batch_size): _phone = label[i] _output, counts = torch.unique_consecutive(_phone, return_counts=True) _text_cal.append(_output) _text_length_cal.append(len(_output)) ds.append(counts) ds = pad_list(ds, pad_value=0).to(text.device) text = pad_list(_text_cal, pad_value=0).to(text.device, dtype=torch.long) text_lengths = torch.tensor(_text_length_cal).to(text.device) elif isinstance(self.score_feats_extract, SyllableScoreFeats): extractMethod_frame = FrameScoreFeats( fs=self.score_feats_extract.fs, n_fft=self.score_feats_extract.n_fft, win_length=self.score_feats_extract.win_length, hop_length=self.score_feats_extract.hop_length, window=self.score_feats_extract.window, center=self.score_feats_extract.center, ) # logging.info(f"extractMethod_frame: {extractMethod_frame}") ( labelFrame, labelFrame_lengths, scoreFrame, scoreFrame_lengths, tempoFrame, tempoFrame_lengths, ) = extractMethod_frame( durations=durations.unsqueeze(-1), durations_lengths=durations_lengths, score=score.unsqueeze(-1), score_lengths=score_lengths, tempo=tempo.unsqueeze(-1), tempo_lengths=tempo_lengths, ) labelFrame = labelFrame[:, : labelFrame_lengths.max()] # for data-parallel scoreFrame = scoreFrame[:, : scoreFrame_lengths.max()] # for data-parallel # Extract Syllable Level label, score, tempo information from Frame Level ( label, label_lengths, score, score_lengths, tempo, tempo_lengths, ) = self.score_feats_extract( durations=labelFrame, durations_lengths=labelFrame_lengths, score=scoreFrame, score_lengths=scoreFrame_lengths, tempo=tempoFrame, tempo_lengths=tempoFrame_lengths, ) # calculate durations, represent syllable encoder outputs to feats mapping # Syllable Level duration info needs phone & midi ds = [] for i, _ in enumerate(labelFrame_lengths): assert labelFrame_lengths[i] == scoreFrame_lengths[i] assert label_lengths[i] == score_lengths[i] frame_length = labelFrame_lengths[i] _phoneFrame = labelFrame[i, :frame_length] _midiFrame = scoreFrame[i, :frame_length] # Clean _phoneFrame & _midiFrame for index in range(frame_length): if _phoneFrame[index] == 0 and _midiFrame[index] == 0: frame_length -= 1 feats_lengths[i] -= 1 syllable_length = label_lengths[i] _phoneSyllable = label[i, :syllable_length] _midiSyllable = score[i, :syllable_length] # logging.info(f"_phoneFrame: {_phoneFrame}, _midiFrame: {_midiFrame}") # logging.info(f"_phoneSyllable: {_phoneSyllable}, _midiSyllable: {_midiSyllable}, _tempoSyllable: {tempo[i]}") start_index = 0 ds_tmp = [] flag_finish = 0 for index in range(syllable_length): _findPhone = _phoneSyllable[index] _findMidi = _midiSyllable[index] _length = 0 if flag_finish == 1: # Fix error in _phoneSyllable & _midiSyllable label[i, index] = 0 score[i, index] = 0 tempo[i, index] = 0 label_lengths[i] -= 1 score_lengths[i] -= 1 tempo_lengths[i] -= 1 else: for indexFrame in range(start_index, frame_length): if ( _phoneFrame[indexFrame] == _findPhone and _midiFrame[indexFrame] == _findMidi ): _length += 1 else: # logging.info(f"_findPhone: {_findPhone}, _findMidi: {_findMidi}, _length: {_length}") ds_tmp.append(_length) start_index = indexFrame break if indexFrame == frame_length - 1: # logging.info(f"_findPhone: {_findPhone}, _findMidi: {_findMidi}, _length: {_length}") flag_finish = 1 ds_tmp.append(_length) start_index = indexFrame # logging.info("Finish") break logging.info( f"ds_tmp: {ds_tmp}, sum(ds_tmp): {sum(ds_tmp)}, frame_length: {frame_length}, feats_lengths[i]: {feats_lengths[i]}" ) assert sum(ds_tmp) == frame_length and sum(ds_tmp) == feats_lengths[i] ds.append(torch.tensor(ds_tmp)) ds = pad_list(ds, pad_value=0).to(label.device) input_dict = dict(text=text) if score is not None and pitch is None: score = score.to(dtype=torch.long) input_dict["midi"] = score if durations is not None: label = label.to(dtype=torch.long) input_dict["label"] = label if ds is not None: input_dict.update(ds=ds) if tempo is not None: tempo = tempo.to(dtype=torch.long) input_dict.update(tempo=tempo) if spembs is not None: input_dict.update(spembs=spembs) if sids is not None: input_dict.update(sids=sids) if lids is not None: input_dict.update(lids=lids) # output_dict = self.svs.inference(**input_dict, **decode_config) outs, probs, att_ws = self.svs.inference(**input_dict) if self.normalize is not None: # NOTE: normalize.inverse is in-place operation outs_denorm = self.normalize.inverse(outs.clone()[None])[0][0] else: outs_denorm = outs return outs, outs_denorm, probs, att_ws
def forward(self, feats, midi, label, spk_ids, feats_lengths, midi_lengths, label_lengths): """forward. Args: feats: Batch of lengths (B, T_feats, adim). midi: Batch of lengths (B, T_feats). label: Batch of lengths (B, T_feats). feats_lengths: Batch of input lengths (B,). Note that the feats, midi, label are time-aligned on frame level. """ h_masks = self._source_mask(feats_lengths) # (B, 1, T_feats) midi_loss, label_loss, speaker_loss = 0, 0, 0 masked_midi_outs, masked_label_outs, speaker_predict = None, None, None batch_size = feats.shape[0] if "midi" in self.predict_type: # midi predict zs_midi, _ = self.predictor_midi(feats, h_masks) # (B, T_feats, adim=80) zs_midi = self.linear_out_midi(zs_midi) # (B, T_feats, midi classes=129) # loss calculation if self.predict_criterion_type == "CrossEntropy": probs_midi = zs_midi # (B, T_feats, midi classes=129) midi_loss = self.predictor_criterion(probs_midi, midi, feats_lengths) # make midi predict output for cycle masked_probs_midi = probs_midi * h_masks.permute(0,2,1) # (B, T_feats, adim=80) masked_midi_outs = torch.argmax(F.softmax(masked_probs_midi,dim=-1), dim=-1) # (B, T_feats) elif self.predict_criterion_type == "CTC": # aggregate G.T.-midi # NOTE(Shuai) midi need "+1" in the begin of loss calculation & "-1" in the end of prediction # because index-0 is for <blank> in CTC-loss calculation. Note that CTC-loss can`t make time-aligned midi prediction for cycle-singing _midi_cal = [] _midi_length_cal = [] ds = [] for i, _ in enumerate(midi_lengths): _midi = midi[i, :midi_lengths[i]] + 1 _output, counts = torch.unique_consecutive(_midi, return_counts=True) _midi_cal.append(_output) _midi_length_cal.append(len(_output)) ds.append(counts) # ds = pad_list(ds, pad_value=0).to(midi.device) midi = pad_list(_midi_cal, pad_value=0).to(midi.device, dtype=torch.long) midi_lengths = torch.tensor(_midi_length_cal).to(midi.device) # logging.info(f"midi: {midi.shape}") # logging.info(f"midi: {midi}") # logging.info(f"midi_lengths: {midi_lengths}") # quit() probs_midi = F.log_softmax(zs_midi, dim=-1).permute(1,0,2) # CTC need shape as (T, N-batch, Class num) midi_loss = self.predictor_criterion(probs_midi, midi, feats_lengths, midi_lengths) if "label" in self.predict_type: # label predict zs_label, _ = self.predictor_label(feats, h_masks) # (B, T_feats, adim=80) zs_label = self.linear_out_label(zs_label) # (B, T_feats, midi classes=50) if self.predict_criterion_type == "CrossEntropy": probs_label = zs_label # (B, T_feats, midi classes=50) label_loss = self.predictor_criterion(probs_label, label, feats_lengths) # make label predict output for cycle masked_probs_label = probs_label * h_masks.permute(0,2,1) # (B, T_feats, adim=80) masked_label_outs = torch.argmax(F.softmax(masked_probs_label,dim=-1), dim=-1) # (B, T_feats) elif self.predict_criterion_type == "CTC": # aggregate G.T.-label probs_label = F.log_softmax(zs_label, dim=-1).permute(1,0,2) # CTC need shape as (T, N-batch, Class num) label_loss = self.predictor_criterion(probs_label, label, feats_lengths, label_lengths) if "spk" in self.predict_type: # speaker predict packed_spk_embed = nn.utils.rnn.pack_padded_sequence(feats, feats_lengths.type(torch.int64).to("cpu"), batch_first=True, enforce_sorted=False) packed_spk_out, hn = self.predictor_spk(packed_spk_embed) spk_out, _ = nn.utils.rnn.pad_packed_sequence(packed_spk_out, batch_first=True) # hn - (dim direction * layer nums, N_batch, eunits) hn = hn.reshape(batch_size, -1) # (N_batch, dim direction * layer nums * eunits_spk) probs_spk = self.linear_out_spk(hn) # (N_batch, speaker classes=4) speaker_loss = self.predictor_spk_criterion(probs_spk, spk_ids) speaker_predict = torch.argmax(F.softmax(probs_spk,dim=-1), dim=-1) # (B, 1) return midi_loss, label_loss, speaker_loss, masked_midi_outs, masked_label_outs, speaker_predict