Exemplo n.º 1
0
def dist_custom_collate(batch, dist_bins=None, angle_bins=None, crop_size=64):
    
    
    feat = torch.zeros(len(batch), crop_size, 300, 22)
    feat_mask = torch.ones(len(batch), crop_size, 300)

    
    dist_labels = torch.zeros(len(batch), crop_size, crop_size)
    
    phi_labels = torch.zeros(len(batch), crop_size)
    psi_labels = torch.zeros(len(batch), crop_size)
    
    for i in range(len(batch)):
        
        seq_len = batch[i][0].shape[1]
        msa_len = batch[i][0].shape[0]
        
        sampled_msa = torch.randperm(msa_len)[:300]
        
        if msa_len > 300:
            msa_len = 300
        
        if seq_len <= crop_size:
            crop = 0
            
            
        else:    
            crop = random.randint(0, seq_len - crop_size)
            seq_len = crop_size
        
        
        
        one_hot = torch.nn.functional.one_hot(batch[i][0][sampled_msa], num_classes=22).permute(1,0,2)
        
        feat[i, :seq_len, :msa_len] = one_hot[crop:crop+seq_len]
        feat_mask[i, seq_len:, msa_len:] = 0
        
        dist_labels[i, :seq_len, :seq_len] = batch[i][1][crop:crop+seq_len, crop:crop+seq_len,0]
        dist_labels[i][dist_labels[i] == 0] = -100

        phi_labels[i, :seq_len] = batch[i][2][crop:crop+seq_len, 0, 0]
        phi_labels[i, :seq_len][batch[i][2][crop:crop+seq_len, 0, 1] == 0] = -100                 

        psi_labels[i, :seq_len] = batch[i][2][crop:crop+seq_len, 1, 0]
        psi_labels[i, :seq_len][batch[i][2][crop:crop+seq_len, 1, 1] == 0] = -100                 
        
    if dist_bins != None:
        dist_labels[dist_labels != -100] = torch.bucketize(dist_labels[dist_labels != -100], dist_bins).float()
        dist_labels[(dist_labels == (dist_bins.shape[0]))] = -100
        dist_labels[(dist_labels == 0)] = -100
        #dist_labels[dist_labels != -100] -= 1
        dist_labels = dist_labels.long()

    if angle_bins != None:
        phi_labels[phi_labels != -100] = torch.bucketize(phi_labels[phi_labels != -100], angle_bins).float() - 1        
        psi_labels[psi_labels != -100] = torch.bucketize(psi_labels[psi_labels != -100], angle_bins).float() - 1
        phi_labels = phi_labels.long()
        psi_labels = psi_labels.long()
        
    return feat,  feat_mask, dist_labels, phi_labels, psi_labels
Exemplo n.º 2
0
 def get_energy_emb(self, x, tgt=None, factor=1.0):
     out = self.energy_predictor(x)
     bins = self.energy_bins.to(x.device)
     if tgt is None:
         out = out * factor
         emb = self.embed_energy(torch.bucketize(out, bins))
     else:
         emb = self.embed_energy(torch.bucketize(tgt, bins))
     return out, emb
Exemplo n.º 3
0
 def get_energy_embedding(self, x, target, mask, control):
     prediction = self.energy_predictor(x, mask)
     if target is not None:
         embedding = self.energy_embedding(
             torch.bucketize(target, self.energy_bins))
     else:
         prediction = prediction * control
         embedding = self.energy_embedding(
             torch.bucketize(prediction, self.energy_bins))
     return prediction, embedding
Exemplo n.º 4
0
 def get_energy_embedding(
         self, x: torch.Tensor, target: Optional[torch.Tensor],
         mask: torch.Tensor,
         control: float) -> Tuple[torch.Tensor, torch.Tensor]:
     prediction = self.energy_predictor(x, mask)
     if target is not None:
         embedding = self.energy_embedding(
             torch.bucketize(target, self.energy_bins))
     else:
         prediction = prediction * control
         embedding = self.energy_embedding(
             torch.bucketize(prediction, self.energy_bins))
     return prediction, embedding
