Ejemplo n.º 1
0
def hsigmoid(x, slope=0.2, offset=0.5):
    """
    pytorch中也自带有该函数,但是参数和paddle的有差别,公式为 x / 6 + 0.5, paddle里的是 x / 5 + 0.5
    :param x: 输入
    :param slope:
    :param offset:
    :return:
    """
    out = x * slope + offset
    torch.clamp_max_(out, 1)
    torch.clamp_min_(out, 0)
    return out
Ejemplo n.º 2
0
 def preprocess_adj(self, edges):
     device = next(self.parameters()).device
     edges = torch.cat((edges, edges.flip(dims=[1, ])), dim=0)  # shape=[E * 2, 2]
     adj = torch.sparse.FloatTensor(edges.transpose(0, 1), torch.ones(edges.shape[0], device=device))
     M, N = adj.shape
     assert M == N
     if not self.residual:
         self_loop = torch.arange(N, device=device).reshape(-1, 1).repeat(1, 2)  # shape = [N, 2]
         self_loop = torch.sparse.FloatTensor(self_loop.transpose(0, 1),
                                              torch.ones(self_loop.shape[0], device=device))
         adj = adj + self_loop
     adj = adj.coalesce()
     torch.clamp_max_(adj._values(), 1)
     return adj
Ejemplo n.º 3
0
    def forward(self, x, t):  # pylint: disable=arguments-differ
        t_zeroone = t.clone()
        t_zeroone[t_zeroone > 0.0] = 1.0
        # x = torch.clamp(x, -20.0, 20.0)
        if self.background_clamp is not None:
            bg_clamp_mask = (t_zeroone == 0.0) & (x < self.background_clamp)
            x[bg_clamp_mask] = self.background_clamp
        bce = torch.nn.functional.binary_cross_entropy_with_logits(
            x, t_zeroone, reduction='none')
        # torch.clamp_max_(bce, 10.0)
        if self.soft_clamp is not None:
            bce = self.soft_clamp(bce)
        if self.min_bce > 0.0:
            torch.clamp_min_(bce, self.min_bce)

        if self.focal_gamma != 0.0:
            p = torch.sigmoid(x)
            pt = p * t_zeroone + (1 - p) * (1 - t_zeroone)
            # Above code is more stable than deriving pt from bce: pt = torch.exp(-bce)

            if self.focal_clamp and self.min_bce > 0.0:
                pt_threshold = math.exp(-self.min_bce)
                torch.clamp_max_(pt, pt_threshold)

            focal = 1.0 - pt
            if self.focal_gamma != 1.0:
                focal = (focal + 1e-4)**self.focal_gamma

            if self.focal_detach:
                focal = focal.detach()

            bce = focal * bce

        if self.focal_alpha == 0.5:
            bce = 0.5 * bce
        elif self.focal_alpha >= 0.0:
            alphat = self.focal_alpha * t_zeroone + (1 - self.focal_alpha) * (
                1 - t_zeroone)
            bce = alphat * bce

        weight_mask = t_zeroone != t
        bce[weight_mask] = bce[weight_mask] * t[weight_mask]

        if self.background_weight != 1.0:
            bg_weight = torch.ones_like(t, requires_grad=False)
            bg_weight[t == 0] *= self.background_weight
            bce = bce * bg_weight

        return bce
Ejemplo n.º 4
0
 def forward(self, x):
     x = self.pad(x)
     x = self.conv(x)
     x = self.relu(x)
     x = torch.clamp_max_(x, 20)
     x = self.dropout(x)
     return x
Ejemplo n.º 5
0
    def _calc_loss(self, x, y):
        # Leaf loss
        leafs_pred = F.log_softmax(self.leafs, dim=1)
        probs_right = torch.sigmoid(
            self.betas * torch.addmm(self.biases, x, self.weights.t()))
        leaf_path_probs = self._calc_path_probs(probs_right,
                                                self.ancestors_leafs)
        # loss_leafs = torch.sum(leaf_path_probs * y.matmul(leafs_pred.t()), dim=1).neg().log().mean() # Loss
        # according to paper, yet this diverges after a couple of epochs
        loss_leafs = torch.sum(leaf_path_probs * y.matmul(leafs_pred.t()),
                               dim=1).neg().mean()

        # Regularization inners: tree balancing by binary cross-entropy with discrete uniform(2) distribution
        inner_path_probs = self._calc_path_probs(probs_right,
                                                 self.ancestors_inners)

        # clamps to avoid errors in binary_cross_entropy
        alpha_inners = torch.clamp_max_(
            torch.sum(inner_path_probs * probs_right, dim=0) /
            torch.sum(inner_path_probs, dim=0).clamp_min_(1e-5), 1)

        loss_inners = F.binary_cross_entropy(alpha_inners,
                                             self._alpha_target,
                                             weight=self.lambda_per_inner,
                                             reduction='sum')

        total_loss = loss_leafs + loss_inners

        return total_loss, loss_leafs, loss_inners
Ejemplo n.º 6
0
    def pick_action_and_get_log_probabilities(self, state=None):
        """Picks actions and then calculates the log probabilities of the actions it picked given the policy"""
        if state is None:
            state = self.state
        state = unwrap_state(state, device=self.device)
        state, current_targets = self.create_state_vector(state)

        action_logits, hidden_state = self.policy(state)
        beta_logits = self.beta(hidden_state.detach(), current_targets)

        with torch.no_grad():
            beta_probs = beta_logits.softmax(dim=-1)
            # need the prob_min because can not be 0 and large logits lead to tiny softmax
            beta_samples = torch.multinomial(beta_probs + PROB_MIN, self.k)
            beta_prob = beta_probs.gather(1, beta_samples)

        ppo_weight = None
        action_log_prob = action_logits.log_softmax(dim=-1)
        if self.use_ppo:
            with torch.no_grad():
                curr_samples = torch.multinomial(
                    action_log_prob.exp() + PROB_MIN, self.k)
                curr_prob = action_log_prob.gather(1, curr_samples)

                action_logits_last, _ = self.last_policy(state)
                action_prob_last = action_logits_last.softmax(dim=-1).gather(
                    1, curr_samples)

                ppo_weight = curr_prob.exp() / (action_prob_last + 1e-8)

        action_prob = action_log_prob.gather(1, beta_samples)
        correction = torch.clamp_max_(
            torch.exp(action_prob) / beta_prob, CLIPPING_VALUE).detach()
        return beta_samples.cpu().detach().numpy(
        ), action_prob, correction, ppo_weight
Ejemplo n.º 7
0
 def forward(self, x):
     shape = x.shape
     x = x.permute(0, 2, 1)
     x = self.linear(x)
     x = self.relu(x)
     x = torch.clamp_max_(x, 20)
     x = self.dropout(x)
     x = x.permute(0, 2, 1)
     return x
    def _compute(self, current, target):
        """
        Updates the target parameter(s) based on the current parameter(s).

        Args:
            current (int, float, torch.tensor, np.array, torch.nn.Module): current parameter(s).
            target (int, float, torch.tensor, np.array, torch.nn.Module): target parameter(s) to be modified based on
            the current parameter(s).

        Returns:
            int, float, torch.tensor, np.array, torch.nn.Module: updated target parameter(s).
        """
        if isinstance(target, torch.nn.Module):
            if isinstance(current, torch.nn.Module):
                for p1, p2 in zip(target.parameters(), current.parameters()):
                    data = p2.data + self.speed * p2.data * self.dt
                    if self.speed < 0:
                        torch.clamp_min_(data, self.end)
                    else:
                        torch.clamp_max_(data, self.end)
                    p1.data.copy_(data)
        elif isinstance(target, torch.Tensor):
            data = current.data + self.speed * current.data * self.dt
            if self.speed < 0:
                torch.clamp_min_(data, self.end)
            else:
                torch.clamp_max_(data, self.end)
            target.data.copy_(data)
        elif isinstance(target, np.ndarray):
            target[:] = current + self.speed * current * self.dt
            if (self.speed < 0
                    and target < self.end) or (self.speed > 0
                                               and target > self.end):
                target[:] = self.end
        else:
            target = current + self.speed * current * self.dt
            if (self.speed < 0
                    and target < self.end) or (self.speed > 0
                                               and target > self.end):
                target = self.end
        return target
Ejemplo n.º 9
0
    def compute_act_stabilizing_loss_abstract(self,
                                              inputs: torch.Tensor,
                                              eps: float,
                                              inputs_min: float = 0,
                                              inputs_max: float = 1):
        """compute an extra loss for stabilizing the activations using abstract
            interpretation

        :return: loss value
        """
        loss = torch.tensor(0, dtype=torch.float32, device=inputs.device)
        with torch.no_grad():
            imin = torch.clamp_min_(inputs - eps, inputs_min)
            imax = torch.clamp_max_(inputs + eps, inputs_max)
        return self.forward(AbstractTensor(imin, imax, loss)).loss
Ejemplo n.º 10
0
 def forward_with_multi_sample(self,
                               x: torch.Tensor,
                               x_adv: torch.Tensor,
                               eps: float,
                               inputs_min: float = 0,
                               inputs_max: float = 1):
     """forward with randomly sampled perturbations and compute a
     stabilization loss """
     data = [x_adv, None, None]
     eps = float(eps)
     with torch.no_grad():
         delta = torch.empty_like(x).random_(0, 2).mul_(2 * eps).sub_(eps)
         data[1] = torch.clamp_min_(x - delta, inputs_min)
         data[2] = torch.clamp_max_(x + delta, inputs_max)
         data = torch.cat([i[np.newaxis] for i in data], dim=0)
     y = self.forward(MultiSampleTensor.from_squeeze(data))
     return y.as_expanded_tensor()[0], y.loss
Ejemplo n.º 11
0
    def pick_action_and_log_probs(self, state=None):
        if state is None:
            state = self.state
        state = unwrap_state(state, device=self.device)
        state, targets = self.create_state_vector(state)
        if self.hyperparameters["batch_rl"]:
            # Get the log probs for the target policy
            actions, action_log_probs = self.agent.log_probs_for_actions(state, targets)

            # Update the off-policy network for IS weights (behavior policy approximation)
            beta_logits = self.off_policy_agent(state.detach())
            beta_log_probs = beta_logits.log_softmax(dim=-1)
            beta_log_probs = beta_log_probs[torch.arange(beta_log_probs.size(0)), targets]
            self.update_off_policy_agent(beta_log_probs)

            is_weights = torch.clamp_max_(torch.exp(action_log_probs) / torch.exp(beta_log_probs),
                                          CLIPPING_VALUE).detach()
            return actions, action_log_probs, is_weights

        action_trajectory, action_log_probs = self.agent(state, deterministic=False)
        actions = self.action_trajectory_to_action(action_trajectory)
        # if self.masking_enabled:
        #     actions, mask = self.mask_action_output(actions)
        return actions, action_log_probs, None
Ejemplo n.º 12
0
 def forward(self, x):
     x = self.relu(x) + self.bias
     if self.max_value is not None:
         x = torch.clamp_max_(x, self.max_value)
     return x
Ejemplo n.º 13
0
    def forward(self, x: torch.Tensor, hc=None):
        '''
        * :ref:`API in English <SpikingLSTMCell.forward-en>`

        .. _SpikingLSTMCell.forward-cn:

        :param x: ``shape = [batch_size, input_size]`` 的输入
        :type x: torch.Tensor

        :param hc: (h_0, c_0)
                h_0 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``,起始隐藏状态
                c_0 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``,起始细胞状态
                如果不提供(h_0, c_0),``h_0`` 默认 ``c_0`` 默认为0
        :type hc: tuple or None
        :return: (h_1, c_1) :
                h_1 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``,下一个时刻的隐藏状态
                c_1 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``,下一个时刻的细胞状态
        :rtype: tuple

        * :ref:`中文API <SpikingLSTMCell.forward-cn>`

        .. _SpikingLSTMCell.forward-en:

        :param x: the input tensor with ``shape = [batch_size, input_size]``
        :type x: torch.Tensor

        :param hc: (h_0, c_0)
                h_0 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``, tensor containing the initial hidden state for each element in the batch
                c_0 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``, tensor containing the initial cell state for each element in the batch
                If (h_0, c_0) is not provided, both ``h_0`` and ``c_0`` default to zero
        :type hc: tuple or None
        :return: (h_1, c_1) :
                h_1 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``, tensor containing the next hidden state for each element in the batch
                c_1 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``, tensor containing the next cell state for each element in the batch
        :rtype: tuple
        '''
        if hc is None:
            h = torch.zeros(size=[x.shape[0], self.hidden_size],
                            dtype=torch.float,
                            device=x.device)
            c = torch.zeros_like(h)
        else:
            h = hc[0]
            c = hc[1]

        if self.surrogate_function2 is None:
            i, f, g, o = torch.split(self.surrogate_function1(
                self.linear_ih(x) + self.linear_hh(h)),
                                     self.hidden_size,
                                     dim=1)
        else:
            i, f, g, o = torch.split(self.linear_ih(x) + self.linear_hh(h),
                                     self.hidden_size,
                                     dim=1)
            i = self.surrogate_function1(i)
            f = self.surrogate_function1(f)
            g = self.surrogate_function2(g)
            o = self.surrogate_function1(o)

        if self.surrogate_function2 is not None:
            assert self.surrogate_function1.spiking == self.surrogate_function2.spiking

        c = c * f + i * g
        '''
        according to the origin paper:
            Notice that c can take the values 0, 1, or 2. Since the gradients around 2 are not as informative, we threshold this output to output 1 when it is 1 or 2. We approximate the gradients of this step function with γ that take two values 1 or ≤ 1.
        '''

        with torch.no_grad():
            torch.clamp_max_(c, 1.)

        h = c * o

        return h, c