def testWriteRead(self): knn = KNNClassifier(distanceMethod="norm", numSVDDims=2, numSVDSamples=2, useSparseMemory=True, minSparsity=0.1, distThreshold=0.1) dimensionality = 40 a = np.array([1, 3, 7, 11, 13, 17, 19, 23, 29], dtype=np.int32) b = np.array([2, 4, 8, 12, 14, 18, 20, 28, 30], dtype=np.int32) c = np.array([1, 2, 3, 14, 16, 19, 22, 24, 33], dtype=np.int32) d = np.array([2, 4, 8, 12, 14, 19, 22, 24, 33], dtype=np.int32) knn.learn(a, 0, isSparse=dimensionality, partitionId=None) knn.learn(b, 1, isSparse=dimensionality, partitionId=None) knn.learn(c, 2, isSparse=dimensionality, partitionId=211) knn.learn(d, 1, isSparse=dimensionality, partitionId=405) knn.finishLearning() proto = KNNClassifierProto.new_message() knn.write(proto) with tempfile.TemporaryFile() as f: proto.write(f) f.seek(0) protoDeserialized = KNNClassifierProto.read(f) knnDeserialized = KNNClassifier.read(protoDeserialized) denseA = np.zeros(dimensionality) denseA[a] = 1.0 expected = knn.infer(denseA) actual = knnDeserialized.infer(denseA) self.assertEqual(expected[0], actual[0]) self.assertItemsEqual(expected[1], actual[1]) self.assertItemsEqual(expected[2], actual[2]) self.assertItemsEqual(expected[3], actual[3]) self.assertItemsEqual(knn.getPartitionIdList(), knnDeserialized.getPartitionIdList())
def testGetPartitionId(self): """ Test a sequence of calls to KNN to ensure we can retrieve partition Id: - We first learn on some patterns (including one pattern with no partitionId in the middle) and test that we can retrieve Ids. - We then invoke inference and then check partitionId again. - We check incorrect indices to ensure we get an exception. - We check the case where the partitionId to be ignored is not in the list. - We learn on one more pattern and check partitionIds again - We remove rows and ensure partitionIds still work """ params = {"distanceMethod": "rawOverlap"} classifier = KNNClassifier(**params) dimensionality = 40 a = np.array([1, 3, 7, 11, 13, 17, 19, 23, 29], dtype=np.int32) b = np.array([2, 4, 8, 12, 14, 18, 20, 28, 30], dtype=np.int32) c = np.array([1, 2, 3, 14, 16, 19, 22, 24, 33], dtype=np.int32) d = np.array([2, 4, 8, 12, 14, 19, 22, 24, 33], dtype=np.int32) e = np.array([1, 3, 7, 12, 14, 19, 22, 24, 33], dtype=np.int32) denseA = np.zeros(dimensionality) denseA[a] = 1.0 classifier.learn(a, 0, isSparse=dimensionality, partitionId=433) classifier.learn(b, 1, isSparse=dimensionality, partitionId=213) classifier.learn(c, 1, isSparse=dimensionality, partitionId=None) classifier.learn(d, 1, isSparse=dimensionality, partitionId=433) self.assertEquals(classifier.getPartitionId(0), 433) self.assertEquals(classifier.getPartitionId(1), 213) self.assertEquals(classifier.getPartitionId(2), None) self.assertEquals(classifier.getPartitionId(3), 433) cat, _, _, _ = classifier.infer(denseA, partitionId=213) self.assertEquals(cat, 0) # Test with patternId not in classifier cat, _, _, _ = classifier.infer(denseA, partitionId=666) self.assertEquals(cat, 0) # Partition Ids should be maintained after inference self.assertEquals(classifier.getPartitionId(0), 433) self.assertEquals(classifier.getPartitionId(1), 213) self.assertEquals(classifier.getPartitionId(2), None) self.assertEquals(classifier.getPartitionId(3), 433) # Should return exceptions if we go out of bounds with self.assertRaises(RuntimeError): classifier.getPartitionId(4) with self.assertRaises(RuntimeError): classifier.getPartitionId(-1) # Learn again classifier.learn(e, 4, isSparse=dimensionality, partitionId=413) self.assertEquals(classifier.getPartitionId(4), 413) # Test getPatternIndicesWithPartitionId self.assertItemsEqual(classifier.getPatternIndicesWithPartitionId(433), [0, 3]) self.assertItemsEqual(classifier.getPatternIndicesWithPartitionId(666), []) self.assertItemsEqual(classifier.getPatternIndicesWithPartitionId(413), [4]) self.assertEquals(classifier.getNumPartitionIds(), 3) # Check that the full set of partition ids is what we expect self.assertItemsEqual(classifier.getPartitionIdList(), [433, 213, np.inf, 433, 413]) self.assertItemsEqual(classifier.getPartitionIdKeys(), [433, 413, 213]) # Remove two rows - all indices shift down self.assertEquals(classifier._removeRows([0,2]), 2) self.assertItemsEqual(classifier.getPatternIndicesWithPartitionId(433), [1]) self.assertItemsEqual(classifier.getPatternIndicesWithPartitionId(413), [2]) # Remove another row and check number of partitions have decreased classifier._removeRows([0]) self.assertEquals(classifier.getNumPartitionIds(), 2) # Check that the full set of partition ids is what we expect self.assertItemsEqual(classifier.getPartitionIdList(), [433, 413]) self.assertItemsEqual(classifier.getPartitionIdKeys(), [433, 413])