Exemplo n.º 5
0
    def forward(self,
                x,
                src_mask,
                mel_mask=None,
                duration_target=None,
                pitch_target=None,
                energy_target=None,
                max_len=None):

        log_duration_prediction = self.duration_predictor(x, src_mask)
        if duration_target is not None:
            if mel_mask is not None:
                max_len = mel_mask.shape[2]
            x, mel_len = self.length_regulator(x, duration_target, max_len)
        else:
            duration_rounded = torch.clamp(torch.round(
                torch.exp(log_duration_prediction) - self.log_offset),
                                           min=0)
            x, mel_len = self.length_regulator(x, duration_rounded, max_len)
            # TODO: get_mask_from_lengths
            mel_mask = get_mask_from_lengths(mel_len)

        if self.pitch_pred:
            pitch_prediction = self.pitch_predictor(x, mel_mask)
            if pitch_target is not None:
                pitch_embedding = self.pitch_embedding(
                    torch.bucketize(pitch_target, self.pitch_bins))
            else:
                pitch_embedding = self.pitch_embedding(
                    torch.bucketize(pitch_prediction, self.pitch_bins))
        else:
            pitch_prediction = None

        if self.energy_pred:
            energy_prediction = self.energy_predictor(x, mel_mask)
            if energy_target is not None:
                energy_embedding = self.energy_embedding(
                    torch.bucketize(energy_target, self.energy_bins))
            else:
                energy_embedding = self.energy_embedding(
                    torch.bucketize(energy_prediction, self.energy_bins))
        else:
            energy_prediction = None

        if self.pitch_pred:
            x = x + pitch_embedding
        if self.energy_pred:
            x = x + energy_embedding
        # x = x + pitch_embedding + energy_embedding

        return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask
Exemplo n.º 6
0
def neighbour_list(pos, box, subcell_size):
    nsystems = coordinates.shape[0]

    for s in range(nsystems):
        spos = pos[s]
        sbox = box[s]

        xbins, ybins, zbins = discretize_box(sbox, subcell_size)

        xidx = torch.bucketize(spos[:, 0], xbins, out_int32=True)
        yidx = torch.bucketize(spos[:, 1], ybins, out_int32=True)
        zidx = torch.bucketize(spos[:, 2], zbins, out_int32=True)

        binidx = torch.stack((xidx, yidx, zidx)).T
Exemplo n.º 7
0
    def predict_inference(self,
                          text_encoding,
                          pitch_encoding,
                          energy_encoding,
                          duration_encoding,
                          speaker_encoding,
                          noise_encoding,
                          src_mask,
                          max_len,
                          speaker_normalized=True,
                          d_control=1.0,
                          p_control=1.0,
                          e_control=1.0):
        encodings_cat = torch.cat(
            (text_encoding, pitch_encoding, speaker_encoding, energy_encoding,
             noise_encoding),
            dim=-1)

        # Duration
        log_duration_prediction = self.duration_predictor(
            duration_encoding, src_mask)  # [batch_size, src_len]
        duration_rounded = torch.clamp(
            (torch.round(torch.exp(log_duration_prediction) - hp.log_offset) *
             d_control),
            min=0)
        encodings_cat, mel_len = self.length_regulator(encodings_cat,
                                                       duration_rounded,
                                                       max_len)
        mel_mask = utils.get_mask_from_lengths(mel_len)

        text_encoding, pitch_encoding, speaker_encoding, energy_encoding, noise_encoding = torch.split(
            encodings_cat, hp.encoder_hidden, dim=-1)

        # Energy
        energy_prediction = self.energy_predictor(energy_encoding, mel_mask)
        energy_prediction = energy_prediction * e_control
        energy_embedding = self.energy_embedding(
            torch.bucketize(energy_prediction, self.energy_bins))

        # Pitch
        pitch_prediction = self.pitch_predictor(
            pitch_encoding if speaker_normalized else
            (pitch_encoding + speaker_encoding), mel_mask)
        pitch_prediction = pitch_prediction * p_control
        pitch_embedding = self.pitch_embedding(
            torch.bucketize(pitch_prediction, self.pitch_bins))

        return text_encoding, pitch_embedding, speaker_encoding, energy_embedding, noise_encoding, log_duration_prediction, pitch_prediction, energy_prediction, mel_mask
