Пример #1
0
def _ann_to_snn_helper(prev, current, nxt):
    if isinstance(current, nn.Linear):
        layer = SubtractiveResetIFNodes(n=current.out_features, reset=0, thresh=1, refrac=0)
        connection = topology.Connection(
            source=prev, target=layer, w=current.weight.t(), b=current.bias
        )

    elif isinstance(current, nn.Conv2d):
        input_height, input_width = prev.shape[2], prev.shape[3]
        out_channels, output_height, output_width = current.out_channels, prev.shape[2], prev.shape[3]

        width = (input_height - current.kernel_size[0] + 2 * current.padding[0]) / current.stride[0] + 1
        height = (input_width - current.kernel_size[1] + 2 * current.padding[1]) / current.stride[1] + 1
        shape = (1, out_channels, int(width), int(height))

        layer = SubtractiveResetIFNodes(
            shape=shape, reset=0, thresh=1, refrac=0
        )
        connection = topology.Conv2dConnection(
            source=prev, target=layer, kernel_size=current.kernel_size, stride=current.stride,
            padding=current.padding, dilation=current.dilation, w=current.weight, b=current.bias
        )

    elif isinstance(current, nn.MaxPool2d):
        input_height, input_width = prev.shape[2], prev.shape[3]
        current.kernel_size = _pair(current.kernel_size)
        current.padding = _pair(current.padding)
        current.stride = _pair(current.stride)

        width = (input_height - current.kernel_size[0] + 2 * current.padding[0]) / current.stride[0] + 1
        height = (input_width - current.kernel_size[1] + 2 * current.padding[1]) / current.stride[1] + 1
        shape = (1, prev.shape[1], int(width), int(height))

        layer = PassThroughNodes(
            shape=shape
        )
        connection = topology.MaxPool2dConnection(
            source=prev, target=layer, kernel_size=current.kernel_size, stride=current.stride,
            padding=current.padding, dilation=current.dilation, decay=1
        )

    else:
        return None, None

    return layer, connection
Пример #2
0
def _ann_to_snn_helper(prev, current, node_type, **kwargs):
    # language=rst
    """
    Helper function for main ``ann_to_snn`` method.

    :param prev: Previous PyTorch module in artificial neural network.
    :param current: Current PyTorch module in artificial neural network.
    :return: Spiking neural network layer and connection corresponding to ``prev`` and ``current`` PyTorch modules.
    """
    if isinstance(current, nn.Linear):
        layer = node_type(n=current.out_features,
                          reset=0,
                          thresh=1,
                          refrac=0,
                          **kwargs)
        bias = current.bias if current.bias is not None else torch.zeros(
            layer.n)
        connection = topology.Connection(source=prev,
                                         target=layer,
                                         w=current.weight.t(),
                                         b=bias)

    elif isinstance(current, nn.Conv2d):
        input_height, input_width = prev.shape[2], prev.shape[3]
        out_channels, output_height, output_width = (
            current.out_channels,
            prev.shape[2],
            prev.shape[3],
        )

        width = (input_height - current.kernel_size[0] +
                 2 * current.padding[0]) / current.stride[0] + 1
        height = (input_width - current.kernel_size[1] +
                  2 * current.padding[1]) / current.stride[1] + 1
        shape = (1, out_channels, int(width), int(height))

        layer = node_type(shape=shape, reset=0, thresh=1, refrac=0, **kwargs)
        bias = current.bias if current.bias is not None else torch.zeros(
            layer.shape[1])
        connection = topology.Conv2dConnection(
            source=prev,
            target=layer,
            kernel_size=current.kernel_size,
            stride=current.stride,
            padding=current.padding,
            dilation=current.dilation,
            w=current.weight,
            b=bias,
        )

    elif isinstance(current, nn.MaxPool2d):
        input_height, input_width = prev.shape[2], prev.shape[3]
        current.kernel_size = _pair(current.kernel_size)
        current.padding = _pair(current.padding)
        current.stride = _pair(current.stride)

        width = (input_height - current.kernel_size[0] +
                 2 * current.padding[0]) / current.stride[0] + 1
        height = (input_width - current.kernel_size[1] +
                  2 * current.padding[1]) / current.stride[1] + 1
        shape = (1, prev.shape[1], int(width), int(height))

        layer = PassThroughNodes(shape=shape)
        connection = topology.MaxPool2dConnection(
            source=prev,
            target=layer,
            kernel_size=current.kernel_size,
            stride=current.stride,
            padding=current.padding,
            dilation=current.dilation,
            decay=1,
        )

    elif isinstance(current, Permute):
        layer = PassThroughNodes(shape=[
            prev.shape[current.dims[0]],
            prev.shape[current.dims[1]],
            prev.shape[current.dims[2]],
            prev.shape[current.dims[3]],
        ])

        connection = PermuteConnection(source=prev,
                                       target=layer,
                                       dims=current.dims)

    elif isinstance(current, nn.ConstantPad2d):
        layer = PassThroughNodes(shape=[
            prev.shape[0],
            prev.shape[1],
            current.padding[0] + current.padding[1] + prev.shape[2],
            current.padding[2] + current.padding[3] + prev.shape[3],
        ])

        connection = ConstantPad2dConnection(source=prev,
                                             target=layer,
                                             padding=current.padding)

    else:
        return None, None

    return layer, connection
