def compute(self, data: torch.Tensor) -> torch.Tensor: if self.demean: data = data - nanmean(data).unsqueeze(-1) return data / nansum(data.abs(), dim=1).unsqueeze(-1)
def _spectrogram(x: torch.Tensor) -> torch.Tensor: return x.abs()
def compute(self, data: torch.Tensor) -> torch.Tensor: return data.abs()
def invert_ulaw(x: torch.Tensor, mu: float = 255.0) -> torch.Tensor: return x.sign() * (1 / mu) * ((1 + mu) ** x.abs() - 1)
def forward(self, pred: torch.Tensor, teacher: torch.Tensor, smooth=1.0): teacher = teacher.float() intersection = (pred * teacher).sum((-1, -2)) sum_ = (pred.abs() + pred.abs()).sum((-1, -2)) jaccard = (intersection + smooth) / (sum_ - intersection + smooth) return (1 - jaccard).mean(1).mean(0)
def linear_contribution(x: torch.Tensor) -> torch.Tensor: ax = x.abs() range_01 = ax.le(1) cont = (1 - ax) * range_01.to(dtype=x.dtype) return cont
def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor: range_3sigma = (x.abs() <= 3 * sigma + 1) # Normalization will be done after cont = torch.exp(-x.pow(2) / (2 * sigma**2)) cont = cont * range_3sigma.to(dtype=x.dtype) return cont
def transformed_h(qval: torch.Tensor, eps: float = 1e-2): return qval.sign() * ((qval.abs() + 1).sqrt() - 1) + eps * qval
def transformed_h_reverse(qval: torch.Tensor, eps: float = 1e-2): return qval.sign() * ( (((1 + 4 * eps * (qval.abs() + 1 + eps)).sqrt() - 1) / (2 * eps)).pow(2) - 1 )
def get_nll(self, xin: torch.Tensor, xin_ind: torch.Tensor, weights: torch.Tensor = None, debug=False): """Given an input tensor and the corresponding index tensor (both shapes = (N,D)) computes the average negative likelihood of observing the inputs""" xin_ind = xin_ind.to(self.device) probs = self.get_probs(xin) D = self.dim N = xin.shape[0] batch_ind = np.stack([range(N)] * D, axis=-1) var_ind = np.stack([range(D)] * N, axis=0) prob_in = probs[batch_ind, var_ind, xin_ind] if weights is None: nll_samples = -prob_in.log2().sum(dim=-1) weighted_nll = nll_samples.mean(dim=-1) return weighted_nll else: # multiply each nll by the corresponding weight and take the mean # prob_x = prob_in.prod(-1) # pos_ind = (weights > 0) # neg_ind = ~pos_ind # prob_obs = prob_x ** pos_ind.float() # prob_obs_b = (torch.tensor(1) - prob_x) ** neg_ind.float() # pos_obj = (prob_obs.log2() * weights.abs()).sum(-1) / pos_ind.sum(-1) # neg_obj = (prob_obs_b.log2() * weights.abs()).sum(-1) / neg_ind.sum(-1) # ll = pos_obj + neg_obj # nll = -ll # # return nll eps_tens = torch.tensor(1e-15) prob_x = prob_in.prod(-1) pos_ind = (weights > 0).float() neg_ind = torch.tensor(1) - pos_ind logp_vec = (prob_x + eps_tens).log10() npos = pos_ind.sum(-1) if npos > 0: pos_ll = (logp_vec * weights.abs() * pos_ind).sum(-1) / npos else: pos_ll = (logp_vec * weights.abs() * pos_ind).sum(-1) nneg = neg_ind.sum(-1) if nneg > 0: neg_ll = (logp_vec * weights.abs() * neg_ind).sum(-1) / nneg else: neg_ll = (logp_vec * weights.abs() * neg_ind).sum(-1) # min_obj = -pos_ll + neg_ll min_obj = -pos_ll # min_obj = neg_ll if debug: pdb.set_trace() if torch.isnan(min_obj): print(min_obj) pdb.set_trace() return min_obj
def update(self, tensor: torch.Tensor): tensor_max = tensor.abs().amax((0, 2, 3)) tensor_max.div_(self.value).clamp_(self.eps) tensor_max.mul_(self.momentum) self.running_max.mul_(1 - self.momentum) self.running_max.add_(tensor_max)
def forward(self, input: torch.Tensor, target: torch.Tensor): input = input.float().view(-1) target = target.float().view(-1) neg_abs = -input.abs() loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() return loss.mean()