コード例 #1
0
    def test_apply_gating_dendrites(self):
        conv_layer = torch.nn.Conv2d(
            in_channels=1, out_channels=3, kernel_size=3, stride=1, bias=True
        )
        dendrite_layer = AbsoluteMaxGatingDendriticLayer2d(
            module=conv_layer,
            num_segments=20,
            dim_context=15,
            module_sparsity=0.7,
            dendrite_sparsity=0.9,
            dendrite_bias=False,
        )

        # pseudo output: batch_size=2, num_channels=3, height=2, width=2
        y = torch.tensor([
            [
                [[0.3, 0.4], [-0.2, 0.1]],
                [[-0.3, 0.5], [-0.1, 0.1]],
                [[0.0, 0.1], [0.3, 0.2]]
            ],
            [
                [[0.1, -0.2], [-0.2, 0.1]],
                [[0.0, 0.1], [-0.4, -0.1]],
                [[-0.3, 0.0], [0.2, 0.4]]
            ],
        ])

        # pseudo dendrite_activations: batch_size=2, num_channels=3, num_segments=3
        dendrite_activations = torch.tensor(
            [
                [[0.4, 0.9, -0.1], [-0.8, 0.7, 0.0], [0.6, -0.6, -0.7]],
                [[0.2, 0.8, 0.8], [-0.1, -0.4, 0.5], [0.0, 0.0, 0.0]],
            ]
        )

        # Expected absolute max dendrite activations (pre-sigmoid):
        # [[0.9  -0.8  -0.7]
        #  [0.8   0.5   0.0]]

        # Expected output based on `dendrite_activations`
        expected_output = torch.tensor([
            [
                [[0.2133, 0.2844], [-0.1422, 0.0711]],
                [[-0.093, 0.155], [-0.031, 0.031]],
                [[0.0, 0.0332], [0.0995, 0.0664]]
            ],
            [
                [[0.069, -0.138], [-0.138, 0.069]],
                [[0.0, 0.0622], [-0.249, -0.0622]],
                [[-0.15, 0.0], [0.1, 0.2]]
            ],
        ])

        actual_output = dendrite_layer.apply_dendrites(y, dendrite_activations)
        all_matches = torch.allclose(expected_output, actual_output, atol=1e-4)
        self.assertTrue(all_matches)
コード例 #2
0
    def test_gradients(self):
        """
        Ensure dendrite gradients are flowing through the layer
        `AbsoluteMaxGatingDendriticLayer2d`. Note that this test doesn't actually
        consider the values of gradients, apart from whether they are zero or non-zero.
        """
        conv_layer = torch.nn.Conv2d(in_channels=2,
                                     out_channels=3,
                                     kernel_size=2,
                                     stride=1,
                                     bias=True)
        dendrite_layer = AbsoluteMaxGatingDendriticLayer2d(
            module=conv_layer,
            num_segments=3,
            dim_context=4,
            module_sparsity=0.7,
            dendrite_sparsity=0.9,
            dendrite_bias=False,
        )

        # Dendrite weights: num_channels=3, num_segments=3, dim_context=4
        dendrite_layer.segments.weights.data[:] = torch.tensor(
            [[[-0.4933, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.3805, 0.0000],
              [0.0000, 0.0000, 0.0000, -0.1641]],
             [[0.0000, 0.0000, 0.0000, 0.3555],
              [0.0000, 0.0000, 0.0000, 0.1892],
              [0.0000, 0.0000, -0.4274, 0.0000]],
             [[0.0000, 0.0000, 0.0000, 0.0957],
              [0.0000, 0.0000, -0.0689, 0.0000],
              [0.0000, 0.0000, 0.0000, -0.3192]]])

        # Input to dendrite layer: batch_size=1, num_channels=2, width=3, height=3
        x = torch.randn((1, 2, 3, 3))

        # Context input to dendrite layer: batch_size=1, dim_context=4
        context_vectors = torch.tensor([[1.0, 0.0, 1.0, 0.0]])

        # Expected dendrite activations:
        #
        # batch item 1 (each row corresponds to an output channel)
        # [[-0.4933  0.3805    zero]
        #  [   zero    zero -0.4274]
        #  [   zero -0.0689    zero]]

        # Expected dendrite gradient mask
        #
        # batch item 1
        # [[1  0  0]
        #  [0  0  1]
        #  [0  1  0]]

        output = dendrite_layer(x, context_vectors)
        output.sum().backward()

        # Expected gradient mask
        expected_grad_mask = torch.tensor([[[1.0, 0.0, 1.0, 0.0],
                                            [0.0, 0.0, 0.0, 0.0],
                                            [0.0, 0.0, 0.0, 0.0]],
                                           [[0.0, 0.0, 0.0, 0.0],
                                            [0.0, 0.0, 0.0, 0.0],
                                            [1.0, 0.0, 1.0, 0.0]],
                                           [[0.0, 0.0, 0.0, 0.0],
                                            [1.0, 0.0, 1.0, 0.0],
                                            [0.0, 0.0, 0.0, 0.0]]])
        actual_grad_mask = 1.0 * (dendrite_layer.segments.weights.grad != 0.0)

        all_matches = (expected_grad_mask == actual_grad_mask).all()
        self.assertTrue(all_matches)
コード例 #3
0
    def test_forward(self):
        """ Validate the output values of the output tensor returned by `forward`. """

        # Initialize convolutional layer
        conv_layer = torch.nn.Conv2d(in_channels=2,
                                     out_channels=3,
                                     kernel_size=2,
                                     stride=1,
                                     bias=True)

        # Initialize dendrite layer
        dendrite_layer = AbsoluteMaxGatingDendriticLayer2d(
            module=conv_layer,
            num_segments=3,
            dim_context=4,
            module_sparsity=0.7,
            dendrite_sparsity=0.9,
            dendrite_bias=False,
        )

        # Set weights and biases of convolutional layer
        conv_layer.weight.data[:] = torch.tensor(
            [[[[0.0000, 0.3105], [-0.1523, 0.0000]],
              [[0.0000, 0.0083], [-0.2167, 0.0483]]],
             [[[0.1621, 0.0000], [-0.3283, 0.0101]],
              [[-0.1045, 0.0261], [0.0000, 0.0000]]],
             [[[0.0000, -0.0968], [0.0499, 0.0000]],
              [[0.0850, 0.0000], [0.2646, -0.3485]]]],
            requires_grad=True)
        conv_layer.bias.data[:] = torch.tensor([-0.2027, -0.1821, 0.2152],
                                               requires_grad=True)

        # Dendrite weights: num_channels=3, num_segments=3, dim_context=4
        dendrite_layer.segments.weights.data[:] = torch.tensor(
            [[[-0.4933, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.3805, 0.0000],
              [0.0000, 0.0000, 0.0000, -0.1641]],
             [[0.0000, 0.0000, 0.0000, 0.3555],
              [0.0000, 0.0000, 0.0000, 0.1892],
              [0.0000, 0.0000, -0.4274, 0.0000]],
             [[0.0000, 0.0000, 0.0000, 0.0957],
              [0.0000, 0.0000, -0.0689, 0.0000],
              [0.0000, 0.0000, 0.0000, -0.3192]]])

        # Input to dendrite layer: batch_size=2, num_channels=2, width=3, height=3
        x = torch.tensor([[[[0.1553, 0.3405, 0.2367], [0.7661, 0.1383, 0.6675],
                            [0.6464, 0.1559, 0.9777]],
                           [[0.4114, 0.6362, 0.7020], [0.2617, 0.2275, 0.4238],
                            [0.6374, 0.8270, 0.7528]]],
                          [[[0.8331, 0.7792, 0.4369], [0.7947, 0.2609, 0.1992],
                            [0.1527, 0.3006, 0.5496]],
                           [[0.6811, 0.6871, 0.0148], [0.6084, 0.8351, 0.5382],
                            [0.7421, 0.8639, 0.7444]]]])

        # Context input to dendrite layer: batch_size=2, dim_context=4
        context_vectors = torch.tensor([[1.0, 0.0, 1.0, 0.0],
                                        [0.0, 1.0, 0.0, 0.0]])

        # Expected absolute max dendrite activations (pre-sigmoid):
        # [[-0.4933 -0.4274 -0.0689]
        #  [   zero    zero    zero]]

        # Expected output of convolutional layer:
        #
        # batch item 1 (each row corresponds to an output channel)
        # [[[-0.2541 -0.1733] [-0.3545 -0.1585]]
        #  [[-0.4334 -0.2137] [-0.2900 -0.2137]]
        #  [[ 0.2454  0.1658] [ 0.1368  0.1342]]]
        #
        # batch item 2
        # [[[-0.1676 -0.2616] [-0.2571 -0.3334]]
        #  [[-0.3586 -0.2108] [-0.1422 -0.3062]]
        #  [[ 0.1073  0.2777] [ 0.1446  0.2511]]]

        # Overall expected output of dendrite layer:
        #
        # batch item 1 (each row corresponds to an output channel)
        # [[[-0.0963335  -0.06570089] [-0.13439679 -0.06008996]]
        #  [[-0.17108351 -0.08435751] [-0.11447673 -0.08435751]]
        #  [[ 0.11847466  0.08004522] [ 0.06604455  0.06478932]]]
        #
        # batch item 2
        # [[[-0.0838 -0.1308] [-0.1285 -0.1667]]
        #  [[-0.1793 -0.1054] [-0.0711 -0.1531]]
        #  [[ 0.0536  0.1389] [ 0.0723  0.1256]]]

        expected_output = torch.tensor([[[[-0.0963335, -0.06570089],
                                          [-0.13439679, -0.06008996]],
                                         [[-0.17108351, -0.08435751],
                                          [-0.11447673, -0.08435751]],
                                         [[0.11847466, 0.08004522],
                                          [0.06604455, 0.06478932]]],
                                        [[[-0.0838, -0.1308],
                                          [-0.1285, -0.1667]],
                                         [[-0.1793, -0.1054],
                                          [-0.0711, -0.1531]],
                                         [[0.0536, 0.1389], [0.0723,
                                                             0.1256]]]])

        actual_output = dendrite_layer(x, context_vectors)
        self.assertTrue(
            torch.allclose(expected_output, actual_output, atol=1e-4))