Exemplo n.º 8
0
    def forward(self, predicted_patches, target, mask):
        # reshape target to patches
        p = self.patch_size
        target = rearrange(target,
                           "b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
                           p1=p,
                           p2=p)

        avg_target = target.mean(dim=3)

        bin_size = self.max_pixel_val / self.output_channel_bits
        channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
        discretized_target = torch.bucketize(avg_target, channel_bins)
        discretized_target = F.one_hot(discretized_target,
                                       self.output_channel_bits)
        c, bi = self.channels, self.output_channel_bits
        discretized_target = rearrange(discretized_target,
                                       "b n c bi -> b n (c bi)",
                                       c=c,
                                       bi=bi)

        bin_mask = 2**torch.arange(c * bi - 1, -1,
                                   -1).to(discretized_target.device,
                                          discretized_target.dtype)
        target_label = torch.sum(bin_mask * discretized_target, -1)

        predicted_patches = predicted_patches[mask]
        target_label = target_label[mask]
        loss = F.cross_entropy(predicted_patches, target_label)
        return loss
Exemplo n.º 9
0
    def forward(self, predicted_patches, target, mask):
        p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
        bin_size = mpv / (2**bits)

        # un-normalize input
        if exists(self.mean) and exists(self.std):
            target = target * self.std + self.mean

        # reshape target to patches
        target = target.clamp(max=mpv)  # clamp just in case
        avg_target = reduce(target,
                            'b c (h p1) (w p2) -> b (h w) c',
                            'mean',
                            p1=p,
                            p2=p).contiguous()

        channel_bins = torch.arange(bin_size, mpv, bin_size, device=device)
        discretized_target = torch.bucketize(avg_target, channel_bins)

        bin_mask = (2**bits)**torch.arange(0, c, device=device).long()
        bin_mask = rearrange(bin_mask, 'c -> () () c')

        target_label = torch.sum(bin_mask * discretized_target, dim=-1)

        loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
        return loss
Exemplo n.º 10
0
def interpolate_cubic_derivative(coeffs, points, new_points):
    index = torch.bucketize(new_points, points) - 1
    index = index.clamp(0, coeffs.shape[-1]-1)
    t = (new_points - points[index]) / (points[index+1] - points[index])
    ret = 3*coeffs[0, index] * t + 2*coeffs[1, index]
    ret = ret * t + coeffs[2, index]
    return ret
Exemplo n.º 11
0
def vector_to_edge_index(idx: Tensor,
                         size: Tuple[int, int],
                         bipartite: bool,
                         force_undirected: bool = False) -> Tensor:

    if bipartite:  # No need to account for self-loops.
        row = idx.div(size[1], rounding_mode='floor')
        col = idx % size[1]
        return torch.stack([row, col], dim=0)

    elif force_undirected:
        assert size[0] == size[1]
        num_nodes = size[0]

        offset = torch.arange(1, num_nodes, device=idx.device).cumsum(0)
        end = torch.arange(num_nodes,
                           num_nodes * num_nodes,
                           num_nodes,
                           device=idx.device)
        row = torch.bucketize(idx, end.sub_(offset), right=True)
        col = offset[row].add_(idx) % num_nodes
        return torch.stack([torch.cat([row, col]), torch.cat([col, row])], 0)

    else:
        assert size[0] == size[1]
        num_nodes = size[0]

        row = idx.div(num_nodes - 1, rounding_mode='floor')
        col = idx % (num_nodes - 1)
        col[row <= col] += 1
        return torch.stack([row, col], dim=0)
Exemplo n.º 12
0
def filter_spikes(x, bounds=[-50, 150]):
    """
    Return interpolated signal and mask vanishing outside NaNs.
    """
    # find spikes
    spikes, mask = find_spikes(x, bounds)
    if not len(spikes): return x, mask
    # interval boundary neighbours
    neighb = spikes + torch.tensor([-2, 2])
    if mask[0]: neighb[0, 0] = neighb[0, 1]
    if mask[-1]: neighb[-1, -1] = neighb[-1, 0]
    # map domain to segment ends
    N = x.shape[0]
    idx = torch.arange(N)
    buckets = torch.cat(
        [torch.tensor([0]),
         neighb.reshape([-1]),
         torch.tensor([N])])
    interval = torch.bucketize(idx, neighb.reshape([-1]))
    i0 = bound(buckets[interval], [0, N - 1])
    i1 = bound(buckets[interval + 1], [0, N - 1])
    # interpolation
    x0 = x[torch.max(i0, neighb[0, 0])]
    x1 = x[torch.min(i1, neighb[-1, 1])]
    slope = (x1 - x0) / (i1 - i0)
    offset = x0 - slope * i0
    x_int = mask * (slope * idx + offset)
    x_int += ~mask * x
    return x_int, mask
