def hsigmoid(x, slope=0.2, offset=0.5): """ pytorch中也自带有该函数,但是参数和paddle的有差别,公式为 x / 6 + 0.5, paddle里的是 x / 5 + 0.5 :param x: 输入 :param slope: :param offset: :return: """ out = x * slope + offset torch.clamp_max_(out, 1) torch.clamp_min_(out, 0) return out
def preprocess_adj(self, edges): device = next(self.parameters()).device edges = torch.cat((edges, edges.flip(dims=[1, ])), dim=0) # shape=[E * 2, 2] adj = torch.sparse.FloatTensor(edges.transpose(0, 1), torch.ones(edges.shape[0], device=device)) M, N = adj.shape assert M == N if not self.residual: self_loop = torch.arange(N, device=device).reshape(-1, 1).repeat(1, 2) # shape = [N, 2] self_loop = torch.sparse.FloatTensor(self_loop.transpose(0, 1), torch.ones(self_loop.shape[0], device=device)) adj = adj + self_loop adj = adj.coalesce() torch.clamp_max_(adj._values(), 1) return adj
def forward(self, x, t): # pylint: disable=arguments-differ t_zeroone = t.clone() t_zeroone[t_zeroone > 0.0] = 1.0 # x = torch.clamp(x, -20.0, 20.0) if self.background_clamp is not None: bg_clamp_mask = (t_zeroone == 0.0) & (x < self.background_clamp) x[bg_clamp_mask] = self.background_clamp bce = torch.nn.functional.binary_cross_entropy_with_logits( x, t_zeroone, reduction='none') # torch.clamp_max_(bce, 10.0) if self.soft_clamp is not None: bce = self.soft_clamp(bce) if self.min_bce > 0.0: torch.clamp_min_(bce, self.min_bce) if self.focal_gamma != 0.0: p = torch.sigmoid(x) pt = p * t_zeroone + (1 - p) * (1 - t_zeroone) # Above code is more stable than deriving pt from bce: pt = torch.exp(-bce) if self.focal_clamp and self.min_bce > 0.0: pt_threshold = math.exp(-self.min_bce) torch.clamp_max_(pt, pt_threshold) focal = 1.0 - pt if self.focal_gamma != 1.0: focal = (focal + 1e-4)**self.focal_gamma if self.focal_detach: focal = focal.detach() bce = focal * bce if self.focal_alpha == 0.5: bce = 0.5 * bce elif self.focal_alpha >= 0.0: alphat = self.focal_alpha * t_zeroone + (1 - self.focal_alpha) * ( 1 - t_zeroone) bce = alphat * bce weight_mask = t_zeroone != t bce[weight_mask] = bce[weight_mask] * t[weight_mask] if self.background_weight != 1.0: bg_weight = torch.ones_like(t, requires_grad=False) bg_weight[t == 0] *= self.background_weight bce = bce * bg_weight return bce
def forward(self, x): x = self.pad(x) x = self.conv(x) x = self.relu(x) x = torch.clamp_max_(x, 20) x = self.dropout(x) return x
def _calc_loss(self, x, y): # Leaf loss leafs_pred = F.log_softmax(self.leafs, dim=1) probs_right = torch.sigmoid( self.betas * torch.addmm(self.biases, x, self.weights.t())) leaf_path_probs = self._calc_path_probs(probs_right, self.ancestors_leafs) # loss_leafs = torch.sum(leaf_path_probs * y.matmul(leafs_pred.t()), dim=1).neg().log().mean() # Loss # according to paper, yet this diverges after a couple of epochs loss_leafs = torch.sum(leaf_path_probs * y.matmul(leafs_pred.t()), dim=1).neg().mean() # Regularization inners: tree balancing by binary cross-entropy with discrete uniform(2) distribution inner_path_probs = self._calc_path_probs(probs_right, self.ancestors_inners) # clamps to avoid errors in binary_cross_entropy alpha_inners = torch.clamp_max_( torch.sum(inner_path_probs * probs_right, dim=0) / torch.sum(inner_path_probs, dim=0).clamp_min_(1e-5), 1) loss_inners = F.binary_cross_entropy(alpha_inners, self._alpha_target, weight=self.lambda_per_inner, reduction='sum') total_loss = loss_leafs + loss_inners return total_loss, loss_leafs, loss_inners
def pick_action_and_get_log_probabilities(self, state=None): """Picks actions and then calculates the log probabilities of the actions it picked given the policy""" if state is None: state = self.state state = unwrap_state(state, device=self.device) state, current_targets = self.create_state_vector(state) action_logits, hidden_state = self.policy(state) beta_logits = self.beta(hidden_state.detach(), current_targets) with torch.no_grad(): beta_probs = beta_logits.softmax(dim=-1) # need the prob_min because can not be 0 and large logits lead to tiny softmax beta_samples = torch.multinomial(beta_probs + PROB_MIN, self.k) beta_prob = beta_probs.gather(1, beta_samples) ppo_weight = None action_log_prob = action_logits.log_softmax(dim=-1) if self.use_ppo: with torch.no_grad(): curr_samples = torch.multinomial( action_log_prob.exp() + PROB_MIN, self.k) curr_prob = action_log_prob.gather(1, curr_samples) action_logits_last, _ = self.last_policy(state) action_prob_last = action_logits_last.softmax(dim=-1).gather( 1, curr_samples) ppo_weight = curr_prob.exp() / (action_prob_last + 1e-8) action_prob = action_log_prob.gather(1, beta_samples) correction = torch.clamp_max_( torch.exp(action_prob) / beta_prob, CLIPPING_VALUE).detach() return beta_samples.cpu().detach().numpy( ), action_prob, correction, ppo_weight
def forward(self, x): shape = x.shape x = x.permute(0, 2, 1) x = self.linear(x) x = self.relu(x) x = torch.clamp_max_(x, 20) x = self.dropout(x) x = x.permute(0, 2, 1) return x
def _compute(self, current, target): """ Updates the target parameter(s) based on the current parameter(s). Args: current (int, float, torch.tensor, np.array, torch.nn.Module): current parameter(s). target (int, float, torch.tensor, np.array, torch.nn.Module): target parameter(s) to be modified based on the current parameter(s). Returns: int, float, torch.tensor, np.array, torch.nn.Module: updated target parameter(s). """ if isinstance(target, torch.nn.Module): if isinstance(current, torch.nn.Module): for p1, p2 in zip(target.parameters(), current.parameters()): data = p2.data + self.speed * p2.data * self.dt if self.speed < 0: torch.clamp_min_(data, self.end) else: torch.clamp_max_(data, self.end) p1.data.copy_(data) elif isinstance(target, torch.Tensor): data = current.data + self.speed * current.data * self.dt if self.speed < 0: torch.clamp_min_(data, self.end) else: torch.clamp_max_(data, self.end) target.data.copy_(data) elif isinstance(target, np.ndarray): target[:] = current + self.speed * current * self.dt if (self.speed < 0 and target < self.end) or (self.speed > 0 and target > self.end): target[:] = self.end else: target = current + self.speed * current * self.dt if (self.speed < 0 and target < self.end) or (self.speed > 0 and target > self.end): target = self.end return target
def compute_act_stabilizing_loss_abstract(self, inputs: torch.Tensor, eps: float, inputs_min: float = 0, inputs_max: float = 1): """compute an extra loss for stabilizing the activations using abstract interpretation :return: loss value """ loss = torch.tensor(0, dtype=torch.float32, device=inputs.device) with torch.no_grad(): imin = torch.clamp_min_(inputs - eps, inputs_min) imax = torch.clamp_max_(inputs + eps, inputs_max) return self.forward(AbstractTensor(imin, imax, loss)).loss
def forward_with_multi_sample(self, x: torch.Tensor, x_adv: torch.Tensor, eps: float, inputs_min: float = 0, inputs_max: float = 1): """forward with randomly sampled perturbations and compute a stabilization loss """ data = [x_adv, None, None] eps = float(eps) with torch.no_grad(): delta = torch.empty_like(x).random_(0, 2).mul_(2 * eps).sub_(eps) data[1] = torch.clamp_min_(x - delta, inputs_min) data[2] = torch.clamp_max_(x + delta, inputs_max) data = torch.cat([i[np.newaxis] for i in data], dim=0) y = self.forward(MultiSampleTensor.from_squeeze(data)) return y.as_expanded_tensor()[0], y.loss
def pick_action_and_log_probs(self, state=None): if state is None: state = self.state state = unwrap_state(state, device=self.device) state, targets = self.create_state_vector(state) if self.hyperparameters["batch_rl"]: # Get the log probs for the target policy actions, action_log_probs = self.agent.log_probs_for_actions(state, targets) # Update the off-policy network for IS weights (behavior policy approximation) beta_logits = self.off_policy_agent(state.detach()) beta_log_probs = beta_logits.log_softmax(dim=-1) beta_log_probs = beta_log_probs[torch.arange(beta_log_probs.size(0)), targets] self.update_off_policy_agent(beta_log_probs) is_weights = torch.clamp_max_(torch.exp(action_log_probs) / torch.exp(beta_log_probs), CLIPPING_VALUE).detach() return actions, action_log_probs, is_weights action_trajectory, action_log_probs = self.agent(state, deterministic=False) actions = self.action_trajectory_to_action(action_trajectory) # if self.masking_enabled: # actions, mask = self.mask_action_output(actions) return actions, action_log_probs, None
def forward(self, x): x = self.relu(x) + self.bias if self.max_value is not None: x = torch.clamp_max_(x, self.max_value) return x
def forward(self, x: torch.Tensor, hc=None): ''' * :ref:`API in English <SpikingLSTMCell.forward-en>` .. _SpikingLSTMCell.forward-cn: :param x: ``shape = [batch_size, input_size]`` 的输入 :type x: torch.Tensor :param hc: (h_0, c_0) h_0 : torch.Tensor ``shape = [batch_size, hidden_size]``,起始隐藏状态 c_0 : torch.Tensor ``shape = [batch_size, hidden_size]``,起始细胞状态 如果不提供(h_0, c_0),``h_0`` 默认 ``c_0`` 默认为0 :type hc: tuple or None :return: (h_1, c_1) : h_1 : torch.Tensor ``shape = [batch_size, hidden_size]``,下一个时刻的隐藏状态 c_1 : torch.Tensor ``shape = [batch_size, hidden_size]``,下一个时刻的细胞状态 :rtype: tuple * :ref:`中文API <SpikingLSTMCell.forward-cn>` .. _SpikingLSTMCell.forward-en: :param x: the input tensor with ``shape = [batch_size, input_size]`` :type x: torch.Tensor :param hc: (h_0, c_0) h_0 : torch.Tensor ``shape = [batch_size, hidden_size]``, tensor containing the initial hidden state for each element in the batch c_0 : torch.Tensor ``shape = [batch_size, hidden_size]``, tensor containing the initial cell state for each element in the batch If (h_0, c_0) is not provided, both ``h_0`` and ``c_0`` default to zero :type hc: tuple or None :return: (h_1, c_1) : h_1 : torch.Tensor ``shape = [batch_size, hidden_size]``, tensor containing the next hidden state for each element in the batch c_1 : torch.Tensor ``shape = [batch_size, hidden_size]``, tensor containing the next cell state for each element in the batch :rtype: tuple ''' if hc is None: h = torch.zeros(size=[x.shape[0], self.hidden_size], dtype=torch.float, device=x.device) c = torch.zeros_like(h) else: h = hc[0] c = hc[1] if self.surrogate_function2 is None: i, f, g, o = torch.split(self.surrogate_function1( self.linear_ih(x) + self.linear_hh(h)), self.hidden_size, dim=1) else: i, f, g, o = torch.split(self.linear_ih(x) + self.linear_hh(h), self.hidden_size, dim=1) i = self.surrogate_function1(i) f = self.surrogate_function1(f) g = self.surrogate_function2(g) o = self.surrogate_function1(o) if self.surrogate_function2 is not None: assert self.surrogate_function1.spiking == self.surrogate_function2.spiking c = c * f + i * g ''' according to the origin paper: Notice that c can take the values 0, 1, or 2. Since the gradients around 2 are not as informative, we threshold this output to output 1 when it is 1 or 2. We approximate the gradients of this step function with γ that take two values 1 or ≤ 1. ''' with torch.no_grad(): torch.clamp_max_(c, 1.) h = c * o return h, c