def dp_sgd_backward(dis, real_imgs, fake_imgs, device, clip_norm, noise_factor, dis_opt): """ since only one part of the loss depends on the data, we compute its gradients separately and noise them up in order to maintain similar gradient sizes, gradients for the other loss are clipped to the same per sample norm """ # real data loss first: params = list(dis.parameters()) loss_real = -pt.mean(dis(real_imgs)) with backpack(BatchGrad(), BatchL2Grad()): loss_real.backward(retain_graph=True) squared_param_norms_real = [ p.batch_l2 for p in params ] # first we get all the squared parameter norms... global_norms_real = pt.sqrt( pt.sum(pt.stack(squared_param_norms_real), dim=0)) # ...then compute the global norms... global_clips_real = pt.clamp_max( clip_norm / global_norms_real, 1.) # ...and finally get a vector of clipping factors perturbed_grads = [] for idx, param in enumerate(params): clipped_sample_grads = param.grad_batch * expand_vector( global_clips_real, param.grad_batch) clipped_grad = pt.sum(clipped_sample_grads, dim=0) # after clipping we sum over the batch noise_sdev = noise_factor * 2 * clip_norm # gaussian noise standard dev is computed (sensitivity is 2*clip)... perturbed_grad = clipped_grad + pt.randn_like( clipped_grad, device=device) * noise_sdev # ...and applied perturbed_grads.append(perturbed_grad) # store perturbed grads dis_opt.zero_grad() # now add fake data loss gradients: loss_fake = pt.mean(dis(fake_imgs)) with backpack(BatchGrad(), BatchL2Grad()): loss_fake.backward() squared_param_norms_fake = [ p.batch_l2 for p in params ] # first we get all the squared parameter norms... global_norms_fake = pt.sqrt( pt.sum(pt.stack(squared_param_norms_fake), dim=0)) # ...then compute the global norms... global_clips_fake = pt.clamp_max( clip_norm / global_norms_fake, 1.) # ...and finally get a vector of clipping factors for idx, param in enumerate(params): clipped_sample_grads = param.grad_batch * expand_vector( global_clips_fake, param.grad_batch) clipped_grad = pt.sum(clipped_sample_grads, dim=0) # after clipping we sum over the batch param.grad = clipped_grad + perturbed_grads[idx] ld = loss_real.item() + loss_fake.item() return global_norms_real, global_clips_real, global_norms_fake, global_clips_fake, ld
def polar_indices( positions, # (N, L_s, L, 2) nray, nring, inner_radius): # (N, L_s, L), (N, L_s, L), (N, L_s, L), (N, L_s, L) distances = torch.sqrt(positions[:, :, :, 0]**2 + positions[:, :, :, 1]**2) distance_indices = torch.clamp(distances / inner_radius, min=0, max=nring - 1).floor().long() angles = torch.atan2(positions[:, :, :, 1], positions[:, :, :, 0]) + math.pi # There is one angle value that can result in index of exactly nray, clamp it to nray-1 angular_indices = torch.clamp_max( (angles / (2 * math.pi) * nray).floor().long(), nray - 1) distance_offsets = torch.clamp_max(distances / inner_radius - distance_indices.float() - 0.5, max=2) angular_offsets = angles / (2 * math.pi) * nray - angular_indices.float() - 0.5 assert angular_indices.min( ) >= 0, f'Negative angular index: {angular_indices.min()}' assert angular_indices.max( ) < nray, f'invalid angular index: {angular_indices.max()} >= {nray}' assert distance_indices.min( ) >= 0, f'Negative distance index: {distance_indices.min()}' assert distance_indices.max( ) < nring, f'invalid distance index: {distance_indices.max()} >= {nring}' return distance_indices, angular_indices, distance_offsets, angular_offsets
def fixed_circle_loss_clean(output, target, mask, num_classes, softplus: torch.nn.Softplus): """ output: (B, N) target: (B, X) mask: (B, N) num_classes = N loss = log(1 + \sum_i exp(s^{neg}_i) \sum_j exp(s^{pos}_j)) \gamma = 1, m = 0 """ output = output.float() # seq_len = output.size(1) target_mask = (target == -1) target[target_mask] = num_classes one_hot_target = F.one_hot(target, num_classes + 1) one_hot_target = one_hot_target.sum(dim=1) one_hot_target = one_hot_target[:, :-1].to(dtype=output.dtype) mask = mask.to(dtype=output.dtype) all_mask = (one_hot_target.sum(dim=-1) == 0) mask_for_pos = torch.clamp_max(1 - one_hot_target + mask, max=1.0) mask_for_neg = torch.clamp_max(one_hot_target + mask, max=1.0) _flag = -1e12 ap = torch.clamp_min(1 - output.detach(), min=0.) an = torch.clamp_min(output.detach(), min=0.) delta_p = 1 delta_n = 0 output = (1 - 2 * one_hot_target) * output # Positive: -1 // Negative: +1 logit_pos = ap * (delta_p + output) + mask_for_pos * _flag x = logit_pos.max(dim=-1, keepdim=True)[0].detach() x = torch.relu_(x) logit_pos = logit_pos - x # assert output_pos.size() == output.size(), (output_pos.size(), output.size()) logit_neg = an * (output - delta_n) + mask_for_neg * _flag y = logit_neg.max(dim=-1, keepdim=True)[0].detach() y = torch.relu_(y) logit_neg = logit_neg - y loss = softplus(x + y + torch.logsumexp(logit_pos, dim=-1) + torch.logsumexp(logit_neg, dim=-1)) masked_loss = loss.masked_fill(all_mask, 0.) return masked_loss
def bound_loss(self, mu): if self.bounds_loss_coef is not None: soft_bound = 1.1 mu_loss_high = torch.clamp_max(mu - soft_bound, 0.0)**2 mu_loss_low = torch.clamp_max(-mu + soft_bound, 0.0)**2 b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1) else: b_loss = 0 return b_loss
def discretized_mix_logistic_loss(y_hat: torch.Tensor, y: torch.Tensor, num_classes=256, log_scale_min=-7.0): """Discretized mix of logistic distributions loss. Note that it is assumed that input is scaled to [-1, 1] :param y_hat: Tensor. shape=(batch_size, 3 * num_mixtures * img_channels, height, width), predict output. :param y: Tensor. shape=(batch_size, img_channels, height, width), Target. :return: Tensor loss """ # unpack parameters, [batch_size, num_mixtures * img_channels, height, width] x 3 logit_probs, means, log_scales = y_hat.chunk(3, dim=1) log_scales = torch.clamp_max(log_scales, log_scale_min) num_mixtures = y_hat.size(1) // y.size(1) // 3 B, C, H, W = y.shape y = y.unsqueeze(1).repeat(1, num_mixtures, 1, 1, 1).permute(0, 2, 1, 3, 4).reshape(B, -1, H, W) centered_y = y - means inv_stdv = torch.exp(-log_scales) plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) cdf_plus = torch.sigmoid(plus_in) min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) cdf_min = torch.sigmoid(min_in) log_cdf_plus = plus_in - F.softplus(plus_in) log_one_minus_cdf_min = -F.softplus(min_in) # probability for all other cases cdf_delta = cdf_plus - cdf_min mid_in = inv_stdv * centered_y log_pdf_mid = min_in - log_scales - 2. * F.softplus(mid_in) log_probs = torch.where( y < -0.999, log_cdf_plus, torch.where( y > 0.999, log_one_minus_cdf_min, torch.where(cdf_delta > 1e-5, torch.clamp_max(cdf_delta, 1e-12), log_pdf_mid - np.log((num_classes - 1) / 2)))) # (batch_size, num_mixtures * img_channels, height, width) log_probs = log_probs + F.softmax(log_probs, dim=1) log_probs = [ log_sum_exp(log_prob) for log_prob in log_probs.chunk(y.size(1), dim=1) ] log_probs = reduce(lambda a, b: a + b, log_probs) return -torch.sum(log_probs)
def _binary_cross_entropy(pred: Tensor, target: Tensor) -> Tensor: """二元交叉熵损失(F.binary_cross_entropy()) :param pred: shape = (N,) :param target: shape = (N,) torch.float32 :return: shape = ()""" # torch.mean(-torch.log(pred) * target + -torch.log(1 - pred) * (1 - target)) return torch.mean( torch.clamp_max(-torch.log(pred), 100) * target + # 防止inf torch.clamp_max(-torch.log(1 - pred), 100) * (1 - target))
def accept_reject(current_z, current_v, z, v, epsilon, accept_hist, hist_len, U, K=lambda v: torch.sum(v * v, 1)): """Accept/reject based on Hamiltonians for current and propose. Args: current_z: position *before* leap-frog steps current_v: speed *before* leap-frog steps z: position *after* leap-frog steps v: speed *after* leap-frog steps epsilon: step size of leap-frog. U: function to compute potential energy K: function to compute kinetic energy """ current_Hamil = K(current_v) + U(current_z) propose_Hamil = K(v) + U(z) prob = torch.clamp_max(torch.exp(current_Hamil - propose_Hamil), 1.) with torch.no_grad(): uniform_sample = torch.rand(prob.size()).cuda() accept = (prob > uniform_sample).float().cuda() z = z.mul(accept.view(-1, 1)) + current_z.mul(1. - accept.view(-1, 1)) accept_hist = accept_hist.add(accept) criteria = (accept_hist / hist_len > 0.65).float().cuda() adapt = 1.02 * criteria + 0.98 * (1. - criteria) epsilon = epsilon.mul(adapt).clamp(1e-4, .5) z.requires_grad_() return z, epsilon, accept_hist
def __call__(self, model: nn.Module, images: torch.Tensor, labels: torch.Tensor) -> AdversaryOutput: pixel_epsilon = self.epsilon * 255 n_iters = round(min( pixel_epsilon + 4, 1.25 * pixel_epsilon)) # according to the policy in the reference paper lo = torch.clamp_min(images - self.epsilon, 0) hi = torch.clamp_max(images + self.epsilon, 1) result = images.clone() result.requires_grad = True for _ in range(n_iters): if result.grad is not None: result.grad.detach_() result.grad.zero_() loss = self.compute_objective(model, result, labels, "mean") loss.backward() with torch.no_grad(): result += self.step_size * torch.sign(result.grad) clamp_min_tensor(result, lo) clamp_max_tensor(result, hi) result.requires_grad = False return AdversaryOutput(result, result - images)
def nn_transformation(model): b_new_list = [None for _ in range(len(model.layers))] act_shift_list = [None for _ in range(len(model.layers))] nn_model = model.get_nn_net() for i in range(len(model.layers)): cur_layer = model.layers[i] w_neg = T.clamp_max(cur_layer.weight, 0) b_tilde = cur_layer.bias - model.alpha[i] * T.sum(T.abs(w_neg), dim=1) b_new, act_shift = calc_b_new(cur_layer, b_tilde) b_new_list[i] = b_new act_shift_list[i] = act_shift w_neg_abs = T.abs(w_neg) w_pos = T.clamp_min(cur_layer.weight, 0) nn_model.add_layer( NNLinear( cur_layer.in_features, cur_layer.out_features, w_pos, w_neg_abs, b_new, model.alpha[i], act_shift, )) return nn_model
def step(self): """Performs a single optimization step. The function expects the gradients to have been computed by BackPACK and the parameters to have a ``batch_l2`` and ``grad_batch`` attribute. """ l2_norms_all_params_list = [] for group in self.param_groups: for p in group["params"]: l2_norms_all_params_list.append(p.batch_l2) l2_norms_all_params = torch.stack(l2_norms_all_params_list) total_norms = torch.sqrt(torch.sum(l2_norms_all_params, dim=0)) scaling_factors = torch.clamp_max(total_norms / self.max_norm, 1.0) for group in self.param_groups: for p in group["params"]: clipped_grads = p.grad_batch * make_broadcastable( scaling_factors, p.grad_batch) clipped_grad = torch.sum(clipped_grads, dim=0) noise_magnitude = self.stddev * self.max_norm noise = torch.randn_like(clipped_grad) * noise_magnitude perturbed_update = clipped_grad + noise p.data.add_(-self.lr * perturbed_update)
def forward(self, state): x = self.linear1(state) x = F.gelu(x) x = self.linear2(x) x = F.gelu(x) x = self.linear3(x) x_mean = F.gelu(x) x_mean = self.linear_mean_4(x_mean) x_mean = F.gelu(x_mean) x_mean = self.linear_mean_5(x_mean) x_mean = F.gelu(x_mean) x_std = F.gelu(x) x_std = self.linear_std_4(x_std) x_std = F.gelu(x_std) x_std = self.linear_std_5(x_std) x_std = F.gelu(x_std) mean = self.mean_layer(x_mean) log_std = self.log_std_layer(x_std) log_std = torch.clamp_min(self.log_std_max*torch.tanh(log_std/self.denominator),0) + \ torch.clamp_max(-self.log_std_min * torch.tanh(log_std / self.denominator), 0) #log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) return mean, log_std
def _sample_embd(norm_embd, labels, batch_size, adaptive_sampling, sampling_angle_std): with torch.no_grad(): unit_directions = F.normalize(torch.randn_like(norm_embd), dim=1) dot_prod = torch.sum(norm_embd * unit_directions, dim=1, keepdim=True) orthogonal_directions = unit_directions - dot_prod * norm_embd if adaptive_sampling and labels is not None: all_angle_std = sampling_angle_std.expand(batch_size, -1) class_indices = torch.arange(batch_size, device=labels.device) angle_std = all_angle_std[class_indices, labels].view(-1, 1) else: angle_std = sampling_angle_std angles = angle_std * torch.randn_like(dot_prod) alpha = torch.clamp_max( torch.where(angles > 0.0, angles, torch.neg(angles)), 0.5 * np.pi) cos_alpha = torch.cos(alpha) sin_alpha = torch.sin(alpha) out_norm_embd = cos_alpha * norm_embd + sin_alpha * orthogonal_directions return out_norm_embd
def overlay_img_with_density(img_path, density_map_path, output_path): """ combine output density map with image to create the red heatmap overlay :param img_path: :param density_map_path: output .torch of density map :param output_path: :return: """ img_tensor = read_image(img_path) density_map_tensor = torch.load(density_map_path) print(img_tensor.shape) print(density_map_tensor.shape) print(density_map_tensor.sum()) density_map_tensor = torch.from_numpy(density_map_tensor).unsqueeze( dim=0).unsqueeze(dim=0) print("density_map_tensor.shape", density_map_tensor.shape) # torch.Size([1, 1, 46, 82]) upsampling_density_map_tensor = nn.functional.interpolate( density_map_tensor, scale_factor=8) / 64 overlay_density_map = img_tensor.detach().clone() upsampling_density_map_tensor = ( upsampling_density_map_tensor.squeeze(dim=0) / upsampling_density_map_tensor.max() * 255) overlay_density_map[0] = torch.clamp_max( img_tensor[0] + upsampling_density_map_tensor[0] * 2, max=255) write_jpeg(overlay_density_map.type(torch.uint8), output_path, quality=100)
def _train_step( model, dataloader, loss_criterion, optimiser, lr_scheduler, l2_norm_clip=0, use_gpu=False, ): """One epoch of training""" outputs, losses = [], [] for x, y in dataloader: if use_gpu: x, y = x.cuda(), y.cuda() out = model(x) loss = loss_criterion(out, y) optimiser.zero_grad() loss.backward() optimiser.step() if l2_norm_clip > 0: # Clip weights in final layer with torch.no_grad(): norms = torch.norm(model.fc.out.weight, dim=1, keepdim=True) norms_clipped = torch.clamp_max(norms, l2_norm_clip) # Renormalise weights model.fc.out.weight.div_(norms).mul_(norms_clipped) outputs.append(out) losses.append(loss.item()) lr_scheduler.step() return outputs, losses
def regulate_len(durations, enc_out, pace: float = 1.0, mel_max_len: Optional[int] = None): """If target=None, then predicted durations are applied""" reps = torch.round(durations.float() / pace).long() dec_lens = reps.sum(dim=1) max_len = dec_lens.max() bsz, _, hid = enc_out.size() reps_padded = torch.cat([reps, (max_len - dec_lens)[:, None]], dim=1) pad_vec = torch.zeros(bsz, 1, hid, dtype=enc_out.dtype, device=enc_out.device) enc_rep = torch.cat([enc_out, pad_vec], dim=1) enc_rep = torch.repeat_interleave( enc_rep.view(-1, hid), reps_padded.view(-1), dim=0 ).view(bsz, -1, hid) # enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0) # for o, r in zip(enc_out, reps)], # batch_first=True) if mel_max_len is not None: enc_rep = enc_rep[:, :mel_max_len] dec_lens = torch.clamp_max(dec_lens, mel_max_len) return enc_rep, dec_lens
def __call__(self, model: nn.Module, images: torch.Tensor, labels: torch.Tensor) -> AdversaryOutput: lo = torch.clamp_min(images - self.epsilon, 0) hi = torch.clamp_max(images + self.epsilon, 1) result = torch.clamp( images + random_float_like(images, -self.epsilon, self.epsilon), 0, 1) result.requires_grad = True for _ in range(self.n_iters): if result.grad is not None: result.grad.detach_() result.grad.zero_() loss = self.compute_objective(model, result, labels, "mean") loss.backward() with torch.no_grad(): result += self.step_size * torch.sign(result.grad) clamp_min_tensor(result, lo) clamp_max_tensor(result, hi) result.requires_grad = False return AdversaryOutput(result, result - images)
def expmap0(self, u, c): sqrt_c = c**0.5 u_norm = torch.clamp_max( torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), self.min_norm), self.max_norm) gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm) return gamma_1
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None): """A function that takes predicted durations per encoded token, and repeats enc_out according to the duration. NOTE: durations.shape[1] == enc_out.shape[1] Args: durations (torch.tensor): A tensor of shape (batch x enc_length) that represents how many times to repeat each token in enc_out. enc_out (torch.tensor): A tensor of shape (batch x enc_length x enc_hidden) that represents the encoded tokens. pace (float): The pace of speaker. Higher values result in faster speaking pace. Defaults to 1.0. max_mel_len (int): The maximum length above which the output will be removed. If sum(durations, dim=1) > max_mel_len, the values after max_mel_len will be removed. Defaults to None, which has no max length. """ dtype = enc_out.dtype reps = durations.float() / pace reps = (reps + 0.5).long() dec_lens = reps.sum(dim=1) max_len = dec_lens.max() reps_cumsum = torch.cumsum(torch.nn.functional.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] reps_cumsum = reps_cumsum.to(dtype) range_ = torch.arange(max_len).to(enc_out.device)[None, :, None] mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] > range_) mult = mult.to(dtype) enc_rep = torch.matmul(mult, enc_out) if mel_max_len: enc_rep = enc_rep[:, :mel_max_len] dec_lens = torch.clamp_max(dec_lens, mel_max_len) return enc_rep, dec_lens
def ascend_txt(model, perceptor, t, nom, lats, la, lb): out = model(lats()) cutn, sideX, sideY = out.size()[1:] p_s = [] for ch in range(cutn): size = int(sideX * torch.zeros(1, ).normal_(mean=.8, std=.3).clip(.5, .95)) offsetx = torch.randint(0, sideX - size, ()) offsety = torch.randint(0, sideY - size, ()) apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size] apper = torch.nn.functional.interpolate(apper, (224, 224), mode='nearest') p_s.append(apper) into = torch.cat(p_s, 0) into = nom((into + 1) / 2) iii = perceptor.encode_image(into) llls = lats() lat_l = torch.abs(1 - torch.std(llls, dim=1)).mean() + torch.abs( torch.mean(llls)).mean() + 4 * torch.clamp_max( torch.square(llls).mean(), 1) for array in llls: mean = torch.mean(array) diffs = array - mean var = torch.mean(torch.pow(diffs, 2.0)) std = torch.pow(var, 0.5) zscores = diffs / std skews = torch.mean(torch.pow(zscores, 3.0)) kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0 lat_l = lat_l + torch.abs(kurtoses) / llls.shape[0] + torch.abs( skews) / llls.shape[0] return la * lat_l, -lb * torch.cosine_similarity(t, iii, dim=-1)
def depth_map_to_rgbimg(depth_map): depth_map = np.asarray( np.squeeze( (255 - torch.clamp_max(depth_map * 4, 250)).byte().numpy()), np.uint8) depth_map = np.asarray(cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB), np.uint8) return depth_map
def run_train_step(self, tensor_input, tensor_output, tensor_focal): tensor_input, tensor_output = tensor_input.to( device), tensor_output.to(device) # Get Models prediction and calculate loss model_output, depth2, depth4, depth8 = self.bts( tensor_input, tensor_focal) loss = self.criterion(model_output, tensor_output) * 1 / self.backprop_frequency if USE_APEX: with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if self.current_step % self.backprop_frequency == 0: # Make update once every x steps torch.nn.utils.clip_grad_norm_(self.bts.parameters(), 5) self.optimizer.step() self.optimizer.zero_grad() if self.current_step % 100 == 0: self.writer.add_scalar( "Loss", loss.item() * self.backprop_frequency / tensor_input.shape[0], self.current_step) if self.current_step % 1000 == 0: img = tensor_input[0].detach().transpose(0, 2).transpose( 0, 1).cpu().numpy().astype(np.uint8) self.writer.add_image("Input", img, self.current_step, None, "HWC") visual_result = (255 - torch.clamp_max( torchvision.utils.make_grid( [tensor_output[0], model_output[0]]) * 5, 250)).byte() self.writer.add_image("Output/Prediction", visual_result, self.current_step) depths = [depth2[0], depth4[0], depth8[0]] depths = [depth * MAX_DEPTH for depth in depths] depth_visual = (255 - torch.clamp_max( torchvision.utils.make_grid(depths) * 5, 250)).byte() self.writer.add_image("Depths", depth_visual, self.current_step) self.current_step += 1
def train(self, train_buffer, val_buffer, callback_fn): for epoch in range(self.args['max_epoch']): for i in range(self.args['steps_per_epoch']): batch_data = train_buffer.sample(self.batch_size) batch_data.to_torch(device=self.device) obs = batch_data['obs'] action = batch_data['act'] next_obs = batch_data['obs_next'] reward = batch_data['rew'] done = batch_data['done'].float() # update critic p = self.critic(obs, action) next_action = self.actor_target.get_action(next_obs) target_p = self.critic_target.get_target( next_obs, next_action, reward, self.gamma * (1 - done)) critic_loss = -(target_p * torch.log(p + 1e-8)).mean() self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() # update actor action_dist = self.actor(obs) log_prob = action_dist.log_prob(action) actions = torch.stack( [action_dist.sample() for _ in range(self.m)], dim=0) repeat_obs = torch.repeat_interleave(obs.unsqueeze(0), self.m, 0) _, values = self.critic(repeat_obs, actions, with_q=True) _, value = self.critic(obs, action, with_q=True) if self.advantage_mode == 'mean': advantage = value - values.mean(dim=0) elif self.advantage_mode == 'max': advantage = value - values.max(dim=0)[0] if self.weight_mode == 'exp': weight = torch.exp(advantage / self.beta) elif self.weight_mode == 'binary': weight = (advantage > 0).float() weight = torch.clamp_max(weight, 20).detach() actor_loss = -torch.mean(weight * log_prob) self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if i % self.args['update_frequency']: self._sync_weight(self.critic_target, self.critic, 1.0) self._sync_weight(self.actor_target, self.actor, 1.0) res = callback_fn(self.get_policy()) self.log_res(epoch, res) return self.get_policy()
def vtrace(self, target_policies_action, behavior_policies_action, rewards, dones, values, next_values): # * v-trace algorithm rho_s = torch.exp(target_policies_action.log() - behavior_policies_action.log()) clip_rho_s = torch.clamp_max(rho_s, self.rho) c_s = torch.clamp_max(rho_s, self.c) deltas = clip_rho_s * (rewards + self.gamma * (1 - dones) * next_values - values) trajectory_size = rho_s.size(1) acc = 0 vs_minus_v_xs = [] for i in range(trajectory_size - 1, -1, -1): # * origin paper: # * acc = deltas[:, i] + self.gamma * (1 - dones)[:, i] * c_s[:, i] * (acc - next_values[:, i]) acc = deltas[:, i] + self.gamma * (1 - dones)[:, i] * c_s[:, i] * acc vs_minus_v_xs.append(acc.view(-1, 1)) vs_minus_v_xs = torch.cat(vs_minus_v_xs[::-1], dim=1).unsqueeze(-1) vs = vs_minus_v_xs + values return clip_rho_s, vs
def weighted_binary_focal_loss(pred, target, alpha=0.25, gamma=2.): """f(x) = alpha * (1 - x)^a * -ln(sigmoid(pred)) :param pred: shape = (N,) :param target: shape = (N,) :param alpha: float The weight of the negative sample and the positive sample. (alpha * positive + (1 - alpha) * negative) :param gamma: float :return: shape = ()""" # target == -1. It's neither a positive sample nor a negative sample. return torch.sum( torch.where( target == -1, torch.tensor(0., device=target.device), alpha * (1 - pred)**gamma * target * torch.clamp_max(-torch.log(pred), 100) + (1 - alpha) * pred**gamma * (1 - target) * torch.clamp_max(-torch.log(1 - pred), 100)))
def bucketize(h_ts, t_ts, tau=21600, vmax=25): diff = (t_ts.unsqueeze(-1) - h_ts).float() / tau diff = diff.masked_fill_(diff < 0, -1) + 1 bucket_ids = torch.floor(torch.log2(diff)) bucket_ids += 1 # bucket_ids = bucket_ids.masked_fill(~torch.isfinite(bucket_ids), 0) # padding index = 0 bucket_ids = torch.clamp_min(bucket_ids, 0) bucket_ids = torch.clamp_max(bucket_ids.long(), vmax) return bucket_ids
def forward(self, pred, label): one_hot = label > 0.5 sample_weight = label != self._ignore_label if not self._from_logits: pred = torch.sigmoid(pred) alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) beta = (1 - pt)**self._gamma sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) mult = sw_sum / (beta_sum + self._eps) if self._detach_delimeter: mult = mult.detach() beta = beta * mult if self._max_mult > 0: beta = torch.clamp_max(beta, self._max_mult) with torch.no_grad(): ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range( 1, label.dim()))).cpu().numpy() sample_mult = torch.mean(mult, dim=tuple(range( 1, mult.dim()))).cpu().numpy() if np.any(ignore_area == 0): self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ ignore_area == 0].mean() beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1) beta_pmax = beta_pmax.mean().item() self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax loss = -alpha * beta * torch.log( torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) loss = self._weight * (loss * sample_weight) if self._size_average: bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion( sample_weight.dim(), self._batch_axis)) loss = torch.sum(loss, dim=misc.get_dims_with_exclusion( loss.dim(), self._batch_axis)) / (bsum + self._eps) else: loss = torch.sum(loss, dim=misc.get_dims_with_exclusion( loss.dim(), self._batch_axis)) return loss
def nn_transformation(model): b_new_list = [None for _ in range(len(model.layers))] act_shift_list = [None for _ in range(len(model.layers))] nn_model = model.get_nn_net() for i in range(len(model.layers)): cur_layer = model.layers[i] classname = cur_layer.__class__.__name__ if classname.find('Linear') != -1: sum_dim = 1 elif classname.find('Conv2d') != -1: sum_dim = (1, 2, 3) w_neg = T.clamp_max(cur_layer.weight, 0) b_tilde = cur_layer.bias - model.alpha[i] * T.sum(T.abs(w_neg), dim=sum_dim) b_new, act_shift = calc_b_new(cur_layer, b_tilde) b_new_list[i] = b_new act_shift_list[i] = act_shift w_neg_abs = T.abs(w_neg) w_pos = T.clamp_min(cur_layer.weight, 0) if classname.find('Linear') != -1: new_layer = NNLinear( cur_layer.in_features, cur_layer.out_features, w_pos, w_neg_abs, b_new, model.alpha[i], act_shift, ) elif classname.find('Conv2d') != -1: new_layer = NNConv2d( cur_layer.in_channels, cur_layer.out_channels, cur_layer.kernel_size, w_pos, w_neg_abs, b_new, model.alpha[i], act_shift, stride=cur_layer.stride, padding=cur_layer.padding, dilation=cur_layer.dilation, groups=cur_layer.groups, padding_mode=cur_layer.padding_mode, ) else: raise Exception("This layer type cannot be converted") nn_model.add_layer(new_layer) return nn_model
def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5): """ Sample @N_importance samples from @bins with distribution defined by @weights. Inputs: bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" weights: (N_rays, N_samples_) N_importance: the number of samples to draw from the distribution det: deterministic or not eps: a small number to prevent division by zero Outputs: samples: the sampled samples """ N_rays, N_samples_ = weights.shape weights = weights + eps # prevent division by zero (don't do inplace op!) pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) cdf = torch.cumsum( pdf, -1) # (N_rays, N_samples), cumulative distribution function cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (N_rays, N_samples_+1) # padded to 0~1 inclusive if det: u = torch.linspace(0, 1, N_importance, device=bins.device) u = u.expand(N_rays, N_importance) else: u = torch.rand(N_rays, N_importance, device=bins.device) u = u.contiguous() inds = searchsorted(cdf, u, side='right') below = torch.clamp_min(inds - 1, 0) above = torch.clamp_max(inds, N_samples_) inds_sampled = torch.stack([below, above], -1).view(N_rays, 2 * N_importance) cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) print(u[0], inds[0], cdf[0], below[0], above[0]) denom = cdf_g[..., 1] - cdf_g[..., 0] denom[ denom < eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled # anyway, therefore any value for it is fine (set to 1 here) samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / denom * (bins_g[..., 1] - bins_g[..., 0]) print(inds_sampled[0], cdf_g[0], bins_g[0], denom[0]) print(samples[0]) return samples
def get_deformation(self, coords, feat, c=None, h_sample=1e-3): coords_neighbor = coords + \ (torch.rand_like(coords) * h_sample - (h_sample / 2.)) coords_neighbor = torch.clamp_max(coords_neighbor, 0, 1) weighted_feat = coords[:, :, None] * feat coords_encoded = self.position_encoding(coords_neighbor, self.B) deform_neighbor = self.decoder(torch.cat( [coords_encoded, weighted_feat], dim=-1), c=c, only_displacment=True) return deform_neighbor
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None): """If target=None, then predicted durations are applied""" reps = torch.round(durations.float() * pace).long() dec_lens = reps.sum(dim=1) enc_rep = pad_sequence( [torch.repeat_interleave(o, r, dim=0) for o, r in zip(enc_out, reps)], batch_first=True) if mel_max_len: enc_rep = enc_rep[:, :mel_max_len] dec_lens = torch.clamp_max(dec_lens, mel_max_len) return enc_rep, dec_lens