def log_uniform_candidate_sampler(self, targets, choice_func=_choice): # returns sampled, true_expected_count, sampled_expected_count # targets = (batch_size, ) # # samples = (n_samples, ) # true_expected_count = (batch_size, ) # sampled_expected_count = (n_samples, ) # see: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/range_sampler.h # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/range_sampler.cc # algorithm: keep track of number of tries when doing sampling, # then expected count is # -expm1(num_tries * log1p(-p)) # = (1 - (1-p)^num_tries) where p is self._probs[id] np_sampled_ids, num_tries = choice_func(self._num_words, self._num_samples) sampled_ids = torch.from_numpy(np_sampled_ids).to(targets.device) # Compute expected count = (1 - (1-p)^num_tries) = -expm1(num_tries * log1p(-p)) # P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1) target_probs = torch.log((targets.float() + 2.0) / (targets.float() + 1.0)) / self._log_num_words_p1 target_expected_count = -1.0 * (torch.exp(num_tries * torch.log1p(-target_probs)) - 1.0) sampled_probs = torch.log((sampled_ids.float() + 2.0) / (sampled_ids.float() + 1.0)) / self._log_num_words_p1 sampled_expected_count = -1.0 * (torch.exp(num_tries * torch.log1p(-sampled_probs)) - 1.0) sampled_ids.requires_grad_(False) target_expected_count.requires_grad_(False) sampled_expected_count.requires_grad_(False) return sampled_ids, target_expected_count, sampled_expected_count
def get_probs_and_logits(ps=None, logits=None, is_multidimensional=True): """ Convert probability values to logits, or vice-versa. Either ``ps`` or ``logits`` should be specified, but not both. :param ps: tensor of probabilities. Should be in the interval *[0, 1]*. If, ``is_multidimensional = True``, then must be normalized along axis -1. :param logits: tensor of logit values. For the multidimensional case, the values, when exponentiated along the last dimension, must sum to 1. :param is_multidimensional: determines the computation of ps from logits, and vice-versa. For the multi-dimensional case, logit values are assumed to be log probabilities, whereas for the uni-dimensional case, it specifically refers to log odds. :return: tuple containing raw probabilities and logits as tensors. """ assert (ps is None) != (logits is None) if ps is not None: eps = _get_clamping_buffer(ps) ps_clamped = ps.clamp(min=eps, max=1 - eps) if is_multidimensional: if ps is None: ps = softmax(logits, -1) else: logits = torch.log(ps_clamped) else: if ps is None: ps = F.sigmoid(logits) else: logits = torch.log(ps_clamped) - torch.log1p(-ps_clamped) return ps, logits
def log_prob(self, value): self._validate_log_prob_arg(value) y = (value - self.loc) / self.scale Z = (self.scale.log() + 0.5 * self.df.log() + 0.5 * math.log(math.pi) + torch.lgamma(0.5 * self.df) - torch.lgamma(0.5 * (self.df + 1.))) return -0.5 * (self.df + 1.) * torch.log1p(y**2. / self.df) - Z
def log_prob(self, value): self._validate_log_prob_arg(value) ct1 = self.df1 * 0.5 ct2 = self.df2 * 0.5 ct3 = self.df1 / self.df2 t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma() t2 = ct1 * ct3.log() + (ct1 - 1) * torch.log(value) t3 = (ct1 + ct2) * torch.log1p(ct3 * value) return t1 + t2 - t3
def log_prob(self, value): self._validate_log_prob_arg(value) log_factorial_n = math.lgamma(self.total_count + 1) log_factorial_k = torch.lgamma(value + 1) log_factorial_nmk = torch.lgamma(self.total_count - value + 1) max_val = (-self.logits).clamp(min=0.0) # Note that: torch.log1p(-self.probs)) = max_val - torch.log1p((self.logits + 2 * max_val).exp())) return (log_factorial_n - log_factorial_k - log_factorial_nmk + value * self.logits + self.total_count * max_val - self.total_count * torch.log1p((self.logits + 2 * max_val).exp()))
def probs_to_logits(probs, is_binary=False): r""" Converts a tensor of probabilities into logits. For the binary case, this denotes the probability of occurrence of the event indexed by `1`. For the multi-dimensional case, the values along the last dimension denote the probabilities of occurrence of each of the events. """ ps_clamped = clamp_probs(probs) if is_binary: return torch.log(ps_clamped) - torch.log1p(-ps_clamped) return torch.log(ps_clamped)
def __call__(self, x): """ Args: x (FloatTensor/LongTensor or ndarray) Returns: x_mu (LongTensor or ndarray) """ mu = self.qc - 1. if isinstance(x, np.ndarray): x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int) elif isinstance(x, (torch.Tensor, torch.LongTensor)): if isinstance(x, torch.LongTensor): x = x.float() mu = torch.FloatTensor([mu]) x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) x_mu = ((x_mu + 1) / 2 * mu + 0.5).long() return x_mu
def triplet_loss(self, distances, anchors, positives, negatives, return_delta=False): """Compute triplet loss Parameters ---------- distances : torch.Tensor Condensed matrix of pairwise distances. anchors, positives, negatives : list of int Triplets indices. return_delta : bool, optional Return delta before clamping. Returns ------- loss : torch.Tensor Triplet loss. """ # estimate total number of embeddings from pdist shape n = int(.5 * (1 + np.sqrt(1 + 8 * len(distances)))) n = [n] * len(anchors) # convert indices from squared matrix # to condensed matrix referential pos = list(map(to_condensed, n, anchors, positives)) neg = list(map(to_condensed, n, anchors, negatives)) # compute raw triplet loss (no margin, no clamping) # the lower, the better delta = distances[pos] - distances[neg] # clamp triplet loss if self.clamp == 'positive': loss = torch.clamp(delta + self.margin_, min=0) elif self.clamp == 'softmargin': loss = torch.log1p(torch.exp(delta)) elif self.clamp == 'sigmoid': # TODO. tune this "10" hyperparameter # TODO. log-sigmoid loss = F.sigmoid(10 * (delta + self.margin_)) # return triplet losses if return_delta: return loss, delta.view((-1, 1)), pos, neg else: return loss
def __call__(self, x_mu): """ Args: x_mu (FloatTensor/LongTensor or ndarray) Returns: x (FloatTensor or ndarray) """ mu = self.qc - 1. if isinstance(x_mu, np.ndarray): x = ((x_mu) / mu) * 2 - 1. x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu elif isinstance(x_mu, (torch.Tensor, torch.LongTensor)): if isinstance(x_mu, torch.LongTensor): x_mu = x_mu.float() mu = torch.FloatTensor([mu]) x = ((x_mu) / mu) * 2 - 1. x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu return x
grads = [] eta = 1. difs = [] for i in range(n): z = dist.rsample() # print (z.shape) # fsf b = Hpy(z) # print (b.shape) # fasdf #p(z|b) theta = bern_param #logit_to_prob(bern_param) v = torch.rand(z.shape[0]) v_prime = v * (b - 1.) * (theta - 1.) + b * (v * theta + 1. - theta) z_tilde = logits + torch.log(v_prime) - torch.log1p(-v_prime) # z_tilde = logits.detach() + torch.log(v_prime) - torch.log1p(-v_prime) #detachign biases it..I used detach initally z_tilde = torch.sigmoid(z_tilde) # #p(z|b) # if b ==0: # v= np.random.rand()*(1-bern_param) # else: # v = np.random.rand()*bern_param+(1-bern_param) # z_tilde = torch.log(bern_param/(1-bern_param)) + torch.log(v/(1-v)) # z_tilde = torch.sigmoid(z_tilde) logprob = dist_bern.log_prob(b) # logprob = dist_bern.log_prob(samp) logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0]
def forward(self, grad_est_type, x=None, warmup=1., inf_net=None): #, k=1): #, marginf_type=0): outputs = {} B = x.shape[0] #Samples from relaxed bernoulli z, logits, logqz = self.q.sample(x) if isnan(logqz).any(): print(torch.sum(isnan(logqz).float()).data.item()) print(torch.mean(logits).data.item()) print(torch.max(logits).data.item()) print(torch.min(logits).data.item()) print(torch.max(z).data.item()) print(torch.min(z).data.item()) fdsfad # Compute discrete ELBO b = harden(z).detach() logpx_b, logq_b, alpha1 = self.f(x, b, logits, hard=True) fhard = (logpx_b - logq_b).detach() if grad_est_type == 'SimpLAX': # Control Variate logpx_z, logq_z, alpha2 = self.f(x, z, logits, hard=False) fsoft = logpx_z.detach() #- logq_z c = self.surr(x, z).view(B) # REINFORCE with Control Variate Adv = (fhard - fsoft - c).detach() cost1 = Adv * logqz # Unbiased gradient of fhard/elbo cost_all = cost1 + c + fsoft # + logpx_b # Surrogate loss surr_cost = torch.abs(fhard - fsoft - c)#**2 elif grad_est_type == 'RELAX': #p(z|b) theta = logit_to_prob(logits) v = torch.rand(z.shape[0], z.shape[1]).cuda() v_prime = v * (b - 1.) * (theta - 1.) + b * (v * theta + 1. - theta) # z_tilde = logits.detach() + torch.log(v_prime) - torch.log1p(-v_prime) z_tilde = logits + torch.log(v_prime) - torch.log1p(-v_prime) z_tilde = torch.sigmoid(z_tilde) # Control Variate logpx_z, logq_z, alpha2 = self.f(x, z, logits, hard=False) fsoft = logpx_z.detach() #- logq_z c_ztilde = self.surr(x, z_tilde).view(B) c_z = self.surr(x, z).view(B) # REINFORCE with Control Variate dist_bern = Bernoulli(logits=logits) logqb = dist_bern.log_prob(b.detach()) logqb = torch.sum(logqb,1) Adv = (fhard - fsoft - c_ztilde).detach() cost1 = Adv * logqb # Unbiased gradient of fhard/elbo cost_all = cost1 + fsoft + c_z - c_ztilde#+ logpx_b # Surrogate loss surr_cost = torch.abs(fhard - fsoft - c_ztilde)#**2 elif grad_est_type == 'SimpLAX_nosoft': # Control Variate logpx_z, logq_z, alpha2 = self.f(x, z, logits, hard=False) # fsoft = logpx_z.detach() #- logq_z c = self.surr(x, z).view(B) # REINFORCE with Control Variate Adv = (fhard - c).detach() cost1 = Adv * logqz # Unbiased gradient of fhard/elbo cost_all = cost1 + c # + logpx_b # Surrogate loss surr_cost = torch.abs(fhard - c)#**2 elif grad_est_type == 'RELAX_nosoft': #p(z|b) theta = logit_to_prob(logits) v = torch.rand(z.shape[0], z.shape[1]).cuda() v_prime = v * (b - 1.) * (theta - 1.) + b * (v * theta + 1. - theta) z_tilde = logits + torch.log(v_prime) - torch.log1p(-v_prime) z_tilde = torch.sigmoid(z_tilde) # Control Variate logpx_z, logq_z, alpha2 = self.f(x, z, logits, hard=False) # fsoft = logpx_z.detach() #- logq_z c_ztilde = self.surr(x, z_tilde).view(B) c_z = self.surr(x, z).view(B) # REINFORCE with Control Variate dist_bern = Bernoulli(logits=logits) logqb = dist_bern.log_prob(b.detach()) logqb = torch.sum(logqb,1) Adv = (fhard - c_ztilde).detach() # print (Adv.shape, logqb.shape) cost1 = Adv * logqb # Unbiased gradient of fhard/elbo # print (cost1.shape, c_z.shape, c_ztilde.shape) # fsdf cost_all = cost1 + c_z - c_ztilde#+ logpx_b # Surrogate loss surr_cost = torch.abs(fhard - c_ztilde)#**2 # Confirm generator grad isnt in encoder grad # logprobgrad = torch.autograd.grad(outputs=torch.mean(fhard), inputs=(logits), retain_graph=True)[0] # print (logprobgrad.shape, torch.max(logprobgrad), torch.min(logprobgrad)) # logprobgrad = torch.autograd.grad(outputs=torch.mean(fsoft), inputs=(logits), retain_graph=True)[0] # print (logprobgrad.shape, torch.max(logprobgrad), torch.min(logprobgrad)) # fsdfads outputs['logpx'] = torch.mean(logpx_b) outputs['x_recon'] = alpha1 # outputs['welbo'] = torch.mean(logpx + warmup*( logpz - logqz)) outputs['welbo'] = torch.mean(cost_all) #torch.mean(logpx_b + warmup*(KL)) outputs['elbo'] = torch.mean(logpx_b - logq_b - 138.63) # outputs['logws'] = log_ws outputs['z'] = z outputs['logpz'] = torch.zeros(1) #torch.mean(logpz) outputs['logqz'] = torch.mean(logq_b) outputs['surr_cost'] = torch.mean(surr_cost) outputs['fhard'] = torch.mean(fhard) # outputs['fsoft'] = torch.mean(fsoft) # outputs['c'] = torch.mean(c) outputs['logq_z'] = torch.mean(logq_z) outputs['logits'] = logits return outputs
def _kl_geometric_geometric(p, q): return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits
def custom_regularization(self, saver_net, trainer_net, tasknum): sigma_weight_reg_sum = 0 sigma_bias_reg_sum = 0 sigma_weight_normal_reg_sum = 0 sigma_bias_normal_reg_sum = 0 mu_weight_reg_sum = 0 mu_bias_reg_sum = 0 L1_mu_weight_reg_sum = 0 L1_mu_bias_reg_sum = 0 loss = 0 saved=0 if tasknum>0: saved=1 else: prev_weight_strength = nn.Parameter(torch.Tensor(28*28,1).uniform_(0,0)) for (_, saver_layer), (_, trainer_layer) in zip(saver_net.named_children(), trainer_net.named_children()): if isinstance(trainer_layer, BayesianLinear)==False and isinstance(trainer_layer, BayesianConv2D)==False: continue # calculate mu regularization trainer_weight_mu = trainer_layer.weight_mu saver_weight_mu = saver_layer.weight_mu trainer_bias = trainer_layer.bias saver_bias = saver_layer.bias fan_in, fan_out = _calculate_fan_in_and_fan_out(trainer_weight_mu) trainer_weight_sigma = torch.log1p(torch.exp(trainer_layer.weight_rho)) saver_weight_sigma = torch.log1p(torch.exp(saver_layer.weight_rho)) if isinstance(trainer_layer, BayesianLinear): std_init = math.sqrt((2 / fan_in) * self.args.ratio) if isinstance(trainer_layer, BayesianConv2D): std_init = math.sqrt((2 / fan_out) * self.args.ratio) saver_weight_strength = std_init / saver_weight_sigma L2_strength = saver_weight_strength bias_strength = torch.squeeze(saver_weight_strength) L1_sigma = saver_weight_sigma bias_sigma = torch.squeeze(saver_weight_sigma) mu_weight_reg = (L2_strength * (trainer_weight_mu-saver_weight_mu)).norm(2)**2 mu_bias_reg = (bias_strength * (trainer_bias-saver_bias)).norm(2)**2 L1_mu_weight_reg = (torch.div(saver_weight_mu**2,L1_sigma**2)*(trainer_weight_mu - saver_weight_mu)).norm(1) L1_mu_bias_reg = (torch.div(saver_bias**2,bias_sigma**2)*(trainer_bias - saver_bias)).norm(1) L1_mu_weight_reg = L1_mu_weight_reg * (std_init ** 2) L1_mu_bias_reg = L1_mu_bias_reg * (std_init ** 2) weight_sigma = (trainer_weight_sigma**2 / saver_weight_sigma**2) normal_weight_sigma = trainer_weight_sigma**2 sigma_weight_reg_sum = sigma_weight_reg_sum + (weight_sigma - torch.log(weight_sigma)).sum() sigma_weight_normal_reg_sum = sigma_weight_normal_reg_sum + (normal_weight_sigma - torch.log(normal_weight_sigma)).sum() mu_weight_reg_sum = mu_weight_reg_sum + mu_weight_reg mu_bias_reg_sum = mu_bias_reg_sum + mu_bias_reg L1_mu_weight_reg_sum = L1_mu_weight_reg_sum + L1_mu_weight_reg L1_mu_bias_reg_sum = L1_mu_bias_reg_sum + L1_mu_bias_reg # L2 loss # loss = loss + self.args.decay * (mu_weight_reg_sum + mu_bias_reg_sum) # L1 loss # loss = loss + saved * (L1_mu_weight_reg_sum + L1_mu_bias_reg_sum) # sigma regularization loss = loss + self.args.beta * (sigma_weight_reg_sum + sigma_weight_normal_reg_sum) return loss
def inverse_sigmoid(x): # Clamp tiny values (<1e-38 for float32) finfo = torch.finfo(x.dtype) x = x.clamp(min=finfo.tiny, max=1. - finfo.eps) return torch.log(x) - torch.log1p(-x)
def icdf(self, value): self._validate_log_prob_arg(value) term = value - 0.5 return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs())
def train(model, optimizer, train_loader, val_loader, epochs=1, alpha=0.9, stop_accuracy=0.9940, cuda_avail=True): start_ts = time.time() losses = [] train_loss = [] nll_loss_function = torch.nn.NLLLoss() batches = len(train_loader) # training loop + eval loop for epoch in range(epochs): total_loss = 0 progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches) model.train(True) for i, data in progress: if cuda_avail: X, y, fakes = data[0].cuda(), data[1].cuda().squeeze( -1).long(), data[2].cuda() else: X, y, fakes = data[0], data[1].squeeze(-1).long(), data[2] model.zero_grad() outputs = model(X) # prepare loss loss = 0 # computing loss on fake samples if not torch.all(fakes == 0): mask = torch.zeros((y.shape[0], 10)).byte() for i in range(y.shape[0]): mask[i, torch.abs(y[i])] = 1 loss = alpha * torch.sum( torch.log1p(torch.exp(torch.sigmoid( outputs[mask][fakes])))) # computing loss on real samples if not torch.all(fakes == 1): loss += (1 - alpha) * nll_loss_function( torch.softmax(torch.sigmoid(outputs[1 - fakes]), -1), y[1 - fakes]) total_loss += loss.detach().data train_loss.append(copy.deepcopy(loss.detach().data.cpu().numpy())) loss.backward() optimizer.step() progress.set_description("Loss: {:.4f}".format(loss.item())) val_loss, val_accuracy = measure_scores(model, val_loader, cuda_avail) print( f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_loss}" ) losses.append(total_loss / batches) # early stopping with a threshold if val_accuracy >= stop_accuracy: break return train_loss
def logit(self, x): return torch.log(x) - torch.log1p(-x)
def get_gaussiandistancefromprior(self, mu_new, mu_prev, rho_prev): log_qw_theta_sum = 0 log_pw_sum = 0 mu, rho, w = self.stack() # print("new weight") # print(w) sigma = torch.log1p(torch.exp(rho)) split = (self.architecture[self.depth - 2] * self.architecture[self.depth - 1] + self.architecture[self.depth - 1]) w_last = w[-split:] mu_last = mu[-split:] rho_last = rho[-split:] sigma_last = sigma[-split:] # print(w_last) w_bef = w[0:(len(w) - split)] mu_bef = mu[0:(len(w) - split)] rho_bef = rho[0:(len(w) - split)] sigma_bef = sigma[0:(len(w) - split)] log_qw_theta_last = self.get_gaussianloglikelihood_qw(w_last, mu_last, sigma_last, p=1) log_qw_theta_sum_last = (log_qw_theta_last).sum() # print("primo ", log_qw_theta_sum_last) log_qw_theta_bef = self.get_gaussianloglikelihood_qw( w_bef, mu_bef, sigma_bef, self.p) log_qw_theta_sum_bef = (log_qw_theta_bef).sum() log_qw_theta_sum = log_qw_theta_sum_last + log_qw_theta_sum_bef # print("secondo ", log_qw_theta_sum_bef) sigma_prev = torch.log1p(torch.exp(rho_prev)) # if alpha_k is zero then set to zero also the prev mu: in this case we do not learn sequentially if self.alpha_k == 0: with torch.no_grad(): mu_prev.data.zero_() mu_new.data.zero_() mu_prev_last = mu_prev[-split:] mu_new_last = mu_new[-split:] sigma_prev_last = sigma_prev[-split:] mu_prev_bef = mu_prev[0:(len(w) - split)] mu_new_bef = mu_new[0:(len(w) - split)] sigma_prev_bef = sigma_prev[0:(len(w) - split)] log_pw_last = self.get_gaussianlogkernelprior(w_last, mu_prev_last, sigma_prev_last, mu_new_last, p=1) log_pw_sum_last = (log_pw_last).sum() # print("terzo ", log_pw_sum_last) log_pw_bef = self.get_gaussianlogkernelprior(w_bef, mu_prev_bef, sigma_prev_bef, mu_new_bef, self.p) log_pw_sum_bef = (log_pw_bef).sum() # print("quarto ", log_pw_sum_bef) log_pw_sum = log_pw_sum_last + log_pw_sum_bef # print( "new" ) # print("quinto ", log_qw_theta_sum.data.numpy(), log_pw_sum.data.numpy()) return (log_qw_theta_sum - log_pw_sum)
def forward(self, input): labels = input['labels'] labels = labels.squeeze(1).long() #covert back to longtensor vids = input['vids'] audio_mags = input['audio_mags'] audio_mix_mags = input['audio_mix_mags'] visuals = input['visuals'] # visuals_256 = input['visuals_256'] audio_mix_mags = audio_mix_mags + 1e-10 '''1. warp the spectrogram''' B = audio_mix_mags.size(0) T = audio_mix_mags.size(3) if self.opt.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(self.opt.device) audio_mix_mags = F.grid_sample(audio_mix_mags, grid_warp) audio_mags = F.grid_sample(audio_mags, grid_warp) '''2. calculate ground-truth masks''' gt_masks = audio_mags / audio_mix_mags # clamp to avoid large numbers in ratio masks gt_masks.clamp_(0., 5.) '''3. pass through visual stream and extract visual features''' visual_feature, _ = self.net_visual(Variable(visuals, requires_grad=False)) '''4. audio-visual feature fusion through UNet and predict mask''' audio_log_mags = torch.log(audio_mix_mags).detach() # audio_norm_mags = torch.sigmoid(torch.log(audio_mags + 1e-10)) mask_prediction = self.net_unet(audio_log_mags, visual_feature) '''5. masking the spectrogram of mixed audio to perform separation and predict classification label''' separated_spectrogram = audio_mix_mags * mask_prediction # generate spectrogram for the classifier spectrogram2classify = torch.log(separated_spectrogram + 1e-10) # get log spectrogram # calculate loss weighting coefficient if self.opt.weighted_loss: weight = torch.log1p(audio_mix_mags) weight = torch.clamp(weight, 1e-3, 10) else: weight = None ''' 6.classify the predicted spectrogram''' ''' add audio feature after resnet18 layer4, 512*8*8''' ''' add output for classifier, output:label,feature(after layer4)''' label_prediction, _ = self.net_classifier(spectrogram2classify) # if self.opt.visual_unet_encoder: # refine_mask, left_mask = self.refine_iteration(mask_prediction, audio_mix_mags, None) #visuals_256) # elif self.opt.visual_cat: # refine_mask, left_mask = self.refine_iteration(mask_prediction, audio_mix_mags, visual_feature) # else: # refine_mask, left_mask = self.refine_iteration(mask_prediction, audio_mix_mags, None) refine_masks = [None for i in range(self.opt.refine_iteration)] temp_mask = mask_prediction left_energy = [None for i in range(self.opt.refine_iteration)] for i in range(self.opt.refine_iteration): refine_mask, left_mask , left_mags = self.refine_iteration(temp_mask, audio_mix_mags, visual_feature) refine_masks[i] = refine_mask temp_mask = refine_mask left_energy[i] = torch.mean(left_mags) # refine后的频谱 refine_spec = audio_mix_mags * refine_mask # refine_norm_mags = torch.sigmoid(torch.log(refine_spec + 1e-10)) refine2classify = torch.log(refine_spec + 1e-10) _, fake_audio_feature = self.net_classifier(refine2classify) ''' 7. down channels for audio feature, for cal loss''' if self.opt.audio_extractor: real_audio_mags = torch.log(audio_mags + 1e-10) _ ,real_audio_feature = self.net_classifier(real_audio_mags) real_audio_feature = self.audio_extractor(real_audio_feature) fake_audio_feature = self.audio_extractor(fake_audio_feature) output = {'gt_label': labels, 'pred_label': label_prediction, 'pred_mask': mask_prediction, 'gt_mask': gt_masks, 'pred_spectrogram': separated_spectrogram, 'visual_object': visuals, 'audio_mags': audio_mags, 'audio_mix_mags': audio_mix_mags, 'weight': weight, 'vids': vids, 'refine_mask': refine_mask, 'refine_spec': refine_spec, 'left_mask':left_mask, 'refine_masks':refine_masks, 'left_mags': left_mags, 'left_energy':left_energy} if self.opt.audio_extractor: output['real_audio_feat'] = real_audio_feature output['fake_audio_feat'] = fake_audio_feature return output
def forward(self, x): y = torch.exp(self.alphas) * torch.log1p(torch.exp(x)) - torch.exp( self.betas) * torch.log1p(torch.exp(-x)) return y
def elementwise_logsumexp(a, b): """computes log(exp(x) + exp(b))""" return torch.max(a, b) + torch.log1p(torch.exp(-torch.abs(a - b)))
def log1m(x: Tensor) -> Tensor: return torch.log1p(x.neg())
def forward(ctx, x): case1 = torch.log(torch.expm1(x).neg()) case2 = torch.log1p(x.exp().neg()) # return torch.where(x > -0.693147, case1, case2) result = torch.where(x > -0.693147, case1, case2) return result
def forward(self, distances): return self.weights * torch.log1p(distances.pow(self.exponent))
def icdf(self, value): term = value - 0.5 return self.loc - self.scale * (term).sign() * torch.log1p( -2 * term.abs())
# risk_bits = torch.distributions.bernoulli.Bernoulli(risk).sample() # mydata = F.normalize(Variable(mydata),2,0) # # mydata = Variable(mydata) # # mydata = Variable((1-2*(mydata > 0).float())*torch.log(1+torch.abs(mydata))) # #mydatadf = pd.DataFrame(mydata.data.numpy()) # #mydatadf['npi'] = true_assignments import pandas as pd pd_mydata = pd.read_csv('~/workspace/marshfield/recode_like_data/training_wide.csv') pd_myoutcomes = pd.read_csv('~/workspace/marshfield/recode_like_data/training_outcomes_wide.csv') pd_mytestdata = pd.read_csv('~/workspace/marshfield/recode_like_data/holdout_wide.csv') pd_mytestoutcomes = pd.read_csv('~/workspace/marshfield/recode_like_data/holdout_outcomes_wide.csv') pd_mydata = pd_mydata.drop(columns='STUDY_ID') mydata = torch.log1p(torch.tensor(pd_mydata.as_matrix()).float()) pd_myoutcomes = pd_myoutcomes.drop(columns='STUDY_ID') myoutcomes = torch.tensor(torch.tensor(pd_myoutcomes.as_matrix()).float().sum(1) > 0).float() pd_mytestdata = pd_mytestdata.drop(columns='STUDY_ID') mytestdata = torch.log1p(torch.tensor(pd_mytestdata.as_matrix()).float()) pd_mytestoutcomes = pd_mytestoutcomes.drop(columns='STUDY_ID') mytestoutcomes = torch.tensor(torch.tensor(pd_mytestoutcomes.as_matrix()).float().sum(1) > 0).float() true_assignments = None risk_bits = myoutcomes # true_assignments = None side_label_assignments = None mp = ModelParameters() mp.num_epochs = 500
def _inverse(self, y): # to avoid evil machine error y = torch.clamp(y, min=-1. + eps, max=1. - eps) return 0.5 * (torch.log1p(y) - torch.log1p(-y))
def sigma(self) -> torch.Tensor: return torch.log1p(torch.exp(self.rho))
def _kl_continuous_bernoulli_continuous_bernoulli(p, q): t1 = p.mean * (p.logits - q.logits) t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs) t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs) return t1 + t2 + t3
def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) u = self.loc.new(*shape).uniform_(_finfo(self.loc).eps - 1, 1) # TODO: If we ever implement tensor.nextafter, below is what we want ideally. # u = self.loc.new(*shape).uniform_(self.loc.nextafter(-.5, 0), .5) return self.loc - self.scale * u.sign() * torch.log1p(-u.abs())
def _kl_beta_continuous_bernoulli(p, q): return -p.entropy() - p.mean * q.logits - torch.log1p( -q.probs) - q._cont_bern_log_norm()
def moments(lowerB, upperB, mu, sigma): device = mu.device lowerB = lowerB.double() upperB = upperB.double() mu = mu.double() sigma = sigma.double() pi = torch.Tensor([np.pi]).double().expand_as(mu).to(device=device) logZhat = torch.empty_like(mu).double().to(device=device) Zhat = torch.empty_like(mu).double().to(device=device) muHat = torch.empty_like(mu).double().to(device=device) sigmaHat = torch.empty_like(mu).double().to(device=device) entropy = torch.empty_like(mu).double().to(device=device) meanConst = torch.empty_like(mu).double().to(device=device) varConst = torch.empty_like(mu).double().to(device=device) """ lowerB is the lower bound upperB is the upper bound mu is the mean of the normal distribution (which is truncated) sigma is the VARIANCE of the normal distribution (which is truncated) """ """ # establish bounds """ a = (lowerB - mu) / (torch.sqrt(2 * sigma)) b = (upperB - mu) / (torch.sqrt(2 * sigma)) """ # do the stable calculation """ # written out in long format to make clear the steps. There are steps to # take to make this code shorter, but I think it is most readable this way. # KEY: The key is to consider that erfcx is stable close to erfc==0, but # less stable close to erfc==1. Since the Gaussian is symmetric # about 0, we can choose to flip the calculation around the origin # to improve our stability. For example, if a is -inf, and b is # -100, then a naive application will return 0. However, we can # just flip this around to be erfc(b)= erfc(100), which we can then # STABLY calculate with erfcx(100)...neat. This leads to many # different cases, when either argument is inf or not. # Also there may appear to be some redundancy here, but it is also # worth noting that a less detailed application of erfcx can be # troublesome, as erfcx(-big numbers) = Inf, which wreaks havoc on # a lot of the calculations. The cases below treat this when # necessary. # first check for problem cases # a problem case I = torch.isinf(a) & torch.isinf(b) if I.any(): # check the sign I_sign = torch.eq(torch.sign(a), torch.sign(b)) & I # then this is integrating from inf to inf, for example. #logZhat = -inf #meanConst = inf #varConst = 0 logZhat[I_sign] = torch.Tensor([-np.inf]).double().to(device=device) Zhat[I_sign] = torch.Tensor([0]).double().to(device=device) muHat[I_sign] = a[I_sign] sigmaHat[I_sign] = torch.Tensor([0]).double().to(device=device) entropy[I_sign] = torch.Tensor([-np.inf]).double().to(device=device) #logZhat = 0 #meanConst = mu #varConst = 0 I_sign_n = ~torch.eq(torch.sign(a), torch.sign(b)) & I logZhat[I_sign_n] = torch.Tensor([0]).double().to(device=device) Zhat[I_sign_n] = torch.Tensor([1]).double().to(device=device) muHat[I_sign_n] = mu[I_sign_n] sigmaHat[I_sign_n] = sigma[I_sign_n] entropy[I_sign_n] = 0.5 * torch.log(2 * pi[I_sign_n] * torch.exp( torch.Tensor([1])).double().to(device=device) * sigma[I_sign_n]) # a problem case I_taken_care_of = I I = (a > b) & ~I_taken_care_of if I.any(): # these bounds pointing the wrong way, so we return 0 by convention. #logZhat = -inf #meanConst = 0 #varConst = 0 logZhat[I] = torch.Tensor([-np.inf]).double().to(device=device) Zhat[I] = torch.zeroes_like(mu[I]).double().to(device=device) muHat[I] = mu[I] sigmaHAT[I] = torch.zeroes_like(mu[I]).double().to(device=device) entropy[I] = torch.Tensor([-np.inf]).double().to(device=device) # now real cases follow... I_taken_care_of = I | I_taken_care_of I = (torch.isinf(a)) & ~I_taken_care_of if I.any(): # then we are integrating everything up to b # in infinite precision we just want normcdf(b), but that is not # numerically stable. # instead we use various erfcx. erfcx scales very very well for small # probabilities (close to 0), but poorly for big probabilities (close # to 1). So some trickery is required. I_b = (b >= 26) & I if I_b.any(): # then this will be very close to 1... use this goofy expm1 log1p # to extend the range up to b==27... 27 std devs away from the # mean is really far, so hopefully that should be adequate. I # haven't been able to get it past that, but it should not matter, # as it will just equal 1 thereafter. Slight inaccuracy that # should not cause any trouble, but still no division by zero or # anything like that. # Note that this case is important, because after b=27, logZhat as # calculated in the other case will equal inf, not 0 as it should. # This case returns 0. logZhatOtherTail = torch.log(torch.Tensor([0.5])).double().to(device=device)\ + torch.log(erfcx(b[I_b]))\ - b[I_b]**2 logZhat[I_b] = torch.log1p(-torch.exp(logZhatOtherTail)) I_b_n = (b < 26) & I if (I_b_n).any(): # b is less than 26, so should be stable to calculate the moments # with a clean application of erfcx, which should work out to # an argument almost b==-inf. # this is the cleanest case, and the other moments are easy also... logZhat[I_b_n] = torch.log(torch.Tensor([0.5])).double().to(device=device)\ + torch.log(erfcx(-b[I_b_n])) - b[I_b_n]**2 # the mean/var calculations are insensitive to these calculations, as we do # not deal in the log space. Since we have to exponentiate everything, # values will be numerically 0 or 1 at all the tails, so the mean/var will # not move. # note that the mean and variance are finally calculated below # we just calculate the constant here. meanConst[I] = -2. / erfcx(-b[I]) varConst[I] = -2. / erfcx(-b[I]) * (upperB[I] + mu[I]) # muHat = mu - (sqrt(sigma/(2*np.pi))*2)./erfcx(-b) # sigmaHat = sigma + mu.^2 - muHat.^2 - (sqrt(sigma/(2*np.pi))*2)./erfcx(-b)*(upperB + mu) I_taken_care_of = I | I_taken_care_of I = torch.isinf(b) & ~I_taken_care_of if I.any(): # then we are integrating from a up to Inf, which is just the opposite # of the above case. I_a = (a <= -26) & I a_erfcx = erfcx(a[I]) if I_a.any(): # then this will be very close to 1... use this goofy expm1 log1p # to extend the range up to a==27... 27 std devs away from the # mean is really far, so hopefully that should be adequate. I # haven't been able to get it past that, but it should not matter, # as it will just equal 1 thereafter. Slight inaccuracy that # should not cause any trouble, but still no division by zero or # anything like that. # Note that this case is important, because after a=27, logZhat as # calculated in the other case will equal inf, not 0 as it should. # This case returns 0. logZhatOtherTail = torch.log(torch.Tensor([0.5])).double().to(device=device)\ + torch.log(erfcx(-a[I_a]))\ - a[I_a]**2 logZhat[I_a] = torch.log1p(-torch.exp(logZhatOtherTail)) I_a_n = (a > -26) & I if (I_a_n).any(): # a is more than -26, so should be stable to calculate the moments # with a clean application of erfcx, which should work out to # almost inf. # this is the cleanest case, and the other moments are easy also... logZhat[I_a_n] = torch.log(torch.Tensor([0.5])).double().to(device=device)\ + torch.log(erfcx(a[I_a_n]))\ - a[I_a_n]**2 # the mean/var calculations are insensitive to these calculations, as we do # not deal in the log space. Since we have to exponentiate everything, # values will be numerically 0 or 1 at all the tails, so the mean/var will # not move. meanConst[I] = 2. / a_erfcx varConst[I] = 2. / a_erfcx * (lowerB[I] + mu[I]) #muHat = mu + (sqrt(sigma/(2*np.pi))*2)./erfcx(a) #sigmaHat = sigma + mu.^2 - muHat.^2 + (sqrt(sigma/(2*np.pi))*2)./erfcx(a)*(lowerB + mu) I_taken_care_of = I | I_taken_care_of # any other cases has bounds for which neither are inf I = ~I_taken_care_of # we have a range from a to b (neither inf), and we need some stable exponent if I.any(): # calculations. I_eq = torch.eq(torch.sign(a), torch.sign(b)) & I if I_eq.any(): # then we can exploit symmetry in this problem to make the # calculations stable for erfcx, that is, with positive arguments: # Zerfcx1 = 0.5*(exp(-b.^2)*erfcx(b) - exp(-a.^2)*erfcx(a)) maxab = torch.max(torch.abs(a[I_eq]), torch.abs(b[I_eq])) minab = torch.min(torch.abs(a[I_eq]), torch.abs(b[I_eq])) logZhat[I_eq] = torch.log(torch.Tensor([0.5])).double().to(device=device) - minab**2 \ + torch.log( torch.abs( torch.exp(-(maxab**2-minab**2))*erfcx(maxab)\ - erfcx(minab)) ) # now the mean and variance calculations # note here the use of the abs and signum functions for flipping the sign # of the arguments appropriately. This uses the relationship # erfc(a) = 2 - erfc(-a). meanConst[I_eq] = 2*torch.sign(a[I_eq])*(1/((erfcx(abs(a[I_eq])) \ - torch.exp(a[I_eq]**2-b[I_eq]**2)*erfcx(abs(b[I_eq]))))\ - 1/((torch.exp(b[I_eq]**2-a[I_eq]**2)*erfcx(abs(a[I_eq]))\ - erfcx(abs(b[I_eq]))))) varConst[I_eq] = 2*torch.sign(a[I_eq])*((lowerB[I_eq]+mu[I_eq])/((erfcx(abs(a[I_eq]))\ - torch.exp(a[I_eq]**2-b[I_eq]**2)*erfcx(abs(b[I_eq]))))\ - (upperB[I_eq]+mu[I_eq])/((torch.exp(b[I_eq]**2-a[I_eq]**2)*erfcx(abs(a[I_eq]))\ - erfcx(abs(b[I_eq]))))) I_n_eq = ~torch.eq(torch.sign(a), torch.sign(b)) & I if I_n_eq.any(): # then the signs are different, which means b>a (upper>lower by definition), and b>=0, a<=0. # but we want to take the bigger one (larger magnitude) and make it positive, as that # is the numerically stable end of this tail. I_b_big_a = (torch.abs(b) >= torch.abs(a)) & I_n_eq if I_b_big_a.any(): mask = (a >= -26) & I_b_big_a if mask.any(): # do things normally logZhat[mask] = torch.log(torch.Tensor([0.5])).double().to(device=device)\ - a[mask]**2 + torch.log(erfcx(a[mask])\ - torch.exp(-(b[mask]**2\ - a[mask]**2))*erfcx(b[mask])) # now the mean and var meanConst[mask] = 2*(1/((erfcx(a[mask])\ - torch.exp(a[mask]**2\ -b[mask]**2)*erfcx(b[mask])))\ - 1/((torch.exp(b[mask]**2\ -a[mask]**2)*erfcx(a[mask])\ - erfcx(b[mask])))) varConst[mask] = 2*((lowerB[mask]+mu[mask])/((erfcx(a[mask])\ - torch.exp(a[mask]**2-b[mask]**2)*erfcx(b[mask])))\ - (upperB[mask]+mu[mask])/((torch.exp(b[mask]**2-a[mask]**2)*erfcx(a[mask])\ - erfcx(b[mask])))) mask = (a < -26) & I_b_big_a if mask.any(): # a is too small and the calculation will be unstable, so # we just put in something very close to 2 instead. # Again this uses the relationship # erfc(a) = 2 - erfc(-a). Since a<0 and b>0, this # case makes sense. This just says 2 - the right # tail - the left tail. logZhat[mask] = torch.log(torch.Tensor([0.5])).double().to(device=device)\ + torch.log( 2 - torch.exp(-b[mask]**2)*erfcx(b[mask])\ - torch.exp(-a[mask]**2)*erfcx(-a[mask]) ) # now the mean and var meanConst[mask] = 2*(1/((erfcx(a[mask]) - torch.exp(a[mask]**2-b[mask]**2)*erfcx(b[mask])))\ - 1/(torch.exp(b[mask]**2)*2 - erfcx(b[mask]))) varConst[mask] = 2*((lowerB[mask]+mu[mask])/((erfcx(a[mask])\ - torch.exp(a[mask]**2-b[mask]**2)*erfcx(b[mask])))\ - (upperB[mask]+mu[mask])/(torch.exp(b[mask]**2)*2 - erfcx(b[mask]))) I_b_less_a = (torch.abs(b) < torch.abs(a)) & I_n_eq if I_b_less_a.any(): mask = (b <= 26) & I_b_less_a if mask.any(): # do things normally but mirrored across 0 logZhat[mask] = torch.log(torch.Tensor([0.5])).double().to(device=device) - b[mask]**2 + torch.log( erfcx(-b[mask])\ - torch.exp(-(a[mask]**2 - b[mask]**2))*erfcx(-a[mask])) # now the mean and var meanConst[mask] = -2*(1/((erfcx(-a[mask])\ - torch.exp(a[mask]**2-b[mask]**2)*erfcx(-b[mask])))\ - 1/((torch.exp(b[mask]**2-a[mask]**2)*erfcx(-a[mask])\ - erfcx(-b[mask])))) varConst[mask] = -2*((lowerB[mask]+mu[mask])/((erfcx(-a[mask]) \ - torch.exp(a[mask]**2-b[mask]**2)*erfcx(-b[mask]))) \ - (upperB[mask]+mu[mask])/((torch.exp(b[mask]**2-a[mask]**2)*erfcx(-a[mask]) \ - erfcx(-b[mask])))) mask = (b > 26) & I_b_less_a if mask.any(): # b is too big and the calculation will be unstable, so # we just put in something very close to 2 instead. # Again this uses the relationship # erfc(a) = 2 - erfc(-a). Since a<0 and b>0, this # case makes sense. This just says 2 - the right # tail - the left tail. logZhat[mask] = torch.log(torch.Tensor([0.5])).double().to(device=device)\ + torch.log( 2 - torch.exp(-a[mask]**2)*erfcx(-a[mask])\ - torch.exp(-b[mask]**2)*erfcx(b[mask]) ) # now the mean and var meanConst[mask] = -2*(1/(erfcx(-a[mask]) - torch.exp(a[mask]**2)*2)\ - 1/(torch.exp(b[mask]**2-a[mask]**2)*erfcx(-a[mask]) - erfcx(-b[mask]))) varConst[mask] = -2*((lowerB[mask] + mu[mask])/(erfcx(-a[mask])\ - torch.exp(a[mask]**2)*2)\ - (upperB[mask] + mu[mask])/(torch.exp(b[mask]**2-a[mask]**2)*erfcx(-a[mask])\ - erfcx(-b[mask]))) # the above four cases (diff signs x stable/unstable) can be # collapsed into two cases by tracking the sign of the maxab # and sign of the minab (the min and max of abs(a) and # abs(b)), but that is a bit less clear, so we # leave it fleshed out above. """ # finally, calculate the returned values """ # logZhat is already calculated, as are meanConst and varConst. # no numerical precision in Zhat Zhat = torch.exp(logZhat) # make the mean muHat = mu + meanConst * torch.sqrt(sigma / (2 * pi)) # make the var sigmaHat = sigma + varConst * torch.sqrt(sigma / (2 * pi)) + mu**2 - muHat**2 # make entropy entropy = 0.5*((meanConst*torch.sqrt(sigma/(2*pi)))**2 + sigmaHat - sigma)/sigma\ + logZhat + torch.log(torch.sqrt(2*pi*torch.exp(torch.Tensor([1]).double().to(device=device))))\ + torch.log(torch.sqrt(sigma)) return logZhat.float(), Zhat.float(), muHat.float(), sigmaHat.float( ), entropy.float()
def icdf(self, value): if self._validate_args: self._validate_sample(value) term = value - 0.5 return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs())
def __call__(self, spectrogram): return torch.log1p(spectrogram * self.compression_factor)
def _kl_geometric_geometric(p, q): return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits
def arctanh(self: TensorType) -> TensorType: """ improve once this issue has been fixed: https://github.com/pytorch/pytorch/issues/10324 """ return type(self)(0.5 * (torch.log1p(self.raw) - torch.log1p(-self.raw)))
def logit(p): return torch.log(p) - torch.log1p(-p)
def log1mexp(x: Tensor) -> Tensor: case1 = torch.log(torch.expm1(x).neg()) case2 = torch.log1p(x.exp().neg()) # return torch.where(x > -0.693147, case1, case2) return torch.where(x > -0.693147, case1, case2)
def arccosh(x): c0 = torch.log(x) c1 = torch.log1p(torch.sqrt(x * x - 1) / x) return c0 + c1
def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] mag_mix = mag_mix + 1e-10 N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) # 0.0 warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # 0.1 calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # 0.2 ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) # LOG magnitude log_mag_mix = torch.log(mag_mix).detach() # 1. forward net_sound -> BxCxHxW feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, args.sound_activation) # 2. forward net_frame -> Bx1xC feat_frames = [None for n in range(N)] for n in range(N): feat_frames[n] = self.net_frame.forward_multiframe(frames[n]) feat_frames[n] = activate(feat_frames[n], args.img_activation) # 3. sound synthesizer pred_masks = [None for n in range(N)] for n in range(N): pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound) pred_masks[n] = activate(pred_masks[n], args.output_activation) # 4. loss err = self.crit(pred_masks, gt_masks, weight).reshape(1) return err, \ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
zs = torch.stack(zs) b = Hpy(zs) # EHRERE, I think the bug is above, see grad of stuff #p(z|b) theta = logit_to_prob(bern_param) v = torch.rand(zs.shape[0], zs.shape[1]) # v= (1-b)*v*(1-theta) + b*v*theta+(1-theta) # z_tilde = torch.log(theta/(1-theta)) + torch.log(v/(1-v)) v_prime = v * (b - 1.) * (theta - 1.) + b * (v * theta + 1. - theta) # z_tilde = bern_param.detach() + torch.log(v_prime) - torch.log1p(-v_prime) z_tilde = bern_param + torch.log(v_prime) - torch.log1p(-v_prime) z_tilde = torch.sigmoid(z_tilde) # z_tilde = torch.tensor(z_tilde, requires_grad=True) # print (z_tilde) # fadfsa # v_prime = v * (b - 1.) * (theta - 1.) + b * (v * theta + 1. - theta) # z_tilde = bern_param.detach() + torch.log(v_prime) - torch.log1p(-v_prime) # z_tilde = torch.sigmoid(z_tilde).detach() # z_tilde = torch.tensor(z_tilde, requires_grad=True) logprob = dist_bern.log_prob(b)
def sigma(self): return torch.log1p(torch.exp(self.rho))
def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) u = self.loc.new(shape).uniform_(_finfo(self.loc).eps - 1, 1) # TODO: If we ever implement tensor.nextafter, below is what we want ideally. # u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5) return self.loc - self.scale * u.sign() * torch.log1p(-u.abs())
def compute_feats(self, wavs): """Feature computation pipeline""" feats = self.hparams.Encoder(wavs) feats = spectral_magnitude(feats, power=0.5) feats = torch.log1p(feats) return feats
def compute_loss(self, model, net_output, sample, reduce=True): losses, other_logs = {}, {} # prepare data before computing loss sampled_uv = sample['sampled_uv'] # S, V, 2, N, P, P (patch-size) S, V, _, N, P1, P2 = sampled_uv.size() H, W, h, w = sample['size'][0, 0].long().cpu().tolist() ###debug if h == 0: h = 1 L = N * P1 * P2 flatten_uv = sampled_uv.view(S, V, 2, L) flatten_index = (flatten_uv[:, :, 0] // h + flatten_uv[:, :, 1] // w * W).long() assert 'colors' in sample and sample[ 'colors'] is not None, "ground-truth colors not provided" target_colors = sample['colors'] masks = (sample['alpha'] > 0) if self.args.no_background_loss else None if L < target_colors.size(2): target_colors = target_colors.gather( 2, flatten_index.unsqueeze(-1).repeat(1, 1, 1, 3)) masks = masks.gather(2, flatten_uv) if masks is not None else None if 'other_logs' in net_output: other_logs.update(net_output['other_logs']) # computing loss if self.args.color_weight > 0: color_loss = utils.rgb_loss(net_output['colors'], target_colors, masks, self.args.L1) losses['color_loss'] = (color_loss, self.args.color_weight) if self.args.alpha_weight > 0: _alpha = net_output['missed'].reshape(-1) alpha_loss = torch.log1p( 1. / 0.11 * _alpha.float() * (1 - _alpha.float())).mean().type_as(_alpha) losses['alpha_loss'] = (alpha_loss, self.args.alpha_weight) if self.args.depth_weight > 0: if sample['depths'] is not None: target_depths = target_depths.gather(2, flatten_index) depth_mask = masks & (target_depths > 0) depth_loss = utils.depth_loss(net_output['depths'], target_depths, depth_mask) else: # no depth map is provided, depth loss only applied on background based on masks max_depth_target = self.args.max_depth * torch.ones_like( net_output['depths']) if sample['mask'] is not None: depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, (1 - sample['mask']).bool()) else: depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, ~masks) depth_weight = self.args.depth_weight if self.args.depth_weight_decay is not None: final_factor, final_steps = eval(self.args.depth_weight_decay) depth_weight *= max( 0, 1 - (1 - final_factor) * self.task._num_updates / final_steps) other_logs['depth_weight'] = depth_weight losses['depth_loss'] = (depth_loss, depth_weight) if self.args.vgg_weight > 0: assert P1 * P2 > 1, "we have to use a patch-based sampling for VGG loss" target_colors = target_colors.reshape(-1, P1, P2, 3).permute( 0, 3, 1, 2) * .5 + .5 output_colors = net_output['colors'].reshape( -1, P1, P2, 3).permute(0, 3, 1, 2) * .5 + .5 vgg_loss = self.vgg(output_colors, target_colors) losses['vgg_loss'] = (vgg_loss, self.args.vgg_weight) if self.args.eikonal_weight > 0: losses['eik_loss'] = (net_output['eikonal-term'].mean(), self.args.eikonal_weight) # if self.args.regz_weight > 0: losses['reg_loss'] = (net_output['regz-term'].mean(), self.args.regz_weight) loss = sum(losses[key][0] * losses[key][1] for key in losses) # add a dummy loss loss = loss + model.dummy_loss + self.dummy_loss * 0. logging_outputs = {key: item(losses[key][0]) for key in losses} logging_outputs.update(other_logs) return loss, logging_outputs
def forward(self, x_wno, anneal_eps, indices=None, return_factorization=False, return_R=False): # copy x_wno x_wno = [xt.copy() for xt in x_wno] # add annealing noise for t in range(self.nt): ns = x_wno[t].shape[0] anneal_noise = np.random.normal(size=(ns, self.nv)) x_wno[t] = np.sqrt( 1 - anneal_eps**2) * x_wno[t] + anneal_eps * anneal_noise x = [ torch.tensor(xt, dtype=torch.float, device=self.device) for xt in x_wno ] z = [None] * self.nt for t in range(self.nt): ns = x_wno[t].shape[0] z_noise = torch.randn((ns, self.m), dtype=torch.float, device=self.device) z_mean = torch.mm(x[t], self.ws[t].t()) z[t] = z_mean + z_noise epsilon = 1e-8 objs = [None] * self.nt sigma = [None] * self.nt mi_xz = [None] * self.nt Rs = [None] * self.nt factorization = [None] * self.nt # store all concatenations here for better memory usage concats = dict() for t in range(self.nt): l = max(0, t - self.window_len[t]) r = min(self.nt, t + self.window_len[t] + 1) weights = [] left_t = t right_t = t for i in range(l, r): cur_ns = x_wno[i].shape[0] coef = np.power(self.gamma, np.abs(i - t)) # skip if the importance is too low if coef < 1e-6: continue left_t = min(left_t, i) right_t = max(right_t, i) weights.append( torch.tensor(coef * np.ones((cur_ns, )), dtype=torch.float, device=self.device)) weights = torch.cat(weights, dim=0) weights = weights / torch.sum(weights) weights = weights.reshape((-1, 1)) t_range = (left_t, right_t) if t_range in concats: x_all = concats[t_range] else: x_all = torch.cat(x[left_t:right_t + 1], dim=0) concats[t_range] = x_all ns_tot = x_all.shape[0] z_all_mean = torch.mm(x_all, self.ws[t].t()) z_all_noise = torch.randn((ns_tot, self.m), dtype=torch.float, device=self.device) z_all = z_all_mean + z_all_noise z2_all = ((z_all**2) * weights).sum(dim=0) # (m,) R_all = torch.mm((z_all * weights).t(), x_all) # m, nv R_all = R_all / torch.sqrt(z2_all).reshape( (self.m, 1)) # as <x^2_i> == 1 we don't divide by it z2 = z2_all R = R_all if return_R: Rs[t] = R if self.weighted_obj: X = x_all Z = z_all else: X = x[t] Z = z[t] if self.reg_type == 'MI': mi_xz[t] = -0.5 * torch.log1p( -torch.clamp(R**2, 0, 1 - epsilon)) # the rest depends on R and z2 only ri = ((R**2) / torch.clamp(1 - R**2, epsilon, 1 - epsilon)).sum( dim=0) # (nv,) # v_xi | z conditional mean outer_term = (1 / (1 + ri)).reshape((1, self.nv)) inner_term_1 = R / torch.clamp(1 - R**2, epsilon, 1) / torch.sqrt(z2).reshape( (self.m, 1)) # (m, nv) inner_term_2 = Z # (ns, m) cond_mean = outer_term * torch.mm(inner_term_2, inner_term_1) # (ns, nv) # calculate normed covariance matrix if needed need_sigma = (((indices is not None) and (t in indices)) or self.reg_type == 'Sigma') if need_sigma or return_factorization: inner_mat = 1.0 / (1 + ri).reshape( (1, self.nv)) * R / torch.clamp(1 - R**2, epsilon, 1) factorization[t] = inner_mat if need_sigma: sigma[t] = torch.mm(inner_mat.t(), inner_mat) identity_matrix = torch.eye(self.nv, dtype=torch.float, device=self.device) sigma[t] = sigma[t] * (1 - identity_matrix) + identity_matrix # objective if self.weighted_obj: obj_part_1 = 0.5 * torch.log( torch.clamp((((X - cond_mean)**2) * weights).sum(dim=0), epsilon, np.inf)).sum(dim=0) else: obj_part_1 = 0.5 * torch.log( torch.clamp(((X - cond_mean)**2).mean(dim=0), epsilon, np.inf)).sum(dim=0) obj_part_2 = 0.5 * torch.log(z2).sum(dim=0) objs[t] = obj_part_1 + obj_part_2 # experiments show that main_obj scales approximately linearly with nv # also it is a sum of over time steps, so we divide on (nt * nv) main_obj = 1.0 / (self.nt * self.nv) * sum(objs) # regularization reg_matrices = [None] * self.nt if self.reg_type == 'W': reg_matrices = self.ws if self.reg_type == 'Sigma': reg_matrices = sigma if self.reg_type == 'MI': reg_matrices = mi_xz reg_obj = torch.tensor(0.0, dtype=torch.float, device=self.device) # experiments show that L1 and L2 regularizations scale approximately linearly with nv if self.l1 > 0: l1_reg = sum([ torch.abs(reg_matrices[t + 1] - reg_matrices[t]).sum() for t in range(self.nt - 1) ]) l1_reg = 1.0 / (self.nt * self.nv) * l1_reg reg_obj = reg_obj + self.l1 * l1_reg if self.l2 > 0: l2_reg = sum([((reg_matrices[t + 1] - reg_matrices[t])**2).sum() for t in range(self.nt - 1)]) l2_reg = 1.0 / (self.nt * self.nv) * l2_reg reg_obj = reg_obj + self.l2 * l2_reg total_obj = main_obj + reg_obj return { 'total_obj': total_obj, 'main_obj': main_obj, 'reg_obj': reg_obj, 'objs': objs, 'sigma': sigma, 'R': Rs, 'factorization': factorization }