Ejemplo n.º 1
0
class fofe_filter(nn.Module):
    def __init__(self, inplanes, alpha=0.8, length=3, inverse=False):
        super(fofe_filter, self).__init__()
        self.length = length
        self.channels = inplanes
        self.alpha = alpha
        self.fofe_filter = Parameter(torch.Tensor(inplanes, 1, length))
        self.fofe_filter.requires_grad_(False)
        self._init_filter(alpha, length, inverse)
        self.padding = (length - 1) // 2

    def _init_filter(self, alpha, length, inverse):
        if not inverse:
            self.fofe_filter[:, :, ].copy_(
                torch.pow(alpha, torch.linspace(length - 1, 0, length)))
        else:
            self.fofe_filter[:, :, ].copy_(
                torch.pow(alpha, torch.range(0, length - 1)))

    def fofe_encode(self, x):
        out = F.pad(x, (self.length - 1, 0), mode='constant', value=0)
        out = F.conv1d(out,
                       self.fofe_filter,
                       bias=None,
                       stride=1,
                       padding=0,
                       groups=self.channels)
        return out

    def forward(self, x):
        if self.alpha == 1 or self.alpha == 0:
            return x
        x = self.fofe_encode(x)
        return x
Ejemplo n.º 2
0
class fofe_conv1d(nn.Module):
    def __init__(self,
                 emb_dims,
                 alpha=0.9,
                 length=1,
                 dilation=1,
                 inverse=False):
        super(fofe_conv1d, self).__init__()
        self.alpha = alpha
        self.length = length
        self.channels = emb_dims
        self.fofe_filter = Parameter(torch.Tensor(emb_dims, 1, length))
        self.fofe_filter.requires_grad_(False)
        self._init_filter(emb_dims, alpha, length, inverse)
        self.padding = (length - 1) // 2
        self.dilated_conv = nn.Sequential(
            nn.Conv1d(self.channels,
                      self.channels,
                      3,
                      1,
                      padding=length,
                      dilation=dilation,
                      groups=1,
                      bias=False), nn.LeakyReLU(0.1, inplace=True))

    def _init_filter(self, channels, alpha, length, inverse):
        if not inverse:
            self.fofe_filter[:, :, ].copy_(
                torch.pow(self.alpha, torch.linspace(length - 1, 0, length)))
        else:
            self.fofe_filter[:, :, ].copy_(
                torch.pow(self.alpha, torch.range(0, length - 1)))

    def forward(self, x):
        x = torch.transpose(x, -2, -1)
        if (self.length % 2 == 0):
            x = F.pad(x, (0, 1), mode='constant', value=0)
        x = F.conv1d(x,
                     self.fofe_filter,
                     bias=None,
                     stride=1,
                     padding=self.padding,
                     groups=self.channels)
        x = self.dilated_conv(x)
        return x