Exemplo n.º 13
0
def energy_to_one_hot(e):
    # For pytorch > = 1.6.0
    bins = torch.linspace(hp.e_min, hp.e_max, steps=255).to(torch.device("cuda" if hp.ngpu > 0 else "cpu"))

    e_quantize = torch.bucketize(e, bins)

    return F.one_hot(e_quantize.long(), 256).float()
Exemplo n.º 14
0
def pitch_to_one_hot(f0, is_training = True):
    # Required pytorch >= 1.6.0

    bins = torch.exp(torch.linspace(np.log(hp.p_min), np.log(hp.p_max), 255)).to(torch.device("cuda" if hp.ngpu > 0 else "cpu"))
    p_quantize = torch.bucketize(f0, bins)
    #p_quantize = p_quantize - 1  # -1 to convert 1 to 256 --> 0 to 255
    return F.one_hot(p_quantize.long(), 256).float()
Exemplo n.º 15
0
    def __call__(self, points_to_interp):
        assert self.points is not None
        assert self.values is not None

        assert len(points_to_interp) == len(self.points)
        K = points_to_interp[0].shape[0]
        for x in points_to_interp:
            assert x.shape[0] == K

        idxs = []
        dists = []
        overalls = []
        for p, x in zip(self.points, points_to_interp):
            idx_right = torch.bucketize(x, p)
            idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1
            idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1)
            dist_left = x - p[idx_left]
            dist_right = p[idx_right] - x
            dist_left[dist_left < 0] = 0.
            dist_right[dist_right < 0] = 0.
            both_zero = (dist_left == 0) & (dist_right == 0)
            dist_left[both_zero] = dist_right[both_zero] = 1.

            idxs.append((idx_left, idx_right))
            dists.append((dist_left, dist_right))
            overalls.append(dist_left + dist_right)

        numerator = 0.
        for indexer in product([0, 1], repeat=self.n):
            as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)]
            bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)]
            numerator += self.values[as_s] * \
                torch.prod(torch.stack(bs_s), dim=0)
        denominator = torch.prod(torch.stack(overalls), dim=0)
        return numerator / denominator
Exemplo n.º 16
0
    def to_one_hot(self, x: torch.Tensor):
        # e = de_norm_mean_std(e, hp.e_mean, hp.e_std)
        # For pytorch > = 1.6.0

        quantize = torch.bucketize(x, self.pitch_bins).to(
            device=x.device)  #.cuda()
        return F.one_hot(quantize.long(), 256).float()
Exemplo n.º 17
0
 def _interpret_t(self, t):
     maxlen = self._b.size(-2) - 1
     index = torch.bucketize(t.detach(), self._t) - 1
     index = index.clamp(0, maxlen)  # clamp because t may go outside of [t[0], t[-1]]; this is fine
     # will never access the last element of self._t; this is correct behaviour
     fractional_part = t - self._t[index]
     return fractional_part, index
Exemplo n.º 18
0
def quantile_weights(input, quantiles, weights, scale):
    '''
    input, weights: 1D array
    scale: scalar
    '''
    in_sorted, indices = torch.sort(input)
    weight_sorted = weights[indices] / torch.sum(weights)
    boundary = torch.cumsum(weight_sorted, dim=0)
    boundary = boundary - weight_sorted / 2.
    boundary = torch.cat((torch.zeros(1), boundary, torch.ones(1)), dim=0)
    weight_sorted = torch.cat((torch.zeros(1), weight_sorted, torch.zeros(1)),
                              dim=0)
    in_sorted = torch.cat((in_sorted[0].reshape(1) - scale, in_sorted,
                           in_sorted[-1].reshape(1) + scale),
                          dim=0)
    ceiled = torch.bucketize(quantiles, boundary)
    ceiled[ceiled < 1] = 1
    floored = ceiled - 1
    ceiled[ceiled > len(boundary) - 1] = 0
    weight_ceiled = (quantiles - boundary[floored]) / (
        weight_sorted[floored] + weight_sorted[ceiled]) * 2.
    weight_floored = 1.0 - weight_ceiled
    d0 = in_sorted[floored.long()] * weight_floored
    d1 = in_sorted[ceiled.long()] * weight_ceiled
    result = d0 + d1
    return result
