def test_apply_gating_dendrites(self): conv_layer = torch.nn.Conv2d( in_channels=1, out_channels=3, kernel_size=3, stride=1, bias=True ) dendrite_layer = AbsoluteMaxGatingDendriticLayer2d( module=conv_layer, num_segments=20, dim_context=15, module_sparsity=0.7, dendrite_sparsity=0.9, dendrite_bias=False, ) # pseudo output: batch_size=2, num_channels=3, height=2, width=2 y = torch.tensor([ [ [[0.3, 0.4], [-0.2, 0.1]], [[-0.3, 0.5], [-0.1, 0.1]], [[0.0, 0.1], [0.3, 0.2]] ], [ [[0.1, -0.2], [-0.2, 0.1]], [[0.0, 0.1], [-0.4, -0.1]], [[-0.3, 0.0], [0.2, 0.4]] ], ]) # pseudo dendrite_activations: batch_size=2, num_channels=3, num_segments=3 dendrite_activations = torch.tensor( [ [[0.4, 0.9, -0.1], [-0.8, 0.7, 0.0], [0.6, -0.6, -0.7]], [[0.2, 0.8, 0.8], [-0.1, -0.4, 0.5], [0.0, 0.0, 0.0]], ] ) # Expected absolute max dendrite activations (pre-sigmoid): # [[0.9 -0.8 -0.7] # [0.8 0.5 0.0]] # Expected output based on `dendrite_activations` expected_output = torch.tensor([ [ [[0.2133, 0.2844], [-0.1422, 0.0711]], [[-0.093, 0.155], [-0.031, 0.031]], [[0.0, 0.0332], [0.0995, 0.0664]] ], [ [[0.069, -0.138], [-0.138, 0.069]], [[0.0, 0.0622], [-0.249, -0.0622]], [[-0.15, 0.0], [0.1, 0.2]] ], ]) actual_output = dendrite_layer.apply_dendrites(y, dendrite_activations) all_matches = torch.allclose(expected_output, actual_output, atol=1e-4) self.assertTrue(all_matches)
def test_gradients(self): """ Ensure dendrite gradients are flowing through the layer `AbsoluteMaxGatingDendriticLayer2d`. Note that this test doesn't actually consider the values of gradients, apart from whether they are zero or non-zero. """ conv_layer = torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=2, stride=1, bias=True) dendrite_layer = AbsoluteMaxGatingDendriticLayer2d( module=conv_layer, num_segments=3, dim_context=4, module_sparsity=0.7, dendrite_sparsity=0.9, dendrite_bias=False, ) # Dendrite weights: num_channels=3, num_segments=3, dim_context=4 dendrite_layer.segments.weights.data[:] = torch.tensor( [[[-0.4933, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.3805, 0.0000], [0.0000, 0.0000, 0.0000, -0.1641]], [[0.0000, 0.0000, 0.0000, 0.3555], [0.0000, 0.0000, 0.0000, 0.1892], [0.0000, 0.0000, -0.4274, 0.0000]], [[0.0000, 0.0000, 0.0000, 0.0957], [0.0000, 0.0000, -0.0689, 0.0000], [0.0000, 0.0000, 0.0000, -0.3192]]]) # Input to dendrite layer: batch_size=1, num_channels=2, width=3, height=3 x = torch.randn((1, 2, 3, 3)) # Context input to dendrite layer: batch_size=1, dim_context=4 context_vectors = torch.tensor([[1.0, 0.0, 1.0, 0.0]]) # Expected dendrite activations: # # batch item 1 (each row corresponds to an output channel) # [[-0.4933 0.3805 zero] # [ zero zero -0.4274] # [ zero -0.0689 zero]] # Expected dendrite gradient mask # # batch item 1 # [[1 0 0] # [0 0 1] # [0 1 0]] output = dendrite_layer(x, context_vectors) output.sum().backward() # Expected gradient mask expected_grad_mask = torch.tensor([[[1.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 1.0, 0.0]], [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) actual_grad_mask = 1.0 * (dendrite_layer.segments.weights.grad != 0.0) all_matches = (expected_grad_mask == actual_grad_mask).all() self.assertTrue(all_matches)
def test_forward(self): """ Validate the output values of the output tensor returned by `forward`. """ # Initialize convolutional layer conv_layer = torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=2, stride=1, bias=True) # Initialize dendrite layer dendrite_layer = AbsoluteMaxGatingDendriticLayer2d( module=conv_layer, num_segments=3, dim_context=4, module_sparsity=0.7, dendrite_sparsity=0.9, dendrite_bias=False, ) # Set weights and biases of convolutional layer conv_layer.weight.data[:] = torch.tensor( [[[[0.0000, 0.3105], [-0.1523, 0.0000]], [[0.0000, 0.0083], [-0.2167, 0.0483]]], [[[0.1621, 0.0000], [-0.3283, 0.0101]], [[-0.1045, 0.0261], [0.0000, 0.0000]]], [[[0.0000, -0.0968], [0.0499, 0.0000]], [[0.0850, 0.0000], [0.2646, -0.3485]]]], requires_grad=True) conv_layer.bias.data[:] = torch.tensor([-0.2027, -0.1821, 0.2152], requires_grad=True) # Dendrite weights: num_channels=3, num_segments=3, dim_context=4 dendrite_layer.segments.weights.data[:] = torch.tensor( [[[-0.4933, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.3805, 0.0000], [0.0000, 0.0000, 0.0000, -0.1641]], [[0.0000, 0.0000, 0.0000, 0.3555], [0.0000, 0.0000, 0.0000, 0.1892], [0.0000, 0.0000, -0.4274, 0.0000]], [[0.0000, 0.0000, 0.0000, 0.0957], [0.0000, 0.0000, -0.0689, 0.0000], [0.0000, 0.0000, 0.0000, -0.3192]]]) # Input to dendrite layer: batch_size=2, num_channels=2, width=3, height=3 x = torch.tensor([[[[0.1553, 0.3405, 0.2367], [0.7661, 0.1383, 0.6675], [0.6464, 0.1559, 0.9777]], [[0.4114, 0.6362, 0.7020], [0.2617, 0.2275, 0.4238], [0.6374, 0.8270, 0.7528]]], [[[0.8331, 0.7792, 0.4369], [0.7947, 0.2609, 0.1992], [0.1527, 0.3006, 0.5496]], [[0.6811, 0.6871, 0.0148], [0.6084, 0.8351, 0.5382], [0.7421, 0.8639, 0.7444]]]]) # Context input to dendrite layer: batch_size=2, dim_context=4 context_vectors = torch.tensor([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 0.0]]) # Expected absolute max dendrite activations (pre-sigmoid): # [[-0.4933 -0.4274 -0.0689] # [ zero zero zero]] # Expected output of convolutional layer: # # batch item 1 (each row corresponds to an output channel) # [[[-0.2541 -0.1733] [-0.3545 -0.1585]] # [[-0.4334 -0.2137] [-0.2900 -0.2137]] # [[ 0.2454 0.1658] [ 0.1368 0.1342]]] # # batch item 2 # [[[-0.1676 -0.2616] [-0.2571 -0.3334]] # [[-0.3586 -0.2108] [-0.1422 -0.3062]] # [[ 0.1073 0.2777] [ 0.1446 0.2511]]] # Overall expected output of dendrite layer: # # batch item 1 (each row corresponds to an output channel) # [[[-0.0963335 -0.06570089] [-0.13439679 -0.06008996]] # [[-0.17108351 -0.08435751] [-0.11447673 -0.08435751]] # [[ 0.11847466 0.08004522] [ 0.06604455 0.06478932]]] # # batch item 2 # [[[-0.0838 -0.1308] [-0.1285 -0.1667]] # [[-0.1793 -0.1054] [-0.0711 -0.1531]] # [[ 0.0536 0.1389] [ 0.0723 0.1256]]] expected_output = torch.tensor([[[[-0.0963335, -0.06570089], [-0.13439679, -0.06008996]], [[-0.17108351, -0.08435751], [-0.11447673, -0.08435751]], [[0.11847466, 0.08004522], [0.06604455, 0.06478932]]], [[[-0.0838, -0.1308], [-0.1285, -0.1667]], [[-0.1793, -0.1054], [-0.0711, -0.1531]], [[0.0536, 0.1389], [0.0723, 0.1256]]]]) actual_output = dendrite_layer(x, context_vectors) self.assertTrue( torch.allclose(expected_output, actual_output, atol=1e-4))