Ejemplo n.º 3
0
        def __init__(
            self,
            modules: Sequence[Module],
            original: Union[Tensor, Parameter],
            unsafe: bool = False,
        ) -> None:
            # We require this because we need to treat differently the first parametrization
            # This should never throw, unless this class is used from the outside
            if len(modules) == 0:
                raise ValueError("ParametrizationList requires one or more modules.")

            super().__init__(modules)
            self.unsafe = unsafe

            # In plain words:
            # module.weight must keep its dtype and shape.
            # Furthermore, if there is no right_inverse or the right_inverse returns a tensor,
            # this should be of the same dtype as the original tensor
            #
            # We check that the following invariants hold:
            #    X = module.weight
            #    Y = param.right_inverse(X)
            #    assert isinstance(Y, Tensor) or
            #           (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y))
            #    Z = param(Y) if isisntance(Y, Tensor) else param(*Y)
            #    # Consistency checks
            #    assert X.dtype == Z.dtype and X.shape == Z.shape
            #    # If it has one input, this allows to be able to use set_ to be able to
            #    # move data to/from the original tensor without changing its id (which is what the
            #    # optimiser uses to track parameters)
            #    if isinstance(Y, Tensor)
            #      assert X.dtype == Y.dtype
            # Below we use original = X, new = Y

            original_shape = original.shape
            original_dtype = original.dtype

            # Compute new
            with torch.no_grad():
                new = original
                for module in reversed(self):  # type: ignore[call-overload]
                    if hasattr(module, "right_inverse"):
                        try:
                            new = module.right_inverse(new)
                        except NotImplementedError:
                            pass
                    # else, or if it throws, we assume that right_inverse is the identity

            if not isinstance(new, Tensor) and not isinstance(
                new, collections.abc.Sequence
            ):
                raise ValueError(
                    "'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "
                    f"Got {type(new).__name__}"
                )

            # Set the number of original tensors
            self.is_tensor = isinstance(new, Tensor)
            self.ntensors = 1 if self.is_tensor else len(new)

            # Register the tensor(s)
            if self.is_tensor:
                if original.dtype != new.dtype:
                    raise ValueError(
                        "When `right_inverse` outputs one tensor, it may not change the dtype.\n"
                        f"original.dtype: {original.dtype}\n"
                        f"right_inverse(original).dtype: {new.dtype}"
                    )
                # Set the original to original so that the user does not need to re-register the parameter
                # manually in the optimiser
                with torch.no_grad():
                    original.set_(new)  # type: ignore[call-overload]
                _register_parameter_or_buffer(self, "original", original)
            else:
                for i, originali in enumerate(new):
                    if not isinstance(originali, Tensor):
                        raise ValueError(
                            "'right_inverse' must return a Tensor or a Sequence of tensors "
                            "(list, tuple...). "
                            f"Got element {i} of the sequence with type {type(originali).__name__}."
                        )

                    # If the original tensor was a Parameter that required grad, we expect the user to
                    # add the new parameters to the optimizer after registering the parametrization
                    # (this is documented)
                    if isinstance(original, Parameter):
                        originali = Parameter(originali)
                    originali.requires_grad_(original.requires_grad)
                    _register_parameter_or_buffer(self, f"original{i}", originali)

            if not self.unsafe:
                # Consistency checks:
                # Since f : A -> B, right_inverse : B -> A, Z and original should live in B
                # Z = forward(right_inverse(original))
                Z = self()
                if not isinstance(Z, Tensor):
                    raise ValueError(
                        f"A parametrization must return a tensor. Got {type(Z).__name__}."
                    )
                if Z.dtype != original_dtype:
                    raise ValueError(
                        "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n"
                        f"unparametrized dtype: {original_dtype}\n"
                        f"parametrized dtype: {Z.dtype}"
                    )
                if Z.shape != original_shape:
                    raise ValueError(
                        "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n"
                        f"unparametrized shape: {original_shape}\n"
                        f"parametrized shape: {Z.shape}"
                    )