Пример #3
0
def _ann_to_snn_helper(prev, current, scale):
    # language=rst
    """
    Helper function for main ``ann_to_snn`` method.

    :param prev: Previous PyTorch module in artificial neural network.
    :param current: Current PyTorch module in artificial neural network.
    :return: Spiking neural network layer and connection corresponding to ``prev`` and ``current`` PyTorch modules.
    """
    if isinstance(current, nn.Linear):
        layer = LIFNodes(n=current.out_features,
                         refrac=0,
                         traces=True,
                         thresh=-52,
                         rest=-65.0,
                         decay=1e-2)
        connection = topology.Connection(source=prev,
                                         target=layer,
                                         w=current.weight.t() * scale)

    elif isinstance(current, nn.Conv2d):
        input_height, input_width = prev.shape[2], prev.shape[3]
        out_channels, output_height, output_width = current.out_channels, prev.shape[
            2], prev.shape[3]

        width = (input_height - current.kernel_size[0] +
                 2 * current.padding[0]) / current.stride[0] + 1
        height = (input_width - current.kernel_size[1] +
                  2 * current.padding[1]) / current.stride[1] + 1
        shape = (1, out_channels, int(width), int(height))

        layer = LIFNodes(
            shape=shape,
            refrac=0,
            traces=True,
            thresh=-52,
            rest=-65.0,
            decay=1e-2,
        )
        connection = topology.Conv2dConnection(source=prev,
                                               target=layer,
                                               kernel_size=current.kernel_size,
                                               stride=current.stride,
                                               padding=current.padding,
                                               dilation=current.dilation,
                                               w=current.weight * scale)

    elif isinstance(current, Permute):
        layer = PassThroughNodes(shape=[
            prev.shape[current.dims[0]], prev.shape[current.dims[1]],
            prev.shape[current.dims[2]], prev.shape[current.dims[3]]
        ])

        connection = PermuteConnection(source=prev,
                                       target=layer,
                                       dims=current.dims)

    elif isinstance(current, nn.ConstantPad2d):
        layer = PassThroughNodes(shape=[
            prev.shape[0], prev.shape[1], current.padding[0] +
            current.padding[1] + prev.shape[2], current.padding[2] +
            current.padding[3] + prev.shape[3]
        ])

        connection = ConstantPad2dConnection(source=prev,
                                             target=layer,
                                             padding=current.padding)

    else:
        return None, None

    return layer, connection