예제 #1
0
def threshold(x: torch.Tensor, method: str, alpha: float) -> torch.Tensor:
    if method == "heaviside":
        return heaviside(x)
    elif method == "super":
        return superspike_fn(x, torch.as_tensor(alpha))
    elif method == "tanh":
        return heavi_tanh_fn(x, alpha)
    elif method == "tent":
        return heavi_tent_fn(x, alpha)
    elif method == "circ":
        return heavi_circ_fn(x, alpha)
    elif method == "heavi_erfc":
        return heavi_erfc_fn(x, alpha)
    else:
        raise ValueError(
            f"Attempted to apply threshold function {method}, but no such " +
            "function exist. We currently support heaviside, super, " +
            "tanh, tent, circ, and heavi_erfc.")
예제 #2
0
파일: stdp.py 프로젝트: weilongzheng/norse
    def __init__(
        self,
        a_pre: torch.Tensor = torch.as_tensor(1.0),
        a_post: torch.Tensor = torch.as_tensor(1.0),
        tau_pre_inv: torch.Tensor = torch.as_tensor(1.0 / 50e-3),
        tau_post_inv: torch.Tensor = torch.as_tensor(1.0 / 50e-3),
        w_min: torch.Tensor = 0.0,
        w_max: torch.Tensor = 1.0,
        eta_plus: torch.Tensor = 1e-3,
        eta_minus: torch.Tensor = 1e-3,
        stdp_algorithm: str = "additive",
        mu: torch.Tensor = 0.0,
        hardbound: bool = True,
        convolutional: bool = False,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
    ):
        self.a_pre = a_pre
        self.a_post = a_post
        self.tau_pre_inv = tau_pre_inv
        self.tau_post_inv = tau_post_inv
        self.w_min = w_min
        self.w_max = w_max
        self.eta_plus = eta_plus
        self.eta_minus = eta_minus

        self.stdp_algorithm = stdp_algorithm
        if self.stdp_algorithm == "additive":
            self.mu = torch.tensor(0.0)
            self.A_plus = lambda w: self.eta_plus
            self.A_minus = lambda w: self.eta_minus
        elif self.stdp_algorithm == "additive_step":
            self.mu = torch.tensor(0.0)
            self.A_plus = lambda w: self.eta_plus * heaviside(self.w_max - w)
            self.A_minus = lambda w: self.eta_minus * heaviside(w - self.w_min)
        elif self.stdp_algorithm == "multiplicative_pow":
            self.mu = torch.tensor(mu)
            self.A_plus = lambda w: self.eta_plus * torch.pow(
                self.w_max - w, self.mu)
            self.A_minus = lambda w: self.eta_minus * torch.pow(
                w - self.w_min, self.mu)
        elif self.stdp_algorithm == "multiplicative_relu":
            self.mu = torch.tensor(1.0)
            self.A_plus = lambda w: self.eta_plus * torch.nn.functional.relu(
                self.w_max - w)
            self.A_minus = lambda w: self.eta_minus * torch.nn.functional.relu(
                w - self.w_min)

        # Hard bounds
        self.hardbound = hardbound
        if self.hardbound:
            self.bounding_func = lambda w: torch.clamp(w, w_min, w_max)
        else:
            self.bounding_func = lambda w: w

        # Conv2D
        self.convolutional = convolutional
        if self.convolutional:
            self.stride = stride
            self.padding = padding
            self.dilation = dilation
예제 #3
0
 def forward(ctx, input_tensor: torch.Tensor, alpha: float) -> torch.Tensor:
     ctx.save_for_backward(input_tensor)
     ctx.alpha = alpha
     return heaviside(input_tensor)
예제 #4
0
def test_heaviside():
    assert torch.equal(heaviside(torch.ones(100)), torch.ones(100))
    assert torch.equal(heaviside(-1.0 * torch.ones(100)), torch.zeros(100))
예제 #5
0
 def forward(ctx, x, alpha):
     ctx.save_for_backward(x)
     ctx.alpha = alpha
     return heaviside(x)  # 0.5 + 0.5 * (x / (x.pow(2) + alpha ** 2).sqrt())
예제 #6
0
 def forward(ctx, x, k):
     ctx.save_for_backward(x)
     ctx.k = k
     return heaviside(x)  # 0.5 + 0.5 * torch.tanh(k * x)
예제 #7
0
 def forward(ctx, x, alpha):
     ctx.alpha = alpha
     ctx.save_for_backward(x)
     return heaviside(x)
예제 #8
0
def test_forward():
    assert torch.equal(super_fn(torch.ones(100), 100.0),
                       heaviside(torch.ones(100)))
    assert torch.equal(super_fn(-1.0 * torch.ones(100), 100.0),
                       torch.zeros(100))