Ejemplo n.º 4
0
class Convolution(nn.Module):
    r"""Performs a 2D convolution over an input spike-wave composed of several input
    planes. Current version only supports stride of 1 with no padding.
    The input is a 4D tensor with the size :math:`(T, C_{{in}}, H_{{in}}, W_{{in}})` and the crresponsing output
    is of size :math:`(T, C_{{out}}, H_{{out}}, W_{{out}})`,
    where :math:`T` is the number of time steps, :math:`C` is the number of feature maps (channels), and
    :math:`H`, and :math:`W` are the hight and width of the input/output planes.
    * :attr:`in_channels` controls the number of input planes (channels/feature maps).
    * :attr:`out_channels` controls the number of feature maps in the current layer.
    * :attr:`kernel_size` controls the size of the convolution kernel. It can be a single integer or a tuple of two integers.
    * :attr:`weight_mean` controls the mean of the normal distribution used for initial random weights.
    * :attr:`weight_std` controls the standard deviation of the normal distribution used for initial random weights.
    .. note::
            Since this version of convolution does not support padding, it is the user responsibility to add proper padding
            on the input before applying convolution.
    Args:
            in_channels (int): Number of channels in the input.
            out_channels (int): Number of channels produced by the convolution.
            kernel_size (int or tuple): Size of the convolving kernel.
            weight_mean (float, optional): Mean of the initial random weights. Default: 0.8
            weight_std (float, optional): Standard deviation of the initial random weights. Default: 0.02
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 weight_mean=0.8,
                 weight_std=0.02):
        super(Convolution, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = to_pair(kernel_size)
        #self.weight_mean = weight_mean
        #self.weight_std = weight_std

        # For future use
        self.stride = 1
        self.bias = None
        self.dilation = 1
        self.groups = 1
        self.padding = 0

        # Parameters
        self.weight = Parameter(
            torch.Tensor(self.out_channels, self.in_channels,
                         *self.kernel_size))
        self.weight.requires_grad_(False)  # We do not use gradients
        self.reset_weight(weight_mean, weight_std)

    def reset_weight(self, weight_mean=0.8, weight_std=0.02):
        """Resets weights to random values based on a normal distribution.
        Args:
                weight_mean (float, optional): Mean of the random weights. Default: 0.8
                weight_std (float, optional): Standard deviation of the random weights. Default: 0.02
        """
        self.weight.normal_(weight_mean, weight_std)

    def load_weight(self, target):
        """Loads weights with the target tensor.
        Args:
                target (Tensor=): The target tensor.
        """
        self.weight.copy_(target)

    def forward(self, input):
        return fn.conv2d(input, self.weight, self.bias, self.stride,
                         self.padding, self.dilation, self.groups)
class ReLuTransformer(torch.nn.Module):
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self, D_features,x_shape):
        super(ReLuTransformer, self).__init__()
        self.in_features = D_features
        self.out_features = D_features+x_shape
        self.slopes = Parameter(torch.Tensor(x_shape))
        self.slopes.requires_grad_(True)
        self.x_shape=x_shape
        self.initialized=False
        #self.hyper_diag=torch.diagflat(torch.ones(np.prod(x_shape))).view((-1,) + x_shape)

    def reset_parameters(self,A_in):
        A = A_in[1].detach()
        a0 = A_in[0].detach()
        temp_slope = 0.5 * (1.0 + torch.div(a0, torch.abs(A).sum(axis=0)))
        temp_slope = torch.clamp(temp_slope,0,1)
        self.slopes.data=temp_slope

    def forward(self, A_in):
        if not self.initialized:
            ReLuTransformer.reset_parameters(self, A_in)
            self.initialized = True
        A = A_in[1]     # eps x x_shape
        a0 = A_in[0]    # x_shape

        l_x = -torch.abs(A).sum(axis=0) + a0
        u_x =  torch.abs(A).sum(axis=0) + a0

        crossover_idx = torch.mul((l_x < 0).float(), (u_x > 0).float())
        offset_factor = torch.max(u_x*(1 - self.slopes), -l_x*self.slopes) / 2.

        a0_n = torch.mul((1 - crossover_idx), torch.nn.functional.relu(a0)) \
                 + torch.mul(crossover_idx, torch.mul(self.slopes, a0) + offset_factor)

        A_new_e, A_0 = self.assemble_hyper_diagonal(self.x_shape, offset_factor, crossover_idx)

        A_n = torch.mul((l_x > 0).float().unsqueeze(0), torch.cat((A, A_0),dim=0))\
                + torch.mul(crossover_idx.unsqueeze(0), torch.cat((torch.mul(self.slopes.unsqueeze(0), A), A_new_e), dim=0))

        if DEBUG:
            l_x_n = -torch.abs(A_n).sum(axis=0) + a0_n
            u_x_n = torch.abs(A_n).sum(axis=0) + a0_n

            if not (((u_x_n-u_x)/abs((u_x_n-u_x)).median())>-1e-4).all() and (u_x_n-u_x).min()<-1e-5:
                print("ReLU upper bound soundness check failed")
            if not (((l_x_n-l_x)/abs((l_x_n-l_x)).median())>-1e-4).all() and (l_x_n-l_x).min()<-1e-5:
                print("ReLU lower bound soundness check failed")
        return (a0_n,A_n)

    def assemble_hyper_diagonal(self,x_shape,offset_factor,crossover_idx):
        # Build Tensor for new epsilons corresponding to a hyper diagonal in 4D (eps x channel x x_shape) or
        # 2D (eps x x_shape) with offset_factor on the entries indicated by crossover_idx

        '''A_new = torch.zeros((int(crossover_idx.sum()),) + x_shape, dtype=torch.float32)  # Channels x space dim x eps space dim
        A_0=torch.clone(A_new)
        k_i = 0
        if len(x_shape) > 1:
            for i in range(x_shape[0]):
                for j in range(x_shape[1]):
                    for l in np.array(range(x_shape[2]))[crossover_idx[i,j,:]==1]:
                        A_new[k_i, i, j, l] = offset_factor[i, j, l]
                        k_i += 1

        else:
            for i in np.array(range(x_shape[0]))[crossover_idx==1]:
                A_new[k_i, i] = offset_factor[i]
                k_i += 1
        '''
        # A_new=torch.diagflat(offset_factor)[crossover_idx.flatten() == 1, :].view((int(crossover_idx.sum()),) + x_shape)

        k=int(crossover_idx.sum())
        A_new=torch.sparse.FloatTensor(torch.cat([torch.arange(k).view(-1, 1), crossover_idx.nonzero()], dim=1).T,
                                 offset_factor[crossover_idx == 1].flatten(), torch.Size((k,) + x_shape)).to_dense()
        '''if len(x_shape) > 1:
            A_new=torch.mul(self.hyper_diag[crossover_idx.flatten() == 1, :, :, :],
                      offset_factor[crossover_idx == 1].flatten().view((k, 1, 1, 1)), )
        else:
            A_new = torch.mul(self.hyper_diag[crossover_idx.flatten() == 1, :],
                              offset_factor[crossover_idx == 1].flatten().view((k, 1)), )
        '''
        A_0=torch.zeros((int(crossover_idx.sum()),) + x_shape, dtype=torch.float32)

        return A_new, A_0