示例#1
0
    def forward(self, features, centres, centre_labels, neighbours=None):

        batch_size_it = features.size(0)
        num_dims = features.size(1)
        p = []

        #Tensor with [0,1,...,num_classes]
        classes = torch.linspace(0, self.num_classes - 1,
                                 self.num_classes).type(torch.LongTensor).view(
                                     self.num_classes, 1).to(features.device)

        #Per-kernel weights (self.weight in log space)
        kernel_weights = torch.exp(self.weight)

        #If in test phase, find neighbours
        if neighbours is None:
            neighbours = find_neighbours(self.num_neighbours, centres,
                                         features)

        for ii in range(batch_size_it):

            #Neighbour indices for current example
            neighbours_ii = neighbours[ii, :]

            #Squared Euclidean distance to neighbours
            d = torch.pow(features[ii, :] - centres[neighbours_ii], 2).sum(1)

            #Weighted Gaussian distance
            d = torch.exp(
                -d * self.gaussian_constant) * kernel_weights[neighbours_ii]

            #Labels of neighbouring centres
            neighbour_labels = centre_labels[neighbours_ii].view(
                1, neighbours_ii.size(0))

            #Sum per-class influence of neighbouring centres (avoiding loops - need 2D tensors)
            p_arr = torch.zeros(self.num_classes, d.size(0)).type(
                torch.FloatTensor).to(features.device)
            idx = classes == neighbour_labels
            p_arr[idx] = d.expand(self.num_classes, -1)[idx]
            p_ii = p_arr.sum(1)  #Unnormalsied class probability distribution

            #Avoid divide by zero and log(0)
            p_ii[p_ii == 0] = 1e-10

            #Normalise
            p_ii = p_ii / p_ii.sum()

            #Convert to log-prob
            p.append(torch.log(p_ii).view(1, self.num_classes))

        return torch.cat(p)
示例#2
0
import tests
from neighbours import find_neighbours
import unittest

if __name__ == '__main__':
    suite = unittest.TestLoader().loadTestsFromModule(tests)
    unittest.TextTestRunner(verbosity=2).run(suite)

    a = [1.5, 2.0, 4.2, 3.0]
    b = [1.0, 4.0, 3.2]
    neighbours, distances = find_neighbours(a, b)
    print("List: " + str(a))
    print("Reference list: " + str(b))
    print("Neighbours:" + str(neighbours))
    print("Distances: " + str(distances))
示例#3
0
 def test_find_neighbours_high_nth_neighbour(self):
     result = ([2, 0, 0, 0, 0], [1.75, 2.0, 2.5, 3.5, 4.2])
     self.assertEqual(find_neighbours(self.A, self.B, n=10), result)
示例#4
0
 def test_find_neighbours_epsilon(self):
     result = ([2, 0, 1, 2, 2], [1.75, 2.0, 1.5, 1.5, 2.2])
     self.assertEqual(find_neighbours(self.A, self.B, epsilon=1.0), result)
示例#5
0
 def test_find_neighbours_normal_nth_neighbour(self):
     result = ([1, 1, 1, 1, 1], [0.75, 1.0, 1.5, 2.5, 3.2])
     self.assertEqual(find_neighbours(self.A, self.B, n=2), result)
示例#6
0
 def test_find_neighbours_default(self):
     result = ([0, 2, 2, 2, 2], [0.25, 0.0, 0.5, 1.5, 2.2])
     self.assertEqual(find_neighbours(self.A, self.B), result)
示例#7
0
 def test_find_neighbours_skip_identical(self):
     result = ([1, 0, 1], [1.0, 1.0, 1.0])
     self.assertEqual(find_neighbours(self.B, self.B, skip_identical=True),
                      result)
示例#8
0
 def test_no_neighbours_with_skip_identical(self):
     result = ([], [])
     self.assertEqual(
         find_neighbours([1.0, 1.0, 1.0], [1.0, 1.0, 1.0],
                         skip_identical=True), result)
示例#9
0
 def test_no_neighbours_with_epsilon(self):
     result = ([], [])
     self.assertEqual(find_neighbours(self.A, self.B, epsilon=10.0), result)
示例#10
0
 def test_identical_and_epsilon_parameter(self):
     with self.assertRaises(ValueError):
         find_neighbours(self.A,
                         self.B,
                         skip_identical=True,
                         epsilon=self.epsilon)
示例#11
0
 def test_empty_reference_list(self):
     with self.assertRaises(ValueError):
         find_neighbours(self.A, [])
示例#12
0

"""
Training
"""
print("Begin training...")
for epoch in range(args.max_epochs):  # loop over the dataset multiple times
	
	#Update stored kernel centres
	if (epoch % args.update_interval) == 0:

		print("Updating kernel centres...")
		centres = update_centres()
		print("Finding training set neighbours...")
		centres = centres.cpu()
		neighbours_tr = find_neighbours( num_neighbours, centres )
		centres = centres.to(device)
		print("Finished update!")

		if epoch > 0:
			save_model()
	
	#Training
	running_loss = 0.0
	running_correct = 0
	for i, data in enumerate(train_loader, 0):
		
		# Get the inputs; data is a list of [inputs, labels]. Send to GPU
		inputs, labels, indices = data
		inputs  = inputs.to(device)
		labels  = labels.to(device).view(-1)