class twnLinear(nn.Linear):
    """
    custom Linear layers for quantization
    """
    def __init__(self, in_features, out_features, bias=True, cRate=0.7):
        super(twnLinear, self).__init__(in_features=in_features,
                                        out_features=out_features,
                                        bias=bias)

        self.weight_ternary = Parameter(torch.zeros(self.weight.data.size()))
        self.weight_alpha = Parameter(torch.ones(1))
        self.weight_delta = 0

        self.cRate = cRate

    def compute_grad(self):
        self.weight.grad = self.weight_ternary.grad
        # print self.weight_ternary.grad.data
        # print "alpha:", self.weight_alpha, "delta: ", self.weight_delta

    def forward(self, input):

        self.weight_delta = self.cRate * \
            self.weight.abs().mean().clamp(min=0, max=10).data[0]

        self.weight_ternary.data.copy_(
            (self.weight.gt(self.weight_delta).float() -
             self.weight.lt(-self.weight_delta).float()).data)

        self.weight_alpha.data.copy_(
            ((self.weight.abs() * self.weight_ternary.abs()).sum() /
             self.weight_ternary.abs().sum()).clamp(min=0, max=10).data)

        return F.linear(input * self.weight_alpha.data[0], self.weight_ternary,
                        self.bias)
Ejemplo n.º 2
0
class XNORConv2d(Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=1,
                 stride=1,
                 padding=0,
                 groups=1,
                 bias=True,
                 dropout_ratio=0):
        super(XNORConv2d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.groups = groups

        self.conv = Conv2d(in_channels=in_channels,
                           out_channels=out_channels,
                           kernel_size=kernel_size,
                           stride=stride,
                           padding=padding,
                           groups=groups)
        self.conv.weight.data.normal_(0, 0.05)
        self.conv.bias.data.zero_()

        self.fp_weights = Parameter(zeros(self.conv.weight.size()))
        self.fp_weights.data.copy_(self.conv.weight.data)

    def forward(self, x):

        self.fp_weights.data = self.fp_weights.data - self.fp_weights.data.mean(
            1, keepdim=True)
        self.fp_weights.data.clamp_(-1, 1)
        self.mean_val = self.fp_weights.abs().view(self.out_channels,
                                                   -1).mean(1, keepdim=True)

        self.conv.weight.data.copy_(self.fp_weights.data.sign() *
                                    self.mean_val.view(-1, 1, 1, 1))
        x = self.conv(x)

        return x

    def update_gradient(self):
        proxy = self.fp_weights.abs().sign()
        proxy[self.fp_weights.data.abs() > 1] = 0
        binary_grad = self.conv.weight.grad * self.mean_val.view(-1, 1, 1,
                                                                 1) * proxy

        mean_grad = self.conv.weight.data.sign() * self.conv.weight.grad
        mean_grad = mean_grad.view(self.out_channels,
                                   -1).mean(1).view(-1, 1, 1, 1)
        mean_grad = mean_grad * self.conv.weight.data.sign()

        self.fp_weights.grad = binary_grad + mean_grad
        self.fp_weights.grad = self.fp_weights.grad * self.fp_weights.data[
            0].nelement() * (1 - 1 / self.fp_weights.data.size(1))
Ejemplo n.º 3
0
class FactorizedSpatialTransformerPyramid2d(SpatialTransformerPyramid2d):
    def __init__(self,
                 in_shape,
                 outdims,
                 scale_n=4,
                 positive=False,
                 bias=True,
                 init_range=.1,
                 downsample=True,
                 type=None):
        super(SpatialTransformerPyramid2d, self).__init__()
        self.in_shape = in_shape
        c, w, h = in_shape
        self.outdims = outdims
        self.positive = positive
        self.gauss_pyramid = Pyramid(scale_n=scale_n,
                                     downsample=downsample,
                                     type=type)
        self.grid = Parameter(torch.Tensor(1, outdims, 1, 2))
        self.feature_scales = Parameter(
            torch.Tensor(1, scale_n + 1, 1, outdims))
        self.feature_channels = Parameter(torch.Tensor(1, 1, c, outdims))

        if bias:
            bias = Parameter(torch.Tensor(outdims))
            self.register_parameter('bias', bias)
        else:
            self.register_parameter('bias', None)
        self.init_range = init_range
        self.initialize()

    @property
    def features(self):
        return (self.feature_scales * self.feature_channels).view(
            1, -1, 1, self.outdims)

    def scale_l1(self, average=True):
        if average:
            return self.feature_scales.abs().mean()
        else:
            return self.feature_scales.abs().sum()

    def channel_l1(self, average=True):
        if average:
            return self.feature_channels.abs().mean()
        else:
            return self.feature_channels.abs().sum()

    def initialize(self):
        self.grid.data.uniform_(-self.init_range, self.init_range)
        self.feature_scales.data.fill_(1 / np.sqrt(self.in_shape[0]))
        self.feature_channels.data.fill_(1 / np.sqrt(self.in_shape[0]))

        if self.bias is not None:
            self.bias.data.fill_(0)
Ejemplo n.º 4
0
class NormalVar(Sampler):
	def __init__(self, input_channel, *args, init_logvar=1, **kwargs):
		super(NormalVar, self).__init__(*args,**kwargs)
		self.prec = Parameter(torch.ones(1,input_channel,1,1)*init_logvar)
		self.register_parameter('logvar',self.prec)
		self.prec.requires_grad=True
	def sample_normal(self,inputs,prec):
		noise = ((torch.randn_like(inputs) / ((prec).abs().sqrt()))).detach()
		return noise
	def sample_normal_nat(self, loc, scale):
		mean = loc/scale
		var = 1/scale
		output =  ((torch.randn_like(loc)* ((var).abs().sqrt()))).detach()
		return None
	def log_prob_normal(self,state,prec):
		logprob = -(((state) ** 2) * ((prec).abs()) / 2) + prec.abs().log() / 2
		return logprob
	def forward(self, inputs, concentration=1):
		inputs = inputs.detach()
		if not self.training:
			return inputs,0
		prec = self.prec.abs()
		conc_prec = prec*concentration
		noise = self.sample_normal(inputs,prec)
		noise_conc= self.sample_normal(inputs,conc_prec)

		logprob = self.log_prob_normal(noise,prec)
		logprob_conc = self.log_prob_normal(noise_conc,prec)

		output = inputs + noise
		logprob = (logprob-logprob_conc).sum(dim=(1,2,3),keepdim=True).squeeze()
		return output,logprob
class twnConv2d(nn.Conv2d):
    """
    custom convolutional layers for quantization
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 bias=True,
                 cRate=0.7):
        super(twnConv2d, self).__init__(in_channels=in_channels,
                                        out_channels=out_channels,
                                        kernel_size=kernel_size,
                                        stride=stride,
                                        padding=padding,
                                        bias=bias)
        self.weight_ternary = Parameter(torch.zeros(self.weight.data.size()))
        self.weight_alpha = Parameter(torch.ones(1))
        self.weight_delta = 0

        self.cRate = cRate

    def compute_grad(self):
        self.weight.grad = self.weight_ternary.grad
        # print self.weight_ternary
        # print self.weight_ternary.grad.data
        # print "alpha:", self.weight_alpha, "delta: ", self.weight_delta
        # assert False

    def forward(self, input):

        self.weight_delta = self.cRate * \
            self.weight.abs().mean().clamp(min=0, max=10).data[0]

        self.weight_ternary.data.copy_(
            (self.weight.gt(self.weight_delta).float() -
             self.weight.lt(-self.weight_delta).float()).data)

        self.weight_alpha.data.copy_(
            ((self.weight.abs() * self.weight_ternary.abs()).sum() /
             self.weight_ternary.abs().sum()).clamp(min=0, max=10).data)

        return F.conv2d(input * self.weight_alpha.data[0], self.weight_ternary,
                        self.bias, self.stride, self.padding, self.dilation,
                        self.groups)
Ejemplo n.º 6
0
class ActNorm(nn.Module):
    """Activation normalization, two ways to initialize:
        - data init: one minibatch of data
        - identity transform: used in sampling-based training of cGlow

    Args:
        in_features (Tensor): Number of input features
        return_logdet (bool): default True.
        data_init (bool): Use one minibatch data initialization or not, 
            default False.
    """
    def __init__(self, in_features, return_logdet=True, data_init=False):
        super(ActNorm, self).__init__()
        # identify transform
        self.weight = Parameter(torch.ones(in_features, 1, 1))
        self.bias = Parameter(torch.zeros(in_features, 1, 1))
        self.data_init = data_init
        self.data_initialized = False
        self.return_logdet = return_logdet

    def _init_parameters(self, input):
        # input: initial minibatch data
        # mean per channel: (B, C, H, W) --> (C, B, H, W) --> (C, BHW)
        input = input.transpose(0, 1).contiguous().view(input.shape[1], -1)
        mean = input.mean(1)
        std = input.std(1) + 1e-6
        self.bias.data = -(mean / std).unsqueeze(-1).unsqueeze(-1)
        self.weight.data = 1. / std.unsqueeze(-1).unsqueeze(-1)

    def forward(self, x):
        if self.data_init and (not self.data_initialized):
            self._init_parameters(x)
            self.data_initialized = True
        if self.return_logdet:
            logdet = self.weight.abs().log().sum() * x.shape[-1] * x.shape[-2]
            return self.weight * x + self.bias, logdet
        else:
            return self.weight * x + self.bias

    def reverse(self, y):
        if self.return_logdet:
            logdet = self.weight.abs().log().sum() * y.shape[-1] * y.shape[-2]
            return (y - self.bias) / self.weight, logdet
        else:
            return (y - self.bias) / self.weight
Ejemplo n.º 7
0
def test_quant_num_grad_align_zero():
    # TODO: we should add gradients to `clamp` op here
    x = torch.randn(1, 3, 224, 224, requires_grad=True, dtype=DTYPE, device=DEVICE)
    d_qx = torch.randn_like(x).detach()
    lb = Parameter(x.detach().min() + 0.1)
    ub = Parameter(x.detach().max() - 0.1)
    k = 8

    # autograd implementation
    assert ub.detach() - lb.detach() > 1e-2
    qx = fake_linear_quant(x, lb, ub, k, align_zero=True)
    qx.backward(d_qx)

    qx_gt = qx.detach()
    d_lb_gt = lb.grad.detach()
    d_ub_gt = ub.grad.detach()
    d_x_gt = x.grad.detach()

    # CUDA numerical implementation
    lb.grad.data.zero_()
    ub.grad.data.zero_()
    x.grad.data.zero_()

    qx = cuda_fake_linear_quant(x, lb, ub, k, align_zero=True)
    qx.backward(d_qx)

    qx_cuda = qx.detach()
    d_lb_cuda = lb.grad.detach()
    d_ub_cuda = ub.grad.detach()
    d_x_cuda = x.grad.detach()

    assert torch.allclose(qx_cuda, qx_gt)
    assert torch.allclose(d_lb_cuda, d_lb_gt)
    assert torch.allclose(d_ub_cuda, d_ub_gt)
    assert torch.allclose(d_x_cuda, d_x_gt)

    # numerical grad implementation
    with torch.no_grad():
        N = torch.tensor(2 ** k - 1, dtype=DTYPE, device=DEVICE)
        delta = ub.sub(lb).div(N)
        z = torch.round(lb.abs().div(delta))
        lb_ = z.neg().mul(delta)
        ub_ = (N - z).mul(delta)
        x_mask = (lb_ <= x) & (x <= ub_)  # pre-compute mask
        x = torch.clamp(x, lb_.item(), ub_.item())
        i = torch.round(x.sub(lb_).div(delta))

        # after forward, calculate cache
        x_sub = x - lb_ - torch.abs(lb)
        d_i = (i - z) - (x_sub / delta)
        d_lb, d_ub = d_lb_ub(d_qx, d_i, N, torch.sign(lb))
        dx = d_x(d_qx, x_mask)

        assert torch.allclose(d_lb_gt, d_lb)
        assert torch.allclose(d_ub_gt, d_ub)
        assert torch.allclose(dx, d_x_gt)
Ejemplo n.º 8
0
class XNORLinear(Module):
    def __init__(self, in_features, out_features, bias=True):
        super(XNORLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias

        self.linear = Linear(in_features=in_features,
                             out_features=out_features,
                             bias=bias)
        self.fp_weights = Parameter(zeros(self.linear.weight.size()))
        self.fp_weights.data.copy_(self.linear.weight.data)

    def forward(self, x):
        self.fp_weights.data = self.fp_weights.data - self.fp_weights.data.mean(
            1, keepdim=True)
        self.fp_weights.data.clamp_(-1, 1)

        self.mean_val = self.fp_weights.abs().view(self.out_features,
                                                   -1).mean(1, keepdim=True)

        self.linear.weight.data.copy_(self.fp_weights.data.sign() *
                                      self.mean_val.view(-1, 1))
        x = self.linear(x)
        return x

    def update_gradient(self):
        proxy = self.fp_weights.abs().sign()
        proxy[self.fp_weights.data.abs() > 1] = 0
        binary_grad = self.linear.weight.grad * self.mean_val.view(-1,
                                                                   1) * proxy

        mean_grad = self.linear.weight.data.sign() * self.linear.weight.grad
        mean_grad = mean_grad.view(self.out_features, -1).mean(1).view(-1, 1)
        mean_grad = mean_grad * self.linear.weight.data.sign()

        self.fp_weights.grad = binary_grad + mean_grad
        self.fp_weights.grad = self.fp_weights.grad * self.fp_weights.data[
            0].nelement() * (1 - 1 / self.fp_weights.data.size(1))
        return
Ejemplo n.º 9
0
class GraphSIR(torch.nn.Module):
    """A SIR model on Graph that considers considers travelling of the infected populations 
    """
    def __init__(self, intra_b, intra_k, inter_adj, inter_b, device='cpu'):
        """
        Definition of the coefficients follows the SIR model 

        Args:
            intra_b (TYPE): intra_city transmission probability, each city can have a different values, 
                            it depends on how crowded the city is 
            intra_k (TYPE): intra_city recovering probability, each city can have a different values, 
                            it depends on how crowded the city is 
            inter_adj (TYPE): a integer tensor of size (# of edges, 2), here we use a undirected graphs to ensure
                            detail balance 
            inter_b (TYPE): The travelling probability of the infected 
            device (str, optional): which device to run this model, CPU is the default. requires CUDA-enabled GPU 
        """
        super().__init__()

        self.N = intra_k.shape[0]  # number of nodes
        self.intra_b = Parameter(
            intra_b.to(device))  # b: infection probability within the city
        self.intra_k = Parameter(
            intra_k.to(device))  # k: healing probability within the city
        self.inter_adj = inter_adj.to(
            device)  # adjacency matrix among all the cities in the models
        self.inter_b = Parameter(
            inter_b)  # inter_b: infection coupling among different cities
        self.device = device  # what device to use, "cpu" as default

    def forward(self, t, s):

        dsdt = torch.zeros(self.N, 3).to(self.device)

        # infected from i to j
        i_2_j = self.inter_b.abs() * s[self.inter_adj[:, 0], 1]
        di_inter = scatter_add(
            i_2_j, self.inter_adj[:, 1], dim_size=self.N) - scatter_add(
                i_2_j, self.inter_adj[:, 0], dim_size=self.N)

        j_2_i = self.inter_b.abs() * s[self.inter_adj[:, 1], 1]
        di_inter += scatter_add(
            j_2_i, self.inter_adj[:, 0], dim_size=self.N) - scatter_add(
                j_2_i, self.inter_adj[:, 1], dim_size=self.N)

        # update the inter-city dependence
        dsdt[:, 1] += di_inter

        # Intra city development
        ds_intra = -s[:, 0] * s[:, 1] * self.intra_b.abs()
        di_intra = s[:, 0] * s[:, 1] * self.intra_b.abs(
        ) - s[:, 1] * self.intra_k.abs()
        dr_intra = s[:, 1] * self.intra_k.abs()

        # update the intra city dependence
        dsdt[:, 0] += ds_intra
        dsdt[:, 1] += di_intra
        dsdt[:, 2] += dr_intra

        return dsdt
Ejemplo n.º 10
0
class DiffBoundary:
    def __init__(self, bit_width=4):
        # TODO: add channel-wise option?
        self.bit_width = bit_width
        self.register_boundaries()

    def register_boundaries(self):
        assert hasattr(self, "weight")
        self.lb = Parameter(self.weight.data.min())
        self.ub = Parameter(self.weight.data.max())

    def reset_boundaries(self):
        assert hasattr(self, "weight")
        self.lb.data = self.weight.data.min()
        self.ub.data = self.weight.data.max()

    def get_quant_weight(self, align_zero=True):
        # TODO: set `align_zero`?
        if align_zero:
            return self._get_quant_weight_align_zero()
        else:
            return self._get_quant_weight()

    def _get_quant_weight(self):
        round_ = RoundSTE.apply
        w = self.weight.detach()
        delta = (self.ub - self.lb) / (2**self.bit_width - 1)
        w = torch.clamp(w, self.lb.item(), self.ub.item())
        idx = round_((w - self.lb).div(delta))  # TODO: do we need STE here?
        qw = (idx * delta) + self.lb
        return qw

    def _get_quant_weight_align_zero(self):
        # TODO: WTF?
        round_ = RoundSTE.apply
        n = 2**self.bit_width - 1
        w = self.weight.detach()
        delta = (self.ub - self.lb) / n
        z = round_(self.lb.abs() / delta)
        lb = -z * delta
        ub = (n - z) * delta
        w = torch.clamp(w, lb.item(), ub.item())
        idx = round_((w - self.lb).div(delta))  # TODO: do we need STE here?
        qw = (idx - z) * delta
        return qw
Ejemplo n.º 11
0
class Connection(AbstractConnection):
    # language=rst
    """
    Specifies synapses between one or two populations of neurons.
    """
    def __init__(self,
                 source: Nodes,
                 target: Nodes,
                 nu: Optional[Union[float, Sequence[float]]] = None,
                 weight_decay: float = 0.0,
                 **kwargs) -> None:
        # language=rst
        """
        Instantiates a :code:`Connection` object.

        :param source: A layer of nodes from which the connection originates.
        :param target: A layer of nodes to which the connection connects.
        :param nu: Learning rate for both pre- and post-synaptic events.
        :param weight_decay: Constant multiple to decay weights by on each iteration.

        Keyword arguments:

        :param LearningRule update_rule: Modifies connection parameters according to some rule.
        :param torch.Tensor w: Strengths of synapses.
        :param torch.Tensor b: Target population bias.
        :param float wmin: Minimum allowed value on the connection weights.
        :param float wmax: Maximum allowed value on the connection weights.
        :param float norm: Total weight per target neuron normalization constant.
        :param ByteTensor norm_by_max: Normalize the weight of a neuron by its max weight.
        :param ByteTensor norm_by_max_with_shadow_weights: Normalize the weight of a neuron by its max weight by
                                                           original weights.
        """
        super().__init__(source, target, nu, weight_decay, **kwargs)

        w = kwargs.get("w", None)
        if w is None:
            if self.wmin == -np.inf or self.wmax == np.inf:
                w = torch.clamp(torch.rand(source.n, target.n), self.wmin,
                                self.wmax)
            else:
                w = self.wmin + torch.rand(source.n,
                                           target.n) * (self.wmax - self.wmin)
        else:
            if self.wmin != -np.inf or self.wmax != np.inf:
                w = torch.clamp(w, self.wmin, self.wmax)

        self.w = Parameter(w, False)

        self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), False)

        if self.norm_by_max_from_shadow_weights:
            self.shadow_w = self.w.clone().detach()
            self.prev_w = self.w.clone().detach()

    def compute(self, s: torch.Tensor) -> torch.Tensor:
        # language=rst
        """
        Compute pre-activations given spikes using connection weights.

        :param s: Incoming spikes.
        :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation).
        """
        # Compute multiplication of spike activations by connection weights and add bias.
        post = s.float().view(-1) @ self.w + self.b
        return post.view(*self.target.shape)

    def update(self, **kwargs) -> None:
        # language=rst
        """
        Compute connection's update rule.
        """
        super().update(**kwargs)

    def normalize(self) -> None:
        # language=rst
        """
        Normalize weights so each target neuron has sum of connection weights equal to ``self.norm``.
        """
        if self.norm is not None:
            w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
            w_abs_sum[w_abs_sum == 0] = 1.0
            self.w *= self.norm / w_abs_sum

    def normalize_by_max(self) -> None:
        # language=rst
        """
        Normalize weights by the max weight of the target neuron.
        """
        if self.norm_by_max:
            w_max = self.w.abs().max(0)[0]
            w_max[w_max == 0] = 1.0
            self.w /= w_max

    def normalize_by_max_from_shadow_weights(self) -> None:
        # language=rst
        """
        Normalize weights by the max weight of the target neuron.
        """
        if self.norm_by_max_from_shadow_weights:
            self.shadow_w += self.w - self.prev_w
            w_max = self.shadow_w.abs().max(0)[0]
            w_max[w_max == 0] = 1.0
            self.w = self.shadow_w / w_max
            self.prev_w = self.w.clone().detach()

    def reset_(self) -> None:
        # language=rst
        """
        Contains resetting logic for the connection.
        """
        super().reset_()
Ejemplo n.º 12
0
class Connection(AbstractConnection):  # full connection
    # language=rst
    """
    Specifies synapses between one or two populations of neurons.
    """
    def __init__(self,
                 source: Nodes,
                 target: Nodes,
                 nu: Optional[Union[float, Sequence[float]]] = None,
                 reduction: Optional[callable] = None,
                 weight_decay: float = 0.0,
                 **kwargs) -> None:
        # language=rst
        """
        Instantiates a :code:`Connection` object.

        :param source: A layer of nodes from which the connection originates.
        :param target: A layer of nodes to which the connection connects.
        :param nu: Learning rate for both pre- and post-synaptic events.
        :param reduction: Method for reducing parameter updates along the minibatch
            dimension.
        :param weight_decay: Constant multiple to decay weights by on each iteration.

        Keyword arguments:

        :param LearningRule update_rule: Modifies connection parameters according to
            some rule.
        :param torch.Tensor w: Strengths of synapses.
        :param torch.Tensor b: Target population bias.
        :param float wmin: Minimum allowed value on the connection weights.
        :param float wmax: Maximum allowed value on the connection weights.
        :param float norm: Total weight per target neuron normalization constant.
        """
        super().__init__(source, target, nu, reduction, weight_decay, **kwargs)

        w = kwargs.get("w", None)  # 此处产生  w , 根据 source 和 target 的形状产生对应的矩阵
        if w is None:  # 若未设置w初值
            if self.wmin == -np.inf or self.wmax == np.inf:  # 若w的上下限未被全部设置
                w = torch.clamp(torch.rand(source.n, target.n), self.wmin,
                                self.wmax)  # 包含了从区间0-1 中的随机数作为初值
            else:
                w = self.wmin + torch.rand(source.n, target.n) * (
                    self.wmax - self.wmin)  # 设置上下限
        else:
            if self.wmin != -np.inf or self.wmax != np.inf:
                w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax)

        self.w = Parameter(w, requires_grad=False)

        b = kwargs.get("b", None)
        if b is not None:
            self.b = Parameter(b, requires_grad=False)
        else:
            self.b = None

        if isinstance(self.target, CSRMNodes):
            self.s_w = None

    def compute(
        self, s: torch.Tensor
    ) -> torch.Tensor:  # 关键的函数:  输入: incoming spikes(从source层中获取) 输出: 经过权重乘积得到的输入target层的值
        # language=rst
        """
        Compute pre-activations given spikes using connection weights.

        :param s: Incoming spikes.
        :return: Incoming spikes multiplied by synaptic weights (with or without
                 decaying spike activation).
        """
        # Compute multiplication of spike activations by weights and add bias.
        if self.b is None:
            post = s.view(s.size(0),
                          -1).float() @ self.w  # @ :matrix multi vector
        else:
            post = s.view(s.size(0), -1).float() @ self.w + self.b
        return post.view(s.size(0), *self.target.shape)

    def compute_window(self, s: torch.Tensor) -> torch.Tensor:
        # language=rst
        """"""

        if self.s_w == None:
            # Construct a matrix of shape batch size * window size * dimension of layer
            self.s_w = torch.zeros(self.target.batch_size,
                                   self.target.res_window_size,
                                   *self.source.shape)

        # Add the spike vector into the first in first out matrix of windowed (res) spike trains
        self.s_w = torch.cat((self.s_w[:, 1:, :], s[:, None, :]), 1)

        # Compute multiplication of spike activations by weights and add bias.
        if self.b is None:
            post = (self.s_w.view(self.s_w.size(0), self.s_w.size(1),
                                  -1).float() @ self.w)
        else:
            post = (self.s_w.view(self.s_w.size(0), self.s_w.size(1),
                                  -1).float() @ self.w + self.b)

        return post.view(self.s_w.size(0), self.target.res_window_size,
                         *self.target.shape)

    def update(self, **kwargs) -> None:
        # language=rst
        """
        Compute connection's update rule.
        """
        super().update(**kwargs)

    def normalize(self) -> None:
        # language=rst
        """
        Normalize weights so each target neuron has sum of connection weights equal to
        ``self.norm``.
        """
        if self.norm is not None:
            w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
            w_abs_sum[w_abs_sum == 0] = 1.0
            self.w *= self.norm / w_abs_sum

    def reset_state_variables(self) -> None:
        # language=rst
        """
        Contains resetting logic for the connection.
        """
        super().reset_state_variables()
Ejemplo n.º 13
0
class PointPooled2d(Readout):
    def __init__(
        self,
        in_shape,
        outdims,
        pool_steps,
        bias,
        pool_kern,
        init_range,
        align_corners=True,
        mean_activity=None,
        feature_reg_weight=1.0,
        gamma_readout=None,  # depricated, use feature_reg_weight instead
        **kwargs,
    ):
        """
        This readout learns a point in the core feature space for each neuron, with help of torch.grid_sample, that best
        predicts its response. Multiple average pooling steps are applied to reduce search space in each stage and thereby, faster convergence to the best prediction point.

        The readout receives the shape of the core as 'in_shape', number of pooling stages to be performed as 'pool_steps', the kernel size and stride length
        to be used for pooling as 'pool_kern', the number of units/neurons being predicted as 'outdims', 'bias' specifying whether
        or not bias term is to be used and 'init_range' range for initialising the grid with uniform distribution, U(-init_range,init_range).
        The grid parameter contains the normalized locations (x, y coordinates in the core feature space) and is clipped to [-1.1] as it a
        requirement of the torch.grid_sample function. The feature parameter learns the best linear mapping from the pooled feature
        map from a given location to a unit's response with or without an additional elu non-linearity.

        Args:
            in_shape (list): shape of the input feature map [channels, width, height]
            outdims (int): number of output units
            pool_steps (int): number of pooling stages
            bias (bool): adds a bias term
            pool_kern (int): filter size and stride length used for pooling the feature map
            init_range (float): intialises the grid with Uniform([-init_range, init_range])
                                [expected: positive value <=1]
            align_corners (bool): Keyword agrument to gridsample for bilinear interpolation.
                It changed behavior in PyTorch 1.3. The default of align_corners = True is setting the
                behavior to pre PyTorch 1.3 functionality for comparability.
        """
        super().__init__()
        if init_range > 1.0 or init_range <= 0.0:
            raise ValueError("init_range is not within required limit!")
        self._pool_steps = pool_steps
        self.in_shape = in_shape
        c, w, h = in_shape
        self.outdims = outdims
        self.feature_reg_weight = self.resolve_deprecated_gamma_readout(
            feature_reg_weight, gamma_readout)
        self.mean_activity = mean_activity
        self.grid = Parameter(torch.Tensor(
            1, outdims, 1, 2))  # x-y coordinates for each neuron
        self.features = Parameter(
            torch.Tensor(1, c * (self._pool_steps + 1), 1, outdims)
        )  # weight matrix mapping the core features to the output units

        if bias:
            bias = Parameter(torch.Tensor(outdims))
            self.register_parameter("bias", bias)
        else:
            self.register_parameter("bias", None)

        self.pool_kern = pool_kern
        self.avg = nn.AvgPool2d(
            (pool_kern, pool_kern), stride=pool_kern, count_include_pad=False
        )  # setup kernel of size=[pool_kern,pool_kern] with stride=pool_kern
        self.init_range = init_range
        self.align_corners = align_corners
        self.initialize(mean_activity)

    @property
    def pool_steps(self):
        return self._pool_steps

    @pool_steps.setter
    def pool_steps(self, value):
        assert value >= 0 and int(
            value
        ) - value == 0, "new pool steps must be a non-negative integer"
        if value != self._pool_steps:
            logger.info("Resizing readout features")
            c, w, h = self.in_shape
            self._pool_steps = int(value)
            self.features = Parameter(
                torch.Tensor(1, c * (self._pool_steps + 1), 1, self.outdims))
            self.features.data.fill_(1 / self.in_shape[0])

    def initialize(self, mean_activity=None):
        """
        Initialize function initialises the grid, features or weights and bias terms.
        """
        if mean_activity is None:
            mean_activity = self.mean_activity
        self.grid.data.uniform_(-self.init_range, self.init_range)
        self.features.data.fill_(1 / self.in_shape[0])
        if self.bias is not None:
            self.initialize_bias(mean_activity=mean_activity)

    def feature_l1(self, reduction="sum", average=None):
        """
        Returns l1 regularization term for features.
        Args:
            average(bool): Deprecated (see reduction) if True, use mean of weights for regularization
            reduction(str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
        """
        return self.apply_reduction(self.features.abs(),
                                    reduction=reduction,
                                    average=average)

    def regularizer(self, reduction="sum", average=None):
        return self.feature_l1(reduction=reduction,
                               average=average) * self.feature_reg_weight

    def forward(self, x, shift=None, out_idx=None, **kwargs):
        """
        Propagates the input forwards through the readout
        Args:
            x: input data
            shift: shifts the location of the grid (from eye-tracking data)
            out_idx: index of neurons to be predicted

        Returns:
            y: neuronal activity
        """
        self.grid.data = torch.clamp(self.grid.data, -1, 1)
        N, c, w, h = x.size()
        c_in, w_in, h_in = self.in_shape
        if [c_in, w_in, h_in] != [c, w, h]:
            raise ValueError(
                "the specified feature map dimension is not the readout's expected input dimension"
            )

        m = self.pool_steps + 1  # the input feature is considered the first pooling stage
        feat = self.features.view(1, m * c, self.outdims)
        if out_idx is None:
            grid = self.grid
            bias = self.bias
            outdims = self.outdims
        else:
            if isinstance(out_idx, np.ndarray):
                if out_idx.dtype == bool:
                    out_idx = np.where(out_idx)[0]
            feat = feat[:, :, out_idx]
            grid = self.grid[:, out_idx]
            if self.bias is not None:
                bias = self.bias[out_idx]
            outdims = len(out_idx)

        if shift is None:
            grid = grid.expand(N, outdims, 1, 2)
        else:
            # shift grid based on shifter network's prediction
            grid = grid.expand(N, outdims, 1, 2) + shift[:, None, None, :]

        pools = [F.grid_sample(x, grid, align_corners=self.align_corners)]
        for _ in range(self.pool_steps):
            _, _, w_pool, h_pool = x.size()
            if w_pool * h_pool == 1:
                warnings.warn(
                    "redundant pooling steps: pooled feature map size is already 1X1, consider reducing it"
                )
            x = self.avg(x)
            pools.append(
                F.grid_sample(x, grid, align_corners=self.align_corners))
        y = torch.cat(pools, dim=1)
        y = (y.squeeze(-1) * feat).sum(1).view(N, outdims)

        if self.bias is not None:
            y = y + bias
        return y

    def __repr__(self):
        c, w, h = self.in_shape
        r = self.__class__.__name__ + " (" + "{} x {} x {}".format(
            c, w, h) + " -> " + str(self.outdims) + ")"
        if self.bias is not None:
            r += " with bias"
        r += " and pooling for {} steps\n".format(self.pool_steps)
        for ch in self.children():
            r += "  -> " + ch.__repr__() + "\n"
        return r
Ejemplo n.º 14
0
class Connection(AbstractConnection):
    # language=rst
    """
    Specifies synapses between one or two populations of neurons.
    """
    def __init__(self,
                 source: Nodes,
                 target: Nodes,
                 nu: Optional[Union[float, Sequence[float]]] = None,
                 reduction: Optional[callable] = None,
                 weight_decay: float = 0.0,
                 **kwargs) -> None:
        # language=rst
        """
        Instantiates a :code:`Connection` object.

        :param source: A layer of nodes from which the connection originates.
        :param target: A layer of nodes to which the connection connects.
        :param nu: Learning rate for both pre- and post-synaptic events.
        :param reduction: Method for reducing parameter updates along the minibatch
            dimension.
        :param weight_decay: Constant multiple to decay weights by on each iteration.

        Keyword arguments:

        :param LearningRule update_rule: Modifies connection parameters according to
            some rule.
        :param torch.Tensor w: Strengths of synapses.
        :param torch.Tensor b: Target population bias.
        :param float wmin: Minimum allowed value on the connection weights.
        :param float wmax: Maximum allowed value on the connection weights.
        :param float norm: Total weight per target neuron normalization constant.
        """
        super().__init__(source, target, nu, reduction, weight_decay, **kwargs)

        w = kwargs.get("w", None)
        if w is None:
            if self.wmin == -np.inf or self.wmax == np.inf:
                w = torch.clamp(torch.rand(source.n, target.n), self.wmin,
                                self.wmax)
            else:
                w = self.wmin + torch.rand(source.n,
                                           target.n) * (self.wmax - self.wmin)
        else:
            if self.wmin != -np.inf or self.wmax != np.inf:
                w = torch.clamp(w, self.wmin, self.wmax)

        self.w = Parameter(w, requires_grad=False)

        b = kwargs.get("b", None)
        if b is not None:
            self.b = Parameter(b, requires_grad=False)
        else:
            self.b = None

    def compute(self, s: torch.Tensor) -> torch.Tensor:
        # language=rst
        """
        Compute pre-activations given spikes using connection weights.

        :param s: Incoming spikes.
        :return: Incoming spikes multiplied by synaptic weights (with or without
                 decaying spike activation).
        """
        # Compute multiplication of spike activations by weights and add bias.
        if self.b is None:
            post = s.view(s.size(0), -1).float() @ self.w
        else:
            post = s.view(s.size(0), -1).float() @ self.w + self.b
        return post.view(s.size(0), *self.target.shape)

    def update(self, **kwargs) -> None:
        # language=rst
        """
        Compute connection's update rule.
        """
        super().update(**kwargs)

    def normalize(self) -> None:
        # language=rst
        """
        Normalize weights so each target neuron has sum of connection weights equal to
        ``self.norm``.
        """
        if self.norm is not None:
            w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
            w_abs_sum[w_abs_sum == 0] = 1.0
            self.w *= self.norm / w_abs_sum

    def reset_state_variables(self) -> None:
        # language=rst
        """
        Contains resetting logic for the connection.
        """
        super().reset_state_variables()
Ejemplo n.º 15
0
class PointPyramid2d(Readout):
    def __init__(
        self,
        in_shape,
        outdims,
        scale_n,
        positive,
        bias,
        init_range,
        downsample,
        type,
        align_corners=True,
        mean_activity=None,
        feature_reg_weight=1.0,
        gamma_readout=None,  # depricated, use feature_reg_weight instead
        **kwargs,
    ):
        super().__init__()
        self.in_shape = in_shape
        c, w, h = in_shape
        self.outdims = outdims
        self.positive = positive
        self.feature_reg_weight = self.resolve_deprecated_gamma_readout(
            feature_reg_weight, gamma_readout)
        self.mean_activity = mean_activity
        self.gauss_pyramid = Pyramid(scale_n=scale_n,
                                     downsample=downsample,
                                     type=type)
        self.grid = Parameter(torch.Tensor(1, outdims, 1, 2))
        self.features = Parameter(
            torch.Tensor(1, c * (scale_n + 1), 1, outdims))

        if bias:
            bias = Parameter(torch.Tensor(outdims))
            self.register_parameter("bias", bias)
        else:
            self.register_parameter("bias", None)
        self.init_range = init_range
        self.align_corners = align_corners
        self.initialize(mean_activity)

    def initialize(self, mean_activity=None):
        if mean_activity is None:
            mean_activity = self.mean_activity
        self.grid.data.uniform_(-self.init_range, self.init_range)
        self.features.data.fill_(1 / self.in_shape[0])
        if self.bias is not None:
            self.initialize_bias(mean_activity=mean_activity)

    def group_sparsity(self, group_size):
        f = self.features.size(1)
        n = f // group_size
        ret = 0
        for chunk in range(0, f, group_size):
            ret = ret + (self.features[:, chunk:chunk + group_size, ...].pow(
                2).mean(1) + 1e-12).sqrt().mean() / n
        return ret

    def feature_l1(self, reduction="sum", average=None):
        return self.apply_reduction(self.features.abs(),
                                    reduction=reduction,
                                    average=average)

    def regularizer(self, reduction="sum", average=None):
        return self.feature_l1(reduction=reduction,
                               average=average) * self.feature_reg_weight

    def forward(self, x, shift=None):
        if self.positive:
            self.features.data.clamp_min_(0)
        self.grid.data = torch.clamp(self.grid.data, -1, 1)
        N, c, w, h = x.size()
        m = self.gauss_pyramid.scale_n + 1
        feat = self.features.view(1, m * c, self.outdims)

        if shift is None:
            grid = self.grid.expand(N, self.outdims, 1, 2)
        else:
            grid = self.grid.expand(N, self.outdims, 1, 2) + shift[:, None,
                                                                   None, :]

        pools = [
            F.grid_sample(xx, grid, align_corners=self.align_corners)
            for xx in self.gauss_pyramid(x)
        ]
        y = torch.cat(pools, dim=1).squeeze(-1)
        y = (y * feat).sum(1).view(N, self.outdims)

        if self.bias is not None:
            y = y + self.bias
        return y

    def __repr__(self):
        c, w, h = self.in_shape
        r = self.__class__.__name__ + " (" + "{} x {} x {}".format(
            c, w, h) + " -> " + str(self.outdims) + ")"
        if self.bias is not None:
            r += " with bias"

        for ch in self.children():
            r += "  -> " + ch.__repr__() + "\n"
        return r
Ejemplo n.º 16
0
class DynamicConnection(AbstractConnection):
    # language=rst
    """
    Specifies synapses between one or two populations of neurons.
    The weight matrix is allowed to rewire dynamically
    """
    def __init__(self,
                 source: Nodes,
                 target: Nodes,
                 nu: Optional[Union[float, Sequence[float]]] = None,
                 reduction: Optional[callable] = None,
                 weight_decay: float = 0.0,
                 **kwargs) -> None:
        # language=rst
        """
        Instantiates a :code:`DynamicConnection` object.

        :param source: A layer of nodes from which the connection originates.
        :param target: A layer of nodes to which the connection connects.
        :param nu: Learning rate for both pre- and post-synaptic events.
        :param reduction: Method for reducing parameter updates along the minibatch
            dimension.
        :param weight_decay: Constant multiple to decay weights by on each iteration.

        Keyword arguments:

        :param LearningRule update_rule: Modifies connection parameters according to
            some rule.
        :param torch.Tensor w: Strengths of synapses.
        :param torch.Tensor b: Target population bias.
        :param float wmin: Minimum allowed value on the connection weights.
        :param float wmax: Maximum allowed value on the connection weights.
        :param float norm: Total weight per target neuron normalization constant.
        :param prune_thresh: Weight threshold for pruning
        :param prune_prob: Probability for pruning
        :param create_prob: Probability for probabalistic synaptogenesis
        :param create: Enable activity dependent synaptogenesis
        """
        super().__init__(source, target, nu, reduction, weight_decay, **kwargs)

        w = kwargs.get("w", None)
        prune_thresh = kwargs.get("prune_thresh", 0.0)
        prune_prob = kwargs.get("prune_prob", 0.0)
        create_prob = kwargs.get("create_prob", 0.0)
        create = kwargs.get("create", False)

        if w is None:
            if self.wmin == -np.inf or self.wmax == np.inf:
                w = torch.clamp(torch.rand(source.n, target.n), self.wmin,
                                self.wmax)
            else:
                w = self.wmin + torch.rand(source.n,
                                           target.n) * (self.wmax - self.wmin)
        else:
            if self.wmin != -np.inf or self.wmax != np.inf:
                w = torch.clamp(w, self.wmin, self.wmax)

        self.w = Parameter(w, requires_grad=False)
        self.b = Parameter(kwargs.get("b", torch.zeros(target.n)),
                           requires_grad=False)
        self.prune_thresh = prune_thresh
        self.prune_prob = prune_prob
        self.create_prob = create_prob
        self.create = create

    def compute(self, s: torch.Tensor) -> torch.Tensor:
        # language=rst
        """
        Compute pre-activations given spikes using connection weights.

        :param s: Incoming spikes.
        :return: Incoming spikes multiplied by synaptic weights (with or without
                 decaying spike activation).
        """
        # Compute multiplication of spike activations by weights and add bias.
        post = s.float().view(s.size(0), -1) @ self.w + self.b
        return post.view(s.size(0), *self.target.shape)

    def update(self, **kwargs) -> None:
        # language=rst
        """
        Compute connection's update rule.
        This is for functional and structural plasticity
        It implements a form of pruning by forcing matrix entries
        to zero according to a threshold
        """

        # call regular functional plasticity rule
        super().update(**kwargs)

    def sp(self) -> Tuple:
        # language=rst
        """
        Runs structural plasticity
        """

        total_conns_created = 0
        total_conns_pruned = 0

        #print("Dynamic Weights before structural plasticity", self.w.data)

        # Synaptogenesis mechanisms

        if self.create_prob > 0.0:

            #print("Probabalistic synaptogenesis")

            # Create a probability mask

            create_mask = torch.rand(self.w.data.shape)
            #print("probs", create_mask)

            create_mask[create_mask < self.create_prob] = 0.0
            create_mask[create_mask >= self.create_prob] = 1.0

            #print("mask",create_mask)
            #print("wt",self.w.data)
            #print((create_mask==0.0).sum().data)
            #print(self.w.data[(create_mask == 0.0) & (self.w.data == 0.0)].data.shape)

            self.w.data[(create_mask == 0.0)
                        & (self.w.data == 0.0)] = 0.3 * (np.random.uniform(
                            self.wmin, self.wmax))

            total_conns_created += (create_mask == 0.0).sum().data

        if self.create:

            #print("Activity dependent synaptogenesis")

            # get the source and target activity traces

            batch_size = self.source.batch_size

            source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
            target_x = self.target.x.view(batch_size, -1).unsqueeze(1)

            # Create masks where the source and target traces are > 0.5
            # i.e. the neurons that have been recently 'active'

            source_mask = torch.zeros_like(source_x)
            source_mask = source_mask.type(torch.BoolTensor)
            target_mask = torch.zeros_like(target_x)
            target_mask = target_mask.type(torch.BoolTensor)

            #print("src activity max", torch.max(source_x))
            #print("target activity max", torch.max(target_x))
            #print("zero weights", (self.w == 0.0).sum().data)

            source_mask[(source_x.data > 0.50)] = True
            target_mask[(target_x.data > 0.50)] = True

            #print("source x", source_x)
            #print("source x mask", source_mask,  torch.max(source_mask))
            #print("target x", target_x)
            #print("target x mask", target_mask,  torch.max(target_mask))

            # Need to check if the source and target mask tensors actually have any True entries
            # otherwise not worth proceeding!

            #print("Dynamic Weights before structural plasticity", self.w.data)

            if (torch.max(source_mask) == True) and (torch.max(target_mask)
                                                     == True):

                #if (torch.max(source_mask[0,0,:]) == True) and (torch.max(target_mask[0,0,:]) == True):

                # Create a mask of random weight values between min and max
                # zero all values where the weight matrix is not zero
                weight_mask = torch.FloatTensor(self.w.data.shape[0],
                                                self.w.data.shape[1]).uniform_(
                                                    self.wmin, self.wmax) * 0.3
                #print(weight_mask)
                weight_mask[(self.w.data != 0.0)] = 0.0

                # Here we 'not' the source and target masks so we can easily
                # use them to set the weight matrix to zero where activity is not > 0.5
                source_mask = torch.logical_not(source_mask)
                target_mask = torch.logical_not(target_mask)

                weight_mask[source_mask[0, :, 0], :] = 0.0
                weight_mask[:, target_mask[0, 0, :]] = 0.0
                #print(weight_mask)

                #print("Weights to create", (weight_mask!=0.0).sum().data)

                # Now simply add the weight mask to the weights.
                # This should have the effect of setting 'new' weights
                # only where the source and target traces values are above
                # threshold and where the weights were previously zero

                self.w.data += weight_mask

                total_conns_created += (weight_mask != 0.0).sum().data

                #self.w.data[source_mask[0,:,0],target_mask[0,0,:]] =  self.w.data[source_mask[0,:,0],target_mask[0,0,:]] + weight_mask[source_mask[0,:,0],target_mask[0,0,:]]
            else:
                pass
                #print("Nothing to update")

            #print("Dynamic Weights after structural plasticity", self.w.data)

        # Connection pruning mechanisms

        if self.prune_thresh > 0.0:

            #print("Threshold pruning")

            # Threshold pruning

            # set all values less than the threshold to zero
            # We have to handle negative weights as well

            # Create a mask so we can calculate the number
            # of connections that will be pruned
            prune_mask = torch.ones(self.w.data.shape)

            prune_mask[(self.w.data > 0.0)
                       & (self.w.data < self.prune_thresh)] = 0.0
            prune_mask[(self.w.data < 0.0)
                       & (self.w.data > -(self.prune_thresh))] = 0.0

            # Prune the actual connections
            self.w.data[(self.w.data > 0.0)
                        & (self.w.data < self.prune_thresh)] = 0.0
            self.w.data[(self.w.data < 0.0)
                        & (self.w.data > -(self.prune_thresh))] = 0.0

            total_conns_pruned += (prune_mask == 0.0).sum().data

        if self.prune_prob > 0.0:

            #print("Probabalistic pruning")

            # Probabalistic pruning

            # Create a probability mask

            prune_mask = torch.rand(self.w.data.shape)
            #print("probs", prune_mask)
            prune_mask[prune_mask < self.prune_prob] = 0.0
            prune_mask[prune_mask >= self.prune_prob] = 1.0

            #print("prune_mask", prune_mask)
            #print("wt", self.w.data)

            #print((prune_mask==0.0).sum().data)

            self.w.data = prune_mask * self.w.data

            total_conns_pruned += (prune_mask == 0.0).sum().data

        return (total_conns_created, total_conns_pruned)

    def normalize(self) -> None:
        # language=rst
        """
        Normalize weights so each target neuron has sum of connection weights equal to
        ``self.norm``.
        """
        if self.norm is not None:
            w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
            w_abs_sum[w_abs_sum == 0] = 1.0
            self.w *= self.norm / w_abs_sum

    def reset_state_variables(self) -> None:
        # language=rst
        """
        Contains resetting logic for the connection.
        """
        super().reset_state_variables()
Ejemplo n.º 17
0
class SpatialTransformerPyramid3d(nn.Module):
    def __init__(self, in_shape, outdims, scale_n=4, positive=True, bias=True, init_range=.05, downsample=True,
                 _skip_upsampling=False, type=None):
        super().__init__()
        self.in_shape = in_shape
        c, _, w, h = in_shape
        self.outdims = outdims
        self.positive = positive
        self.gauss = Pyramid(scale_n=scale_n, downsample=downsample, _skip_upsampling=_skip_upsampling, type=type)

        self.grid = Parameter(torch.Tensor(1, outdims, 1, 2))
        self.features = Parameter(torch.Tensor(1, c * (scale_n + 1), 1, outdims))

        if bias:
            bias = Parameter(torch.Tensor(outdims))
            self.register_parameter('bias', bias)
        else:
            self.register_parameter('bias', None)
        self.init_range = init_range
        self.initialize()

    def initialize(self):
        self.grid.data.uniform_(-self.init_range, self.init_range)
        self.features.data.fill_(1 / self.in_shape[0])

        if self.bias is not None:
            self.bias.data.fill_(0)

    def feature_l1(self, average=True, subs_idx=None):
        if subs_idx is not None: raise NotImplemented('Subsample is not implemented.')

        if average:
            return self.features.abs().mean()
        else:
            return self.features.abs().sum()

    def forward(self, x, shift=None, subs_idx=None):
        if subs_idx is not None: raise NotImplemented('Subsample is not implemented.')

        if self.positive:
            positive(self.features)
        self.grid.data = torch.clamp(self.grid.data, -1, 1)
        N, c, t, w, h = x.size()
        m = self.gauss.scale_n + 1
        feat = self.features.view(1, m * c, self.outdims)

        if shift is None:
            grid = self.grid.expand(N * t, self.outdims, 1, 2)
        else:
            grid = self.grid.expand(N, self.outdims, 1, 2)
            grid = torch.stack([grid + shift[:, i, :][:, None, None, :] for i in range(t)], 1)
            grid = grid.contiguous().view(-1, self.outdims, 1, 2)

        z = x.contiguous().transpose(2, 1).contiguous().view(-1, c, w, h)
        pools = [F.grid_sample(x, grid) for x in self.gauss(z)]
        y = torch.cat(pools, dim=1).squeeze(-1)
        y = (y * feat).sum(1).view(N, t, self.outdims)

        if self.bias is not None:
            y = y + self.bias
        return y

    def __repr__(self):
        c, t, w, h = self.in_shape
        r = self.__class__.__name__ + \
            ' (' + '{} x {} x {}'.format(c, w, h) + ' -> ' + str(self.outdims) + ')'
        if self.bias is not None:
            r += ' with bias'

        for ch in self.children():
            r += '\n  -> ' + ch.__repr__()
        return r
Ejemplo n.º 18
0
class SpatialTransformerPooled2d(nn.Module):
    def __init__(self, in_shape, outdims, pool_steps=1, positive=False, bias=True,
                 pool_kern=2, init_range=.1):
        super().__init__()
        self.pool_steps = pool_steps
        self.in_shape = in_shape
        c, w, h = in_shape
        self.outdims = outdims
        self.positive = positive
        self.grid = Parameter(torch.Tensor(1, outdims, 1, 2))
        self.features = Parameter(torch.Tensor(1, c * (self.pool_steps + 1), 1, outdims))

        if bias:
            bias = Parameter(torch.Tensor(outdims))
            self.register_parameter('bias', bias)
        else:
            self.register_parameter('bias', None)

        self.pool_kern = pool_kern
        self.avg = nn.AvgPool2d((pool_kern, pool_kern), stride=pool_kern, count_include_pad=False)
        self.init_range = init_range
        self.initialize()

    def initialize(self):
        self.grid.data.uniform_(-self.init_range, self.init_range)
        self.features.data.fill_(1 / self.in_shape[0])

        if self.bias is not None:
            self.bias.data.fill_(0)

    def feature_l1(self, average=True):
        if average:
            return self.features.abs().mean()
        else:
            return self.features.abs().sum()

    def group_sparsity(self, group_size):
        f = self.features.size(1)
        n = f // group_size
        ret = 0
        for chunk in range(0, f, group_size):
            ret = ret + (self.features[:, chunk:chunk + group_size, ...].pow(2).mean(1) + 1e-12).sqrt().mean() / n
        return ret

    def forward(self, x, shift=None):
        if self.positive:
            positive(self.features)
        self.grid.data = torch.clamp(self.grid.data, -1, 1)
        N, c, w, h = x.size()
        m = self.pool_steps + 1
        feat = self.features.view(1, m * c, self.outdims)

        if shift is None:
            grid = self.grid.expand(N, self.outdims, 1, 2)
        else:
            grid = self.grid.expand(N, self.outdims, 1, 2) + shift[:, None, None, :]

        pools = [F.grid_sample(x, grid)]
        for _ in range(self.pool_steps):
            x = self.avg(x)
            pools.append(F.grid_sample(x, grid))
        y = torch.cat(pools, dim=1)
        y = (y.squeeze(-1) * feat).sum(1).view(N, self.outdims)

        if self.bias is not None:
            y = y + self.bias
        return y

    def __repr__(self):
        c, w, h = self.in_shape
        r = self.__class__.__name__ + \
            ' (' + '{} x {} x {}'.format(c, w, h) + ' -> ' + str(self.outdims) + ')'
        if self.bias is not None:
            r += ' with bias'
        r += ' and pooling for {} steps\n'.format(self.pool_steps)
        for ch in self.children():
            r += '  -> ' + ch.__repr__() + '\n'
        return r
Ejemplo n.º 19
0
class SpatialTransformerPyramid2d(nn.Module):
    def __init__(self, in_shape, outdims, scale_n=4, positive=False, bias=True,
                 init_range=.1, downsample=True, _skip_upsampling=False, type=None):
        super().__init__()
        self.in_shape = in_shape
        c, w, h = in_shape
        self.outdims = outdims
        self.positive = positive
        self.gauss_pyramid = Pyramid(scale_n=scale_n, downsample=downsample, _skip_upsampling=_skip_upsampling, type=type)
        self.grid = Parameter(torch.Tensor(1, outdims, 1, 2))
        self.features = Parameter(torch.Tensor(1, c * (scale_n + 1), 1, outdims))

        if bias:
            bias = Parameter(torch.Tensor(outdims))
            self.register_parameter('bias', bias)
        else:
            self.register_parameter('bias', None)
        self.init_range = init_range
        self.initialize()

    def initialize(self):
        self.grid.data.uniform_(-self.init_range, self.init_range)
        self.features.data.fill_(1 / self.in_shape[0])

        if self.bias is not None:
            self.bias.data.fill_(0)

    def group_sparsity(self, group_size):
        f = self.features.size(1)
        n = f // group_size
        ret = 0
        for chunk in range(0, f, group_size):
            ret = ret + (self.features[:, chunk:chunk + group_size, ...].pow(2).mean(1) + 1e-12).sqrt().mean() / n
        return ret

    def feature_l1(self, average=True):
        if average:
            return self.features.abs().mean()
        else:
            return self.features.abs().sum()

    def neuron_layer_power(self, x, neuron_id):
        if self.positive:
            positive(self.features)
        self.grid.data = torch.clamp(self.grid.data, -1, 1)
        N, c, w, h = x.size()
        m = self.gauss_pyramid.scale_n + 1
        feat = self.features.view(1, m * c, self.outdims)

        y = torch.cat(self.gauss_pyramid(x), dim=1)
        y = (y * feat[:, :, neuron_id, None, None]).sum(1)
        return y.pow(2).mean()

    def forward(self, x, shift=None):
        if self.positive:
            positive(self.features)
        self.grid.data = torch.clamp(self.grid.data, -1, 1)
        N, c, w, h = x.size()
        m = self.gauss_pyramid.scale_n + 1
        feat = self.features.view(1, m * c, self.outdims)

        if shift is None:
            grid = self.grid.expand(N, self.outdims, 1, 2)
        else:
            grid = self.grid.expand(N, self.outdims, 1, 2) + shift[:, None, None, :]

        pools = [F.grid_sample(xx, grid) for xx in self.gauss_pyramid(x)]
        y = torch.cat(pools, dim=1).squeeze(-1)
        y = (y * feat).sum(1).view(N, self.outdims)

        if self.bias is not None:
            y = y + self.bias
        return y

    def __repr__(self):
        c, w, h = self.in_shape
        r = self.__class__.__name__ + \
            ' (' + '{} x {} x {}'.format(c, w, h) + ' -> ' + str(self.outdims) + ')'
        if self.bias is not None:
            r += ' with bias'

        for ch in self.children():
            r += '  -> ' + ch.__repr__() + '\n'
        return r
Ejemplo n.º 20
0
class AllToAllConnection(ABC, Module):
    def __init__(self,
                 source: Nodes,
                 target: Nodes,
                 w: None,
                 tc_synaptic: float = 0.0,
                 phi: float = 0.0,
                 nu: Optional[Union[float, Sequence[float]]] = None,
                 weight_decay: float = 0.0,
                 **kwargs) -> None:
        """
        :param source: A layer of nodes from which the connection originates.
        :param target: A layer of nodes to which the connection connects.
        :param nu: Learning rate for both pre- and post-synaptic events.
        :param reduction: Method for reducing parameter updates along the minibatch
            dimension.
        :param weight_decay: Constant multiple to decay weights by on each iteration.

        Keyword arguments:

        :param LearningRule update_rule: Modifies connection parameters according to
            some rule.
        :param torch.Tensor w: Strengths of synapses.
        :param torch.Tensor b: Target population bias
        :param float wmin: The minimum value on the connection weights.
        :param float wmax: The maximum value on the connection weights.
        :param float norm: Total weight per target neuron normalization.
        """
        super().__init__()  # initialisation of Module

        assert isinstance(source, Nodes), "Source is not a Nodes object"
        assert isinstance(target, Nodes), "Target is not a Nodes object"

        self.source = source
        self.target = target

        self.nu = nu
        self.weight_decay = weight_decay
        # self.reduction = reduction

        self.update_rule = kwargs.get("update_rule", NoOp)
        self.wmin = kwargs.get("wmin", -np.inf)
        self.wmax = kwargs.get("wmax", np.inf)
        self.norm = kwargs.get("norm", None)
        # self.decay = kwargs.get("decay", None)

        # Learning rule
        if self.update_rule is None:
            self.update_rule = NoOp

        self.update_rule = self.update_rule(
            connection=self,
            nu=nu,
            # reduction=reduction,
            weight_decay=weight_decay,
            **kwargs)

        # Weights
        self.w = Parameter(w, requires_grad=False)
        self.b = Parameter(kwargs.get("b", torch.zeros(target.n)),
                           requires_grad=False)

        # Parameters used to update synaptic input
        self.active_neurotransmitters = torch.zeros(self.source.n,
                                                    self.target.n)
        self.tc_synaptic = tc_synaptic
        self.phi = phi
        self.v_rev = 0

        self.cumul_I = None
        # self.cumul_weigth = self.w.t()
        # if not hasattr(self.target, "eligibility_trace"):
        #     self.target.eligibility_trace = torch.zeros(*self.w.shape)
        # self.cumul_et = self.target.eligibility_trace.t()

    # Get dirac(delta_t)
    def get_dirac(self):
        pre_s = self.source.s.view(-1).unsqueeze(1)
        post_s = self.target.s
        return torch.max(pre_s, post_s).float(
        )  # True or 1 if a spike occured either in pre or post neuron, False or 0 otherwise

    def compute(self, s: torch.Tensor) -> None:
        # language=rst
        """
        Compute pre-activations of downstream neurons given spikes of upstream neurons.

        :param s: Incoming spikes.
        """

        # Update of the number of active neurotransmitters for each synapse
        pre_spike_occured = torch.mul(
            s.float().view(-1, 1),
            torch.ones(*self.active_neurotransmitters.shape))
        update = -self.active_neurotransmitters / self.tc_synaptic + self.phi * pre_spike_occured
        update = torch.where(self.w != 0, update, torch.tensor(0.))
        self.active_neurotransmitters += update

        # Get input
        S = torch.sum(self.active_neurotransmitters.t(), dim=1,
                      keepdim=True).view(1, -1)
        return (self.v_rev - self.target.v) * torch.max(self.w) * S
        # if self.cumul_I == None:
        #     self.cumul_I = I
        # else :
        #     self.cumul_I = torch.cat((self.cumul_I, I),0)
        # return I

    def update(self, **kwargs) -> None:
        # language=rst
        """
        Compute connection's update rule.

        Keyword arguments:

        :param bool learning: Whether to allow connection updates.
        :param ByteTensor mask: Boolean mask determining which weights to clamp to zero.
        """
        learning = kwargs["learning"]

        # self.cumul_weigth = torch.cat((self.cumul_weigth, self.w.t()),0)
        # self.cumul_et = torch.cat((self.cumul_et,self.target.eligibility_trace.t()),0)

        if learning:
            self.update_rule.update(**kwargs)

        mask = kwargs.get("mask", None)
        if mask is not None:
            self.w.masked_fill_(mask, 0)

    def normalize(self) -> None:
        """
        Normalize weights so each target neuron has sum of connection weights equal to
        ``self.norm``.
        """
        if self.norm is not None:
            w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
            w_abs_sum[w_abs_sum == 0] = 1.0
            self.w *= self.norm / w_abs_sum

    def reset_state_variables(self) -> None:
        """
        Contains resetting logic for the connection.
        """
        pass
Ejemplo n.º 21
0
class ConcatConnection(AbstractConnection):
    def __init__(self,
                 source: Dict[str, Nodes],
                 target: Nodes,
                 nu: Optional[Union[float, Sequence[float]]] = None,
                 reduction: Optional[callable] = None,
                 weight_decay: float = 0.0,
                 **kwargs) -> None:

        super().__init__(source, target, nu, reduction, weight_decay, **kwargs)

        w = kwargs.get("w", None)
        source_n = np.sum(nodes.n for nodes in list(source.values()))

        if w is None:
            if self.wmin == -np.inf or self.wmax == np.inf:
                w = torch.clamp(torch.zeros(source_n, target.n), self.wmin,
                                self.wmax)
            else:
                w = self.wmin + torch.zeros(source_n,
                                            target.n) * (self.wmax - self.wmin)
        else:
            if self.wmin != -np.inf or self.wmax != np.inf:
                w = torch.clamp(w, self.wmin, self.wmax)

        self.w = Parameter(w, requires_grad=False)
        self.b = Parameter(kwargs.get("b", torch.zeros(target.n)),
                           requires_grad=False)

    def compute(self, s: torch.Tensor) -> torch.Tensor:
        # language=rst
        """
		Compute pre-activations given spikes using connection weights.
		:param s: Incoming spikes.
		:return: Incoming spikes multiplied by synaptic weights (with or without
				 decaying spike activation).
		"""
        # Compute multiplication of spike activations by weights and add bias.
        post = s.float().view(s.size(0), -1) @ self.w + self.b
        return post.view(s.size(0), *self.target.shape)

    def update(self, **kwargs) -> None:
        # language=rst
        """
		Compute connection's update rule.
		"""
        super().update(**kwargs)

    def normalize(self) -> None:
        # language=rst
        """
		Normalize weights so each target neuron has sum of connection weights equal to
		``self.norm``.
		"""

        if self.norm is not None:
            w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
            w_abs_sum[w_abs_sum == 0] = 1.0
            self.w *= self.norm / w_abs_sum

    def reset_state_variables(self) -> None:
        # language=rst
        """
		Contains resetting logic for the connection.
		"""
        super().reset_state_variables()
Ejemplo n.º 22
0
class Connection(AbstractConnection):
    # language=rst
    """
    Specifies synapses between one or two populations of neurons.
    """

    def __init__(
        self,
        source: Nodes,
        target: Nodes,
        impulse_amplitude: float,
        impulse_length: float,
        impulse_shape_factor: float = 0.9,
        invert: bool = False,
        nu: Optional[Union[float, Sequence[float]]] = None,
        reduction: Optional[callable] = None,
        weight_decay: float = 0.0,
        post_spike_weight_decay: float = 0.0,
        **kwargs
    ) -> None:
        # language=rst
        """
        Instantiates a :code:`Connection` object.

        :param source: A layer of nodes from which the connection originates.
        :param target: A layer of nodes to which the connection connects.
        :param nu: Learning rate for both pre- and post-synaptic events.
        :param reduction: Method for reducing parameter updates along the minibatch dimension.
        :param weight_decay: Constant multiple to decay weights by on each iteration.

        Keyword arguments:

        :param LearningRule update_rule: Modifies connection parameters according to some rule.
        :param torch.Tensor w: Strengths of synapses.
        :param torch.Tensor b: Target population bias.
        :param float wmin: Minimum allowed value on the connection weights.
        :param float wmax: Maximum allowed value on the connection weights.
        :param float norm: Total weight per target neuron normalization constant.
        :param ByteTensor norm_by_max: Normalize the weight of a neuron by its max weight.
        :param ByteTensor norm_by_max_with_shadow_weights: Normalize the weight of a neuron by its max weight by
                                                           original weights.
        """
        super().__init__(source, target, impulse_amplitude, impulse_length, impulse_shape_factor, invert, nu, reduction, weight_decay, post_spike_weight_decay, **kwargs)

        w = kwargs.get("w", None)
        if w is None:
            if self.wmin == -np.inf or self.wmax == np.inf:
                w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax)
            else:
                w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin)
        else:
            if self.wmin != -np.inf or self.wmax != np.inf:
                w = torch.clamp(w, self.wmin, self.wmax)

                
                
                
                
                
                
        self.w = Parameter(w, False)
        self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), False)

        if self.norm_by_max_from_shadow_weights:
            self.shadow_w = self.w.clone().detach()
            self.prev_w = self.w.clone().detach()
    
    
    def update_impulse_state(self, s):
        self.impulse_state += (self.impulse_state > 0).float().view(-1) # adds 1 on where were spikes before
        s[self.impulse_state.unsqueeze(0) > 0] = 0
        self.source.s[self.impulse_state.unsqueeze(0) > 0] = 0
        self.impulse_state += (self.impulse_state == 0).float() * s.float().view(-1) # adds 1 on spikes
        impulse = self.impulse_curve()
        self.impulse_state *= (self.impulse_state < self.impulse_length).float()
        return impulse

    def impulse_curve(self):
        
        k = self.impulse_shape_factor
        
        if self.invert:
            
            impulse_value_2 = self.impulse_amplitude/(self.impulse_length*k - 1) #производная в точках, не находящихся в середине импульса
        
            impulse_value_1 = self.impulse_amplitude/(self.impulse_length*(1-k))
        
            impulse_bias =  2*self.impulse_amplitude *(self.impulse_state > (self.impulse_length * (1-k)+ 0.5)).float().view(-1) *(self.impulse_state <= (self.impulse_length * (1-k)+1.5)).float().view(-1)
                                               
            impulse = (-impulse_value_1) * (self.impulse_state > 0).float().view(-1) * (self.impulse_state <= (self.impulse_length * (1-k) + 0.5)).float().view(-1) + (-impulse_value_2)  * (self.impulse_state > (self.impulse_length * (1-k)+1.5)).float().view(-1) + impulse_bias   
        
            return impulse
            
        else:
            
            impulse_value = self.impulse_amplitude/(self.impulse_length - 2 )/k #производная в точках, не находящихся в середине импульса
    
            impulse_bias =  (2*self.impulse_amplitude *(abs(self.impulse_state - (self.impulse_length * k)) < 0.5).float().view(-1) + 2*self.impulse_amplitude*(abs(self.impulse_state - (self.impulse_length * k)) == 0.5).float().view(-1)*(self.impulse_state < self.impulse_length * k).float().view(-1))
                                               
            impulse = impulse_value * (self.impulse_state > 0).float().view(-1) * (self.impulse_state < self.impulse_length * k).float().view(-1) + impulse_value/((1-k)/k)  * (self.impulse_state > self.impulse_length * k).float().view(-1) - impulse_bias
        
        
            return impulse
      
    
    def compute(self, s: torch.Tensor) -> torch.Tensor:
        # language=rst
        """
        Compute pre-activations given spikes using connection weights.

        :param s: Incoming spikes.
        :return: Incoming spikes multiplied by synaptic weights (with or without
                 decaying spike activation).
        """
        # Compute multiplication of spike activations by weights and add bias.
        # language=rst
        """
        Compute pre-activations given spikes using connection weights.

        :param s: Incoming spikes.
        :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation).
        """
       
        impulse = self.update_impulse_state(s)
        self.a_pre += impulse
        #
        self.a_pre *= (self.impulse_state > 0).float()
        #
        # Compute multiplication of spike activations by connection weights.
        a_post = self.a_pre @ self.w
        return a_post.view(*self.target.shape)



    def update(self, **kwargs) -> None:
        # language=rst
        """
        Compute connection's update rule.
        """
        super().update(**kwargs)

    def normalize(self) -> None:
        # language=rst
        """
        Normalize weights so each target neuron has sum of connection weights equal to
        ``self.norm``.
        """
        if self.norm is not None:
            w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
            w_abs_sum[w_abs_sum == 0] = 1.0
            self.w *= self.norm / w_abs_sum

    def normalize_by_max(self) -> None:
        # language=rst
        """
        Normalize weights by the max weight of the target neuron.
        """
        if self.norm_by_max:
            w_max = self.w.abs().max(0)[0]
            w_max[w_max == 0] = 1.0
            self.w /= w_max

    def normalize_by_max_from_shadow_weights(self) -> None:
        # language=rst
        """
        Normalize weights by the max weight of the target neuron.
        """
        if self.norm_by_max_from_shadow_weights:
            self.shadow_w += self.w - self.prev_w
            w_max = self.shadow_w.abs().max(0)[0]
            w_max[w_max == 0] = 1.0
            self.w = self.shadow_w / w_max
            self.prev_w = self.w.clone().detach()

    def reset_(self) -> None:
        # language=rst
        """
        Contains resetting logic for the connection.
        """
        super().reset_()
        self.a_pre = torch.zeros_like(self.a_pre)
        self.impulse_state = torch.zeros_like(self.impulse_state)
Ejemplo n.º 23
0
class Gaussian2d(nn.Module):
    """
    Instantiates an object that can used to learn a point in the core feature space for each neuron,
    sampled from a Gaussian distribution with some mean and variance at train but set to mean at test time, that best predicts its response.

    The readout receives the shape of the core as 'in_shape', the number of units/neurons being predicted as 'outdims', 'bias' specifying whether
    or not bias term is to be used and 'init_range' range for initialising the mean and variance of the gaussian distribution from which we sample to
    uniform distribution, U(-init_range,init_range) and  uniform distribution, U(0.0, 3*init_range) respectively.
    The grid parameter contains the normalized locations (x, y coordinates in the core feature space) and is clipped to [-1.1] as it a
    requirement of the torch.grid_sample function. The feature parameter learns the best linear mapping between the feature
    map from a given location, sample from Gaussian at train time but set to mean at eval time, and the unit's response with or without an additional elu non-linearity.

    Args:
        in_shape (list): shape of the input feature map [channels, width, height]
        outdims (int): number of output units
        bias (bool): adds a bias term
        init_mu_range (float): initialises the the mean with Uniform([-init_range, init_range])
                            [expected: positive value <=1]
        init_sigma_range (float): initialises sigma with Uniform([0.0, init_sigma_range]).
                It is recommended however to use a fixed initialization, for faster convergence.
                For this, set fixed_sigma to True.
        batch_sample (bool): if True, samples a position for each image in the batch separately
                            [default: True as it decreases convergence time and performs just as well]
        align_corners (bool): Keyword agrument to gridsample for bilinear interpolation.
                It changed behavior in PyTorch 1.3. The default of align_corners = True is setting the
                behavior to pre PyTorch 1.3 functionality for comparability.
        fixed_sigma (bool). Recommended behavior: True. But set to false for backwards compatibility.
                If true, initialized the sigma not in a range, but with the exact value given for all neurons.
    """
    def __init__(self,
                 in_shape,
                 outdims,
                 bias,
                 init_mu_range=0.5,
                 init_sigma_range=0.5,
                 batch_sample=True,
                 align_corners=True,
                 fixed_sigma=False,
                 **kwargs):
        warnings.warn(
            "Gaussian2d is deprecated and will be removed in the future. Use `layers.readout.NonIsoGaussian2d` instead",
            DeprecationWarning,
        )
        super().__init__()
        if init_mu_range > 1.0 or init_mu_range <= 0.0 or init_sigma_range <= 0.0:
            raise ValueError(
                "either init_mu_range doesn't belong to [0.0, 1.0] or init_sigma_range is non-positive"
            )
        self.in_shape = in_shape
        c, w, h = in_shape
        self.outdims = outdims
        self.batch_sample = batch_sample
        self.grid_shape = (1, outdims, 1, 2)
        self.mu = Parameter(torch.Tensor(
            *self.grid_shape))  # mean location of gaussian for each neuron
        self.sigma = Parameter(
            torch.Tensor(*self.grid_shape
                         ))  # standard deviation for gaussian for each neuron
        self.features = Parameter(torch.Tensor(
            1, c, 1, outdims))  # feature weights for each channel of the core

        if bias:
            bias = Parameter(torch.Tensor(outdims))
            self.register_parameter("bias", bias)
        else:
            self.register_parameter("bias", None)

        self.init_mu_range = init_mu_range
        self.init_sigma_range = init_sigma_range
        self.align_corners = align_corners
        self.fixed_sigma = fixed_sigma
        self.initialize()

    def initialize(self):
        """
        Initializes the mean, and sigma of the Gaussian readout along with the features weights
        """
        self.mu.data.uniform_(-self.init_mu_range, self.init_mu_range)
        if self.fixed_sigma:
            self.sigma.data.uniform_(self.init_sigma_range,
                                     self.init_sigma_range)
        else:
            self.sigma.data.uniform_(0, self.init_sigma_range)
            warnings.warn(
                "sigma is sampled from uniform distribuiton, instead of a fixed value. Consider setting "
                "fixed_sigma to True")
        self.features.data.fill_(1 / self.in_shape[0])

        if self.bias is not None:
            self.bias.data.fill_(0)

    def sample_grid(self, batch_size, sample=None):
        """
        Returns the grid locations from the core by sampling from a Gaussian distribution
        Args:
            batch_size (int): size of the batch
            sample (bool/None): sample determines whether we draw a sample from Gaussian distribution, N(mu,sigma), defined per neuron
                            or use the mean, mu, of the Gaussian distribution without sampling.
                           if sample is None (default), samples from the N(mu,sigma) during training phase and
                             fixes to the mean, mu, during evaluation phase.
                           if sample is True/False, overrides the model_state (i.e training or eval) and does as instructed
        """
        with torch.no_grad():
            self.mu.clamp_(
                min=-1, max=1
            )  # at eval time, only self.mu is used so it must belong to [-1,1]
            self.sigma.clamp_(
                min=0)  # sigma/variance is always a positive quantity

        grid_shape = (batch_size, ) + self.grid_shape[1:]

        sample = self.training if sample is None else sample

        if sample:
            norm = self.mu.new(*grid_shape).normal_()
        else:
            norm = self.mu.new(
                *grid_shape).zero_()  # for consistency and CUDA capability

        return torch.clamp(
            norm * self.sigma + self.mu, min=-1, max=1
        )  # grid locations in feature space sampled randomly around the mean self.mu

    @property
    def grid(self):
        return self.sample_grid(batch_size=1, sample=False)

    def feature_l1(self, average=True):
        """
        Returns the l1 regularization term either the mean or the sum of all weights
        Args:
            average(bool): if True, use mean of weights for regularization

        """
        if average:
            return self.features.abs().mean()
        else:
            return self.features.abs().sum()

    def forward(self, x, sample=None, shift=None, out_idx=None):
        """
        Propagates the input forwards through the readout
        Args:
            x: input data
            sample (bool/None): sample determines whether we draw a sample from Gaussian distribution, N(mu,sigma), defined per neuron
                            or use the mean, mu, of the Gaussian distribution without sampling.
                           if sample is None (default), samples from the N(mu,sigma) during training phase and
                             fixes to the mean, mu, during evaluation phase.
                           if sample is True/False, overrides the model_state (i.e training or eval) and does as instructed
            shift (bool): shifts the location of the grid (from eye-tracking data)
            out_idx (bool): index of neurons to be predicted

        Returns:
            y: neuronal activity
        """
        N, c, w, h = x.size()
        c_in, w_in, h_in = self.in_shape
        if (c_in, w_in, h_in) != (c, w, h):
            raise ValueError(
                "the specified feature map dimension is not the readout's expected input dimension"
            )
        feat = self.features.view(1, c, self.outdims)
        bias = self.bias
        outdims = self.outdims

        if self.batch_sample:
            # sample the grid_locations separately per image per batch
            grid = self.sample_grid(
                batch_size=N,
                sample=sample)  # sample determines sampling from Gaussian
        else:
            # use one sampled grid_locations for all images in the batch
            grid = self.sample_grid(batch_size=1,
                                    sample=sample).expand(N, outdims, 1, 2)

        if out_idx is not None:
            if isinstance(out_idx, np.ndarray):
                if out_idx.dtype == bool:
                    out_idx = np.where(out_idx)[0]
            feat = feat[:, :, out_idx]
            grid = grid[:, out_idx]
            if bias is not None:
                bias = bias[out_idx]
            outdims = len(out_idx)

        if shift is not None:
            grid = grid + shift[:, None, None, :]

        y = F.grid_sample(x, grid, align_corners=self.align_corners)
        y = (y.squeeze(-1) * feat).sum(1).view(N, outdims)

        if self.bias is not None:
            y = y + bias
        return y

    def __repr__(self):
        c, w, h = self.in_shape
        r = self.__class__.__name__ + " (" + "{} x {} x {}".format(
            c, w, h) + " -> " + str(self.outdims) + ")"
        if self.bias is not None:
            r += " with bias"
        for ch in self.children():
            r += "  -> " + ch.__repr__() + "\n"
        return r