def dist_custom_collate(batch, dist_bins=None, angle_bins=None, crop_size=64): feat = torch.zeros(len(batch), crop_size, 300, 22) feat_mask = torch.ones(len(batch), crop_size, 300) dist_labels = torch.zeros(len(batch), crop_size, crop_size) phi_labels = torch.zeros(len(batch), crop_size) psi_labels = torch.zeros(len(batch), crop_size) for i in range(len(batch)): seq_len = batch[i][0].shape[1] msa_len = batch[i][0].shape[0] sampled_msa = torch.randperm(msa_len)[:300] if msa_len > 300: msa_len = 300 if seq_len <= crop_size: crop = 0 else: crop = random.randint(0, seq_len - crop_size) seq_len = crop_size one_hot = torch.nn.functional.one_hot(batch[i][0][sampled_msa], num_classes=22).permute(1,0,2) feat[i, :seq_len, :msa_len] = one_hot[crop:crop+seq_len] feat_mask[i, seq_len:, msa_len:] = 0 dist_labels[i, :seq_len, :seq_len] = batch[i][1][crop:crop+seq_len, crop:crop+seq_len,0] dist_labels[i][dist_labels[i] == 0] = -100 phi_labels[i, :seq_len] = batch[i][2][crop:crop+seq_len, 0, 0] phi_labels[i, :seq_len][batch[i][2][crop:crop+seq_len, 0, 1] == 0] = -100 psi_labels[i, :seq_len] = batch[i][2][crop:crop+seq_len, 1, 0] psi_labels[i, :seq_len][batch[i][2][crop:crop+seq_len, 1, 1] == 0] = -100 if dist_bins != None: dist_labels[dist_labels != -100] = torch.bucketize(dist_labels[dist_labels != -100], dist_bins).float() dist_labels[(dist_labels == (dist_bins.shape[0]))] = -100 dist_labels[(dist_labels == 0)] = -100 #dist_labels[dist_labels != -100] -= 1 dist_labels = dist_labels.long() if angle_bins != None: phi_labels[phi_labels != -100] = torch.bucketize(phi_labels[phi_labels != -100], angle_bins).float() - 1 psi_labels[psi_labels != -100] = torch.bucketize(psi_labels[psi_labels != -100], angle_bins).float() - 1 phi_labels = phi_labels.long() psi_labels = psi_labels.long() return feat, feat_mask, dist_labels, phi_labels, psi_labels
def get_energy_emb(self, x, tgt=None, factor=1.0): out = self.energy_predictor(x) bins = self.energy_bins.to(x.device) if tgt is None: out = out * factor emb = self.embed_energy(torch.bucketize(out, bins)) else: emb = self.embed_energy(torch.bucketize(tgt, bins)) return out, emb
def get_energy_embedding(self, x, target, mask, control): prediction = self.energy_predictor(x, mask) if target is not None: embedding = self.energy_embedding( torch.bucketize(target, self.energy_bins)) else: prediction = prediction * control embedding = self.energy_embedding( torch.bucketize(prediction, self.energy_bins)) return prediction, embedding
def get_energy_embedding( self, x: torch.Tensor, target: Optional[torch.Tensor], mask: torch.Tensor, control: float) -> Tuple[torch.Tensor, torch.Tensor]: prediction = self.energy_predictor(x, mask) if target is not None: embedding = self.energy_embedding( torch.bucketize(target, self.energy_bins)) else: prediction = prediction * control embedding = self.energy_embedding( torch.bucketize(prediction, self.energy_bins)) return prediction, embedding
def forward(self, x, src_mask, mel_mask=None, duration_target=None, pitch_target=None, energy_target=None, max_len=None): log_duration_prediction = self.duration_predictor(x, src_mask) if duration_target is not None: if mel_mask is not None: max_len = mel_mask.shape[2] x, mel_len = self.length_regulator(x, duration_target, max_len) else: duration_rounded = torch.clamp(torch.round( torch.exp(log_duration_prediction) - self.log_offset), min=0) x, mel_len = self.length_regulator(x, duration_rounded, max_len) # TODO: get_mask_from_lengths mel_mask = get_mask_from_lengths(mel_len) if self.pitch_pred: pitch_prediction = self.pitch_predictor(x, mel_mask) if pitch_target is not None: pitch_embedding = self.pitch_embedding( torch.bucketize(pitch_target, self.pitch_bins)) else: pitch_embedding = self.pitch_embedding( torch.bucketize(pitch_prediction, self.pitch_bins)) else: pitch_prediction = None if self.energy_pred: energy_prediction = self.energy_predictor(x, mel_mask) if energy_target is not None: energy_embedding = self.energy_embedding( torch.bucketize(energy_target, self.energy_bins)) else: energy_embedding = self.energy_embedding( torch.bucketize(energy_prediction, self.energy_bins)) else: energy_prediction = None if self.pitch_pred: x = x + pitch_embedding if self.energy_pred: x = x + energy_embedding # x = x + pitch_embedding + energy_embedding return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask
def neighbour_list(pos, box, subcell_size): nsystems = coordinates.shape[0] for s in range(nsystems): spos = pos[s] sbox = box[s] xbins, ybins, zbins = discretize_box(sbox, subcell_size) xidx = torch.bucketize(spos[:, 0], xbins, out_int32=True) yidx = torch.bucketize(spos[:, 1], ybins, out_int32=True) zidx = torch.bucketize(spos[:, 2], zbins, out_int32=True) binidx = torch.stack((xidx, yidx, zidx)).T
def predict_inference(self, text_encoding, pitch_encoding, energy_encoding, duration_encoding, speaker_encoding, noise_encoding, src_mask, max_len, speaker_normalized=True, d_control=1.0, p_control=1.0, e_control=1.0): encodings_cat = torch.cat( (text_encoding, pitch_encoding, speaker_encoding, energy_encoding, noise_encoding), dim=-1) # Duration log_duration_prediction = self.duration_predictor( duration_encoding, src_mask) # [batch_size, src_len] duration_rounded = torch.clamp( (torch.round(torch.exp(log_duration_prediction) - hp.log_offset) * d_control), min=0) encodings_cat, mel_len = self.length_regulator(encodings_cat, duration_rounded, max_len) mel_mask = utils.get_mask_from_lengths(mel_len) text_encoding, pitch_encoding, speaker_encoding, energy_encoding, noise_encoding = torch.split( encodings_cat, hp.encoder_hidden, dim=-1) # Energy energy_prediction = self.energy_predictor(energy_encoding, mel_mask) energy_prediction = energy_prediction * e_control energy_embedding = self.energy_embedding( torch.bucketize(energy_prediction, self.energy_bins)) # Pitch pitch_prediction = self.pitch_predictor( pitch_encoding if speaker_normalized else (pitch_encoding + speaker_encoding), mel_mask) pitch_prediction = pitch_prediction * p_control pitch_embedding = self.pitch_embedding( torch.bucketize(pitch_prediction, self.pitch_bins)) return text_encoding, pitch_embedding, speaker_encoding, energy_embedding, noise_encoding, log_duration_prediction, pitch_prediction, energy_prediction, mel_mask
def forward(self, predicted_patches, target, mask): # reshape target to patches p = self.patch_size target = rearrange(target, "b c (h p1) (w p2) -> b (h w) c (p1 p2) ", p1=p, p2=p) avg_target = target.mean(dim=3) bin_size = self.max_pixel_val / self.output_channel_bits channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size) discretized_target = torch.bucketize(avg_target, channel_bins) discretized_target = F.one_hot(discretized_target, self.output_channel_bits) c, bi = self.channels, self.output_channel_bits discretized_target = rearrange(discretized_target, "b n c bi -> b n (c bi)", c=c, bi=bi) bin_mask = 2**torch.arange(c * bi - 1, -1, -1).to(discretized_target.device, discretized_target.dtype) target_label = torch.sum(bin_mask * discretized_target, -1) predicted_patches = predicted_patches[mask] target_label = target_label[mask] loss = F.cross_entropy(predicted_patches, target_label) return loss
def forward(self, predicted_patches, target, mask): p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device bin_size = mpv / (2**bits) # un-normalize input if exists(self.mean) and exists(self.std): target = target * self.std + self.mean # reshape target to patches target = target.clamp(max=mpv) # clamp just in case avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1=p, p2=p).contiguous() channel_bins = torch.arange(bin_size, mpv, bin_size, device=device) discretized_target = torch.bucketize(avg_target, channel_bins) bin_mask = (2**bits)**torch.arange(0, c, device=device).long() bin_mask = rearrange(bin_mask, 'c -> () () c') target_label = torch.sum(bin_mask * discretized_target, dim=-1) loss = F.cross_entropy(predicted_patches[mask], target_label[mask]) return loss
def interpolate_cubic_derivative(coeffs, points, new_points): index = torch.bucketize(new_points, points) - 1 index = index.clamp(0, coeffs.shape[-1]-1) t = (new_points - points[index]) / (points[index+1] - points[index]) ret = 3*coeffs[0, index] * t + 2*coeffs[1, index] ret = ret * t + coeffs[2, index] return ret
def vector_to_edge_index(idx: Tensor, size: Tuple[int, int], bipartite: bool, force_undirected: bool = False) -> Tensor: if bipartite: # No need to account for self-loops. row = idx.div(size[1], rounding_mode='floor') col = idx % size[1] return torch.stack([row, col], dim=0) elif force_undirected: assert size[0] == size[1] num_nodes = size[0] offset = torch.arange(1, num_nodes, device=idx.device).cumsum(0) end = torch.arange(num_nodes, num_nodes * num_nodes, num_nodes, device=idx.device) row = torch.bucketize(idx, end.sub_(offset), right=True) col = offset[row].add_(idx) % num_nodes return torch.stack([torch.cat([row, col]), torch.cat([col, row])], 0) else: assert size[0] == size[1] num_nodes = size[0] row = idx.div(num_nodes - 1, rounding_mode='floor') col = idx % (num_nodes - 1) col[row <= col] += 1 return torch.stack([row, col], dim=0)
def filter_spikes(x, bounds=[-50, 150]): """ Return interpolated signal and mask vanishing outside NaNs. """ # find spikes spikes, mask = find_spikes(x, bounds) if not len(spikes): return x, mask # interval boundary neighbours neighb = spikes + torch.tensor([-2, 2]) if mask[0]: neighb[0, 0] = neighb[0, 1] if mask[-1]: neighb[-1, -1] = neighb[-1, 0] # map domain to segment ends N = x.shape[0] idx = torch.arange(N) buckets = torch.cat( [torch.tensor([0]), neighb.reshape([-1]), torch.tensor([N])]) interval = torch.bucketize(idx, neighb.reshape([-1])) i0 = bound(buckets[interval], [0, N - 1]) i1 = bound(buckets[interval + 1], [0, N - 1]) # interpolation x0 = x[torch.max(i0, neighb[0, 0])] x1 = x[torch.min(i1, neighb[-1, 1])] slope = (x1 - x0) / (i1 - i0) offset = x0 - slope * i0 x_int = mask * (slope * idx + offset) x_int += ~mask * x return x_int, mask
def energy_to_one_hot(e): # For pytorch > = 1.6.0 bins = torch.linspace(hp.e_min, hp.e_max, steps=255).to(torch.device("cuda" if hp.ngpu > 0 else "cpu")) e_quantize = torch.bucketize(e, bins) return F.one_hot(e_quantize.long(), 256).float()
def pitch_to_one_hot(f0, is_training = True): # Required pytorch >= 1.6.0 bins = torch.exp(torch.linspace(np.log(hp.p_min), np.log(hp.p_max), 255)).to(torch.device("cuda" if hp.ngpu > 0 else "cpu")) p_quantize = torch.bucketize(f0, bins) #p_quantize = p_quantize - 1 # -1 to convert 1 to 256 --> 0 to 255 return F.one_hot(p_quantize.long(), 256).float()
def __call__(self, points_to_interp): assert self.points is not None assert self.values is not None assert len(points_to_interp) == len(self.points) K = points_to_interp[0].shape[0] for x in points_to_interp: assert x.shape[0] == K idxs = [] dists = [] overalls = [] for p, x in zip(self.points, points_to_interp): idx_right = torch.bucketize(x, p) idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1 idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1) dist_left = x - p[idx_left] dist_right = p[idx_right] - x dist_left[dist_left < 0] = 0. dist_right[dist_right < 0] = 0. both_zero = (dist_left == 0) & (dist_right == 0) dist_left[both_zero] = dist_right[both_zero] = 1. idxs.append((idx_left, idx_right)) dists.append((dist_left, dist_right)) overalls.append(dist_left + dist_right) numerator = 0. for indexer in product([0, 1], repeat=self.n): as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)] bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)] numerator += self.values[as_s] * \ torch.prod(torch.stack(bs_s), dim=0) denominator = torch.prod(torch.stack(overalls), dim=0) return numerator / denominator
def to_one_hot(self, x: torch.Tensor): # e = de_norm_mean_std(e, hp.e_mean, hp.e_std) # For pytorch > = 1.6.0 quantize = torch.bucketize(x, self.pitch_bins).to( device=x.device) #.cuda() return F.one_hot(quantize.long(), 256).float()
def _interpret_t(self, t): maxlen = self._b.size(-2) - 1 index = torch.bucketize(t.detach(), self._t) - 1 index = index.clamp(0, maxlen) # clamp because t may go outside of [t[0], t[-1]]; this is fine # will never access the last element of self._t; this is correct behaviour fractional_part = t - self._t[index] return fractional_part, index
def quantile_weights(input, quantiles, weights, scale): ''' input, weights: 1D array scale: scalar ''' in_sorted, indices = torch.sort(input) weight_sorted = weights[indices] / torch.sum(weights) boundary = torch.cumsum(weight_sorted, dim=0) boundary = boundary - weight_sorted / 2. boundary = torch.cat((torch.zeros(1), boundary, torch.ones(1)), dim=0) weight_sorted = torch.cat((torch.zeros(1), weight_sorted, torch.zeros(1)), dim=0) in_sorted = torch.cat((in_sorted[0].reshape(1) - scale, in_sorted, in_sorted[-1].reshape(1) + scale), dim=0) ceiled = torch.bucketize(quantiles, boundary) ceiled[ceiled < 1] = 1 floored = ceiled - 1 ceiled[ceiled > len(boundary) - 1] = 0 weight_ceiled = (quantiles - boundary[floored]) / ( weight_sorted[floored] + weight_sorted[ceiled]) * 2. weight_floored = 1.0 - weight_ceiled d0 = in_sorted[floored.long()] * weight_floored d1 = in_sorted[ceiled.long()] * weight_ceiled result = d0 + d1 return result
def _interpret_t(self, t): t = torch.as_tensor(t, dtype=self._b.dtype, device=self._b.device) maxlen = self._b.size(-2) - 1 # clamp because t may go outside of [t[0], t[-1]]; this is fine index = torch.bucketize(t.detach(), self._t.detach()).sub(1).clamp(0, maxlen) # will never access the last element of self._t; this is correct behaviour fractional_part = t - self._t[index] return fractional_part, index
def encode_survival(time: Union[float, int, TensorOrArray], event: Union[int, bool, TensorOrArray], bins: TensorOrArray) -> torch.Tensor: """Encodes survival time and event indicator in the format required for MTLR training. For uncensored instances, one-hot encoding of binned survival time is generated. Censoring is handled differently, with all possible values for event time encoded as 1s. For example, if 5 time bins are used, an instance experiencing event in bin 3 is encoded as [0, 0, 0, 1, 0], and instance censored in bin 2 as [0, 0, 1, 1, 1]. Note that an additional 'catch-all' bin is added, spanning the range `(bins.max(), inf)`. Parameters ---------- time Time of event or censoring. event Event indicator (0 = censored). bins Bins used for time axis discretisation. Returns ------- torch.Tensor Encoded survival times. """ # TODO this should handle arrays and (CUDA) tensors if isinstance(time, (float, int, np.ndarray)): time = np.atleast_1d(time) time = torch.tensor(time) if isinstance(event, (int, bool, np.ndarray)): event = np.atleast_1d(event) event = torch.tensor(event) if isinstance(bins, np.ndarray): bins = torch.tensor(bins) try: device = bins.device except AttributeError: device = "cpu" time = np.clip(time, 0, bins.max()) # add extra bin [max_time, inf) at the end y = torch.zeros((time.shape[0], bins.shape[0] + 1), dtype=torch.float, device=device) # For some reason, the `right` arg in torch.bucketize # works in the _opposite_ way as it does in numpy, # so we need to set it to True bin_idxs = torch.bucketize(time, bins, right=True) for i, (bin_idx, e) in enumerate(zip(bin_idxs, event)): if e == 1: y[i, bin_idx] = 1 else: y[i, bin_idx:] = 1 return y.squeeze()
def _quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]: n_bins = 2**n_bits borders = torch.as_tensor( _quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1]) quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1) lookup = average_buckets(tensor, quant_weight, n_bins) return quant_weight, lookup
def forward(self, x, src_mask, mel_mask=None, duration_target=None, pitch_target=None, energy_target=None, max_len=None): ## Duration Predictor ## log_duration_prediction = self.duration_predictor(x, src_mask) if duration_target is not None: x, mel_len = self.length_regulator(x, duration_target, max_len) else: duration_rounded = torch.clamp(torch.round( torch.exp(log_duration_prediction) - hp.log_offset), min=0) x, mel_len = self.length_regulator(x, duration_rounded, max_len) mel_mask = utils.get_mask_from_lengths(mel_len) ## Pitch Predictor ## pitch_prediction = self.pitch_predictor(x, mel_mask) if pitch_target is not None: pitch_embedding = self.pitch_embedding( torch.bucketize(pitch_target.detach(), self.pitch_bins.detach())) else: pitch_embedding = self.pitch_embedding( torch.bucketize(pitch_prediction.detach(), self.pitch_bins.detach())) ## Energy Predictor ## energy_prediction = self.energy_predictor(x, mel_mask) if energy_target is not None: energy_embedding = self.energy_embedding( torch.bucketize(energy_target.detach(), self.energy_bins.detach())) else: energy_embedding = self.energy_embedding( torch.bucketize(energy_prediction.detach(), self.energy_bins.detach())) x = x + pitch_embedding + energy_embedding return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask
def forward(self, x, src_mask, mel_mask=None, duration_target=None, pitch_target=None, energy_target=None, max_len=None, d_control=1.0, p_control=1.0, e_control=1.0): log_duration_prediction = self.duration_predictor(x, src_mask) if duration_target is not None: x, mel_len = self.length_regulator(x, duration_target, max_len) else: duration_rounded = torch.clamp((torch.round( torch.exp(log_duration_prediction) - hp.log_offset) * d_control), min=0) x, mel_len = self.length_regulator(x, duration_rounded, max_len) mel_mask = utils.get_mask_from_lengths(mel_len) pitch_prediction = self.pitch_predictor(x, mel_mask) if pitch_target is not None: pitch_embedding = self.pitch_embedding( torch.bucketize(pitch_target, self.pitch_bins)) else: pitch_prediction = pitch_prediction * p_control pitch_embedding = self.pitch_embedding( torch.bucketize(pitch_prediction, self.pitch_bins)) energy_prediction = self.energy_predictor(x, mel_mask) if energy_target is not None: energy_embedding = self.energy_embedding( torch.bucketize(energy_target, self.energy_bins)) else: energy_prediction = energy_prediction * e_control energy_embedding = self.energy_embedding( torch.bucketize(energy_prediction, self.energy_bins)) x = x + pitch_embedding + energy_embedding return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask
def get_loss(self, y, y_hat): """ Do the applicable processing and return the loss for the supplied prediction \ and the label tensors. :param y: Label tensor :type y: torch.Tensor :param y_hat: Predicted tensor :type y_hat: torch.Tensor :return: Prediction loss :rtype: torch.Tensor """ if self.hparams.undersample: sub_mask = y < self.hparams.undersample subval = y[sub_mask] low = max(subval.min(), 0.5) high = subval.max() boundaries = torch.arange(low, high, (high - low) / 10).to( self.model.device ) freq_idx = torch.bucketize(subval, boundaries[:-1], right=False) self.undersampler.fit_resample( subval.cpu().unsqueeze(-1), (boundaries.take(index=freq_idx).cpu() * 100).int(), ) idx = self.undersampler.sample_indices_ y = torch.cat((y[~sub_mask], subval[idx])) y_hat = torch.cat((y_hat[~sub_mask], y_hat[sub_mask][idx])) if self.hparams.round_to_zero: y_hat = y_hat[y > self.hparams.round_to_zero] y = y[y > self.hparams.round_to_zero] if self.hparams.clip_output: y_hat = y_hat[ (y < self.hparams.clip_output[-1]) & (self.hparams.clip_output[0] < y) ] y = y[ (y < self.hparams.clip_output[-1]) & (self.hparams.clip_output[0] < y) ] if self.hparams.cb_loss: loss_factor = self.get_cb_loss_factor(y) if self.hparams.boxcox: y = torch.from_numpy(boxcox(y.cpu(), lmbda=self.hparams.boxcox,)).to( y.device ) pre_loss = (y_hat - y) ** 2 # if "loss_factor" in locals(): # pre_loss *= loss_factor loss = pre_loss.mean() assert loss == loss return loss
def color_bin(input, color_bins, bin_scale, n_tot): """ weighted sum of all pixels """ batch_size, channels, height, width = input.size() all_col = [] for i in range(batch_size): bucketed = torch.bucketize(input[i],color_bins) bins = torch.einsum('bcd, b -> cd', bucketed, bin_scale).to(torch.float) col_hist = torch.histc(bins,n_tot,min=0,max=n_tot-1)/(height*width) all_col.append(col_hist) return torch.stack(all_col)
def probsample(n_samples, prob, return_counts=True): if return_counts: bins = torch.cat([torch.zeros(1), torch.cumsum(prob, dim=0)]) bins[-1] = 1 res = torch.histogram(torch.rand(n_samples), bins) return res.hist bins = torch.cumsum(prob, dim=0) bins[-1] = 1 return torch.bucketize(torch.rand((n_samples, )), bins)
def energy_to_one_hot(e, is_training=True): # e = de_norm_mean_std(e, hp.e_mean, hp.e_std) # For pytorch > = 1.6.0 bins = torch.linspace(hp.e_min, hp.e_max, steps=255).to( torch.device("cuda" if hp.ngpu > 0 else "cpu")) e_quantize = torch.bucketize(e, bins) return F.one_hot(e_quantize.long(), 256).float()
def get_bucketed_distance_matrix(coords, mask): distances = torch.cdist(coords, coords, p=2) boundaries = torch.linspace(2, 20, steps=DISTOGRAM_BUCKETS, device=coords.device) discretized_distances = torch.bucketize(distances, boundaries[:-1]) discretized_distances.masked_fill_(~(mask[:, :, None] & mask[:, None, :]), IGNORE_INDEX) return discretized_distances
def get_bucketed_distance_matrix(coords, mask, num_buckets=constants.DISTOGRAM_BUCKETS, ignore_index=-100): distances = torch.cdist(coords, coords, p=2) boundaries = torch.linspace(2, 20, steps=num_buckets, device=coords.device) discretized_distances = torch.bucketize(distances, boundaries[:-1]) discretized_distances.masked_fill_(~(mask[..., None] & mask[..., None, :]), ignore_index) return discretized_distances
def forward(self, x, src_mask, mel_mask=None, duration_target=None, pitch_target=None, energy_target=None, max_len=None): log_duration_prediction = self.duration_predictor(x, src_mask) if duration_target is not None: x, mel_len = self.length_regulator(x, duration_target, max_len) else: duration_rounded = torch.clamp(torch.round( torch.exp(log_duration_prediction) - hp.log_offset), min=0) x, mel_len = self.length_regulator(x, duration_rounded, max_len) mel_mask = utils.get_mask_from_lengths(mel_len) pitch_prediction = self.pitch_predictor(x, mel_mask) if pitch_target is not None: pitch_embedding = self.pitch_embedding( torch.bucketize(pitch_target, self.pitch_bins)) else: pitch_embedding = self.pitch_embedding( torch.bucketize(pitch_prediction, self.pitch_bins)) energy_prediction = self.energy_predictor(x, mel_mask) if energy_target is not None: energy_embedding = self.energy_embedding( torch.bucketize(energy_target, self.energy_bins)) else: energy_embedding = self.energy_embedding( torch.bucketize(energy_prediction, self.energy_bins)) #print('pitch_target:',pitch_target.size()) #自加:check the dim of x and f are met #print('x:',x.size()) #print('pitch:',pitch_embedding.size()) #print('energy:',energy_embedding.size()) x = x + pitch_embedding + energy_embedding return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask