Пример #1
0
    def test_sparsity(self):
        """
        Validate that sparsity is enforced per unit per segement.
        """
        sparsity = 9 / 15
        dendrite_segments = DendriteSegments(num_units=10,
                                             num_segments=20,
                                             dim_context=15,
                                             sparsity=sparsity,
                                             bias=True)

        weights = dendrite_segments.weights
        weights[:] = 1
        dendrite_segments.rezero_weights()

        for unit in range(dendrite_segments.num_units):
            for segment in range(dendrite_segments.num_segments):
                w = weights[unit, segment, :]
                num_off = (weights[unit, segment, :] == 0).sum().item()
                actual_sparsity = num_off / w.numel()
                self.assertEqual(
                    sparsity,
                    actual_sparsity,
                    f"Sparsity {actual_sparsity} != {sparsity}"
                    f"for unit {unit} and segment {segment}",
                )
Пример #2
0
    def __init__(self, input_size, output_size,
                 hidden_sizes=(10, 10),
                 dim_context=2,
                 num_segments=(5, 5, 5),
                 module_sparsity=(.75, .75, .75),
                 dendrite_sparsity=(.5, .5, .5),
                 dendrite_bias=(False, False, False),
                 activation_fn=nn.ReLU):
        super().__init__()

        self.input_size = input_size

        self.block0 = nn.Sequential(
            nn.Linear(input_size, hidden_sizes[0]),
            nn.BatchNorm1d(hidden_sizes[0], affine=False),
            activation_fn()
        )

        self.segments0 = DendriteSegments(
            num_units=hidden_sizes[0],
            num_segments=num_segments[0],
            dim_context=dim_context,
            sparsity=dendrite_sparsity[0],
            bias=dendrite_bias[0],
        )

        self.block1 = nn.Sequential(
            nn.Linear(hidden_sizes[0], hidden_sizes[1]),
            nn.BatchNorm1d(hidden_sizes[1], affine=False),
            activation_fn()
        )

        self.segments1 = DendriteSegments(
            num_units=hidden_sizes[1],
            num_segments=num_segments[1],
            dim_context=dim_context,
            sparsity=dendrite_sparsity[1],
            bias=dendrite_bias[1],
        )

        self.block2 = nn.Sequential(
            nn.Linear(hidden_sizes[1], hidden_sizes[2]),
            nn.BatchNorm1d(hidden_sizes[2], affine=False),
            activation_fn()
        )

        self.segments2 = DendriteSegments(
            num_units=hidden_sizes[2],
            num_segments=num_segments[2],
            dim_context=dim_context,
            sparsity=dendrite_sparsity[2],
            bias=dendrite_bias[2],
        )

        self.classifier = nn.Linear(hidden_sizes[-1], output_size)
Пример #3
0
    def test_forward(self):
        """Validate shape of forward output."""

        dendrite_segments = DendriteSegments(
            num_units=10, num_segments=20, dim_context=15, sparsity=0.7, bias=True
        )
        dendrite_segments.rezero_weights()

        batch_size = 8
        context = torch.rand(batch_size, dendrite_segments.dim_context)
        out = dendrite_segments(context)
        self.assertEqual(out.shape, (8, 10, 20))
Пример #4
0
    def test_equivalent_forward(self):
        """
        Validate output with respect to an equivalent operation:
        applying the dendrite segments one-by-one for each unit.
        """
        dendrite_segments = DendriteSegments(num_units=10,
                                             num_segments=20,
                                             dim_context=15,
                                             sparsity=0.7,
                                             bias=True)

        batch_size = 8
        context = torch.rand(batch_size, dendrite_segments.dim_context)
        out = dendrite_segments(
            context)  # shape batch_size x num_units x num_segments

        weights = dendrite_segments.weights
        biases = dendrite_segments.biases
        for unit in range(dendrite_segments.num_units):
            unit_weight = weights[unit, ...]
            unit_bias = biases[unit, ...]

            expected_out = torch.nn.functional.linear(context, unit_weight,
                                                      unit_bias)
            actual_out = out[:, unit, :]
            same_out = torch.allclose(actual_out, expected_out, atol=1e-7)
            self.assertTrue(
                same_out,
                f"Didn't observe the expected output for unit {unit}: "
                f"actual_out - expected_out = {actual_out - expected_out}",
            )
    def __init__(self,
                 num_classes,
                 num_segments=10,
                 dendrite_sparsity=0.856,
                 dendrite_bias=True):

        super().__init__()

        self.segments = DendriteSegments(
            num_units=2304,
            num_segments=num_segments,
            dim_context=1008,
            sparsity=dendrite_sparsity,
            bias=dendrite_bias,
        )
        self.classifier = nn.Linear(2304, num_classes)

        self.prediction = nn.Sequential(
            *self.conv_block(3, 256, 3, 1, 0),
            nn.MaxPool2d(kernel_size=2, stride=2),
            *self.conv_block(256, 256, 3, 1, 0),
            nn.MaxPool2d(kernel_size=2, stride=2),
            *self.conv_block(256, 256, 3, 1, 0),
            nn.AdaptiveAvgPool2d(output_size=(3, 3)),
            nn.Flatten(),
        )

        self.modulation = nn.Sequential(
            *self.conv_block(3, 112, 3, 1, 0),
            nn.MaxPool2d(kernel_size=2, stride=2),
            *self.conv_block(112, 112, 3, 1, 0),
            nn.MaxPool2d(kernel_size=2, stride=2),
            *self.conv_block(112, 112, 3, 1, 0),
            nn.AdaptiveAvgPool2d(output_size=(3, 3)),
            nn.Flatten(),
        )

        self.dendritic_gate = DendriticAbsoluteMaxGate1d()

        # Apply Kaiming initialization
        self.reset_params()