Exemplo n.º 19
0
 def _interpret_t(self, t):
     t = torch.as_tensor(t, dtype=self._b.dtype,  device=self._b.device)
     maxlen = self._b.size(-2) - 1
     # clamp because t may go outside of [t[0], t[-1]]; this is fine
     index = torch.bucketize(t.detach(), self._t.detach()).sub(1).clamp(0, maxlen)
     # will never access the last element of self._t; this is correct behaviour
     fractional_part = t - self._t[index]
     return fractional_part, index
Exemplo n.º 20
0
def encode_survival(time: Union[float, int, TensorOrArray],
                    event: Union[int, bool, TensorOrArray],
                    bins: TensorOrArray) -> torch.Tensor:
    """Encodes survival time and event indicator in the format
    required for MTLR training.

    For uncensored instances, one-hot encoding of binned survival time
    is generated. Censoring is handled differently, with all possible
    values for event time encoded as 1s. For example, if 5 time bins are used,
    an instance experiencing event in bin 3 is encoded as [0, 0, 0, 1, 0], and
    instance censored in bin 2 as [0, 0, 1, 1, 1]. Note that an additional
    'catch-all' bin is added, spanning the range `(bins.max(), inf)`.

    Parameters
    ----------
    time
        Time of event or censoring.
    event
        Event indicator (0 = censored).
    bins
        Bins used for time axis discretisation.

    Returns
    -------
    torch.Tensor
        Encoded survival times.
    """
    # TODO this should handle arrays and (CUDA) tensors
    if isinstance(time, (float, int, np.ndarray)):
        time = np.atleast_1d(time)
        time = torch.tensor(time)
    if isinstance(event, (int, bool, np.ndarray)):
        event = np.atleast_1d(event)
        event = torch.tensor(event)

    if isinstance(bins, np.ndarray):
        bins = torch.tensor(bins)

    try:
        device = bins.device
    except AttributeError:
        device = "cpu"

    time = np.clip(time, 0, bins.max())
    # add extra bin [max_time, inf) at the end
    y = torch.zeros((time.shape[0], bins.shape[0] + 1),
                    dtype=torch.float,
                    device=device)
    # For some reason, the `right` arg in torch.bucketize
    # works in the _opposite_ way as it does in numpy,
    # so we need to set it to True
    bin_idxs = torch.bucketize(time, bins, right=True)
    for i, (bin_idx, e) in enumerate(zip(bin_idxs, event)):
        if e == 1:
            y[i, bin_idx] = 1
        else:
            y[i, bin_idx:] = 1
    return y.squeeze()
Exemplo n.º 21
0
def _quantile_encode_approx(tensor: torch.Tensor,
                            n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
    n_bins = 2**n_bits
    borders = torch.as_tensor(
        _quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
    quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0,
                                n_bins - 1)
    lookup = average_buckets(tensor, quant_weight, n_bins)
    return quant_weight, lookup
Exemplo n.º 22
0
    def forward(self,
                x,
                src_mask,
                mel_mask=None,
                duration_target=None,
                pitch_target=None,
                energy_target=None,
                max_len=None):
        ## Duration Predictor ##
        log_duration_prediction = self.duration_predictor(x, src_mask)
        if duration_target is not None:
            x, mel_len = self.length_regulator(x, duration_target, max_len)
        else:
            duration_rounded = torch.clamp(torch.round(
                torch.exp(log_duration_prediction) - hp.log_offset),
                                           min=0)
            x, mel_len = self.length_regulator(x, duration_rounded, max_len)
            mel_mask = utils.get_mask_from_lengths(mel_len)

        ## Pitch Predictor ##
        pitch_prediction = self.pitch_predictor(x, mel_mask)
        if pitch_target is not None:
            pitch_embedding = self.pitch_embedding(
                torch.bucketize(pitch_target.detach(),
                                self.pitch_bins.detach()))
        else:
            pitch_embedding = self.pitch_embedding(
                torch.bucketize(pitch_prediction.detach(),
                                self.pitch_bins.detach()))

        ## Energy Predictor ##
        energy_prediction = self.energy_predictor(x, mel_mask)
        if energy_target is not None:
            energy_embedding = self.energy_embedding(
                torch.bucketize(energy_target.detach(),
                                self.energy_bins.detach()))
        else:
            energy_embedding = self.energy_embedding(
                torch.bucketize(energy_prediction.detach(),
                                self.energy_bins.detach()))

        x = x + pitch_embedding + energy_embedding

        return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask
Exemplo n.º 23
0
    def forward(self,
                x,
                src_mask,
                mel_mask=None,
                duration_target=None,
                pitch_target=None,
                energy_target=None,
                max_len=None,
                d_control=1.0,
                p_control=1.0,
                e_control=1.0):

        log_duration_prediction = self.duration_predictor(x, src_mask)
        if duration_target is not None:
            x, mel_len = self.length_regulator(x, duration_target, max_len)
        else:
            duration_rounded = torch.clamp((torch.round(
                torch.exp(log_duration_prediction) - hp.log_offset) *
                                            d_control),
                                           min=0)
            x, mel_len = self.length_regulator(x, duration_rounded, max_len)
            mel_mask = utils.get_mask_from_lengths(mel_len)

        pitch_prediction = self.pitch_predictor(x, mel_mask)
        if pitch_target is not None:
            pitch_embedding = self.pitch_embedding(
                torch.bucketize(pitch_target, self.pitch_bins))
        else:
            pitch_prediction = pitch_prediction * p_control
            pitch_embedding = self.pitch_embedding(
                torch.bucketize(pitch_prediction, self.pitch_bins))

        energy_prediction = self.energy_predictor(x, mel_mask)
        if energy_target is not None:
            energy_embedding = self.energy_embedding(
                torch.bucketize(energy_target, self.energy_bins))
        else:
            energy_prediction = energy_prediction * e_control
            energy_embedding = self.energy_embedding(
                torch.bucketize(energy_prediction, self.energy_bins))

        x = x + pitch_embedding + energy_embedding

        return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask
Exemplo n.º 24
0
    def get_loss(self, y, y_hat):
        """
        Do the applicable processing and return the loss for the supplied prediction \
and the label tensors.

        :param y: Label tensor
        :type y: torch.Tensor
        :param y_hat: Predicted tensor
        :type y_hat: torch.Tensor
        :return: Prediction loss
        :rtype: torch.Tensor
        """
        if self.hparams.undersample:
            sub_mask = y < self.hparams.undersample
            subval = y[sub_mask]
            low = max(subval.min(), 0.5)
            high = subval.max()
            boundaries = torch.arange(low, high, (high - low) / 10).to(
                self.model.device
            )
            freq_idx = torch.bucketize(subval, boundaries[:-1], right=False)
            self.undersampler.fit_resample(
                subval.cpu().unsqueeze(-1),
                (boundaries.take(index=freq_idx).cpu() * 100).int(),
            )
            idx = self.undersampler.sample_indices_
            y = torch.cat((y[~sub_mask], subval[idx]))
            y_hat = torch.cat((y_hat[~sub_mask], y_hat[sub_mask][idx]))

        if self.hparams.round_to_zero:
            y_hat = y_hat[y > self.hparams.round_to_zero]
            y = y[y > self.hparams.round_to_zero]

        if self.hparams.clip_output:
            y_hat = y_hat[
                (y < self.hparams.clip_output[-1]) & (self.hparams.clip_output[0] < y)
            ]
            y = y[
                (y < self.hparams.clip_output[-1]) & (self.hparams.clip_output[0] < y)
            ]

        if self.hparams.cb_loss:
            loss_factor = self.get_cb_loss_factor(y)

        if self.hparams.boxcox:
            y = torch.from_numpy(boxcox(y.cpu(), lmbda=self.hparams.boxcox,)).to(
                y.device
            )

        pre_loss = (y_hat - y) ** 2
        # if "loss_factor" in locals():
        #     pre_loss *= loss_factor
        loss = pre_loss.mean()
        assert loss == loss

        return loss
Exemplo n.º 25
0
def color_bin(input, color_bins, bin_scale, n_tot):
    """ weighted sum of all pixels """
    batch_size, channels, height, width = input.size()
    all_col = []
    for i in range(batch_size):
        bucketed = torch.bucketize(input[i],color_bins)
        bins = torch.einsum('bcd, b -> cd', bucketed, bin_scale).to(torch.float)
        col_hist = torch.histc(bins,n_tot,min=0,max=n_tot-1)/(height*width)
        all_col.append(col_hist)
    return torch.stack(all_col)
Exemplo n.º 26
0
def probsample(n_samples, prob, return_counts=True):
    if return_counts:
        bins = torch.cat([torch.zeros(1), torch.cumsum(prob, dim=0)])
        bins[-1] = 1
        res = torch.histogram(torch.rand(n_samples), bins)
        return res.hist

    bins = torch.cumsum(prob, dim=0)
    bins[-1] = 1
    return torch.bucketize(torch.rand((n_samples, )), bins)
Exemplo n.º 27
0
def energy_to_one_hot(e, is_training=True):

    # e = de_norm_mean_std(e, hp.e_mean, hp.e_std)
    # For pytorch > = 1.6.0
    bins = torch.linspace(hp.e_min, hp.e_max, steps=255).to(
        torch.device("cuda" if hp.ngpu > 0 else "cpu"))

    e_quantize = torch.bucketize(e, bins)

    return F.one_hot(e_quantize.long(), 256).float()
Exemplo n.º 28
0
def get_bucketed_distance_matrix(coords, mask):
    distances = torch.cdist(coords, coords, p=2)
    boundaries = torch.linspace(2,
                                20,
                                steps=DISTOGRAM_BUCKETS,
                                device=coords.device)
    discretized_distances = torch.bucketize(distances, boundaries[:-1])
    discretized_distances.masked_fill_(~(mask[:, :, None] & mask[:, None, :]),
                                       IGNORE_INDEX)
    return discretized_distances
Exemplo n.º 29
0
def get_bucketed_distance_matrix(coords,
                                 mask,
                                 num_buckets=constants.DISTOGRAM_BUCKETS,
                                 ignore_index=-100):
    distances = torch.cdist(coords, coords, p=2)
    boundaries = torch.linspace(2, 20, steps=num_buckets, device=coords.device)
    discretized_distances = torch.bucketize(distances, boundaries[:-1])
    discretized_distances.masked_fill_(~(mask[..., None] & mask[..., None, :]),
                                       ignore_index)
    return discretized_distances
Exemplo n.º 30
0
    def forward(self,
                x,
                src_mask,
                mel_mask=None,
                duration_target=None,
                pitch_target=None,
                energy_target=None,
                max_len=None):

        log_duration_prediction = self.duration_predictor(x, src_mask)
        if duration_target is not None:
            x, mel_len = self.length_regulator(x, duration_target, max_len)
        else:
            duration_rounded = torch.clamp(torch.round(
                torch.exp(log_duration_prediction) - hp.log_offset),
                                           min=0)
            x, mel_len = self.length_regulator(x, duration_rounded, max_len)
            mel_mask = utils.get_mask_from_lengths(mel_len)

        pitch_prediction = self.pitch_predictor(x, mel_mask)
        if pitch_target is not None:
            pitch_embedding = self.pitch_embedding(
                torch.bucketize(pitch_target, self.pitch_bins))
        else:
            pitch_embedding = self.pitch_embedding(
                torch.bucketize(pitch_prediction, self.pitch_bins))

        energy_prediction = self.energy_predictor(x, mel_mask)
        if energy_target is not None:
            energy_embedding = self.energy_embedding(
                torch.bucketize(energy_target, self.energy_bins))
        else:
            energy_embedding = self.energy_embedding(
                torch.bucketize(energy_prediction, self.energy_bins))

        #print('pitch_target:',pitch_target.size())            #自加:check the dim of x and f are met
        #print('x:',x.size())
        #print('pitch:',pitch_embedding.size())
        #print('energy:',energy_embedding.size())
        x = x + pitch_embedding + energy_embedding

        return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask