コード例 #1
0
  def testGetPartitionId(self):
    """
    Test a sequence of calls to KNN to ensure we can retrieve partition Id:
        - We first learn on some patterns (including one pattern with no
          partitionId in the middle) and test that we can retrieve Ids.
        - We then invoke inference and then check partitionId again.
        - We check incorrect indices to ensure we get an exception.
        - We check the case where the partitionId to be ignored is not in
          the list.
        - We learn on one more pattern and check partitionIds again
        - We remove rows and ensure partitionIds still work
    """
    params = {"distanceMethod": "rawOverlap"}
    classifier = KNNClassifier(**params)

    dimensionality = 40
    a = np.array([1, 3, 7, 11, 13, 17, 19, 23, 29], dtype=np.int32)
    b = np.array([2, 4, 8, 12, 14, 18, 20, 28, 30], dtype=np.int32)
    c = np.array([1, 2, 3, 14, 16, 19, 22, 24, 33], dtype=np.int32)
    d = np.array([2, 4, 8, 12, 14, 19, 22, 24, 33], dtype=np.int32)
    e = np.array([1, 3, 7, 12, 14, 19, 22, 24, 33], dtype=np.int32)

    denseA = np.zeros(dimensionality)
    denseA[a] = 1.0

    classifier.learn(a, 0, isSparse=dimensionality, partitionId=433)
    classifier.learn(b, 1, isSparse=dimensionality, partitionId=213)
    classifier.learn(c, 1, isSparse=dimensionality, partitionId=None)
    classifier.learn(d, 1, isSparse=dimensionality, partitionId=433)

    self.assertEquals(classifier.getPartitionId(0), 433)
    self.assertEquals(classifier.getPartitionId(1), 213)
    self.assertEquals(classifier.getPartitionId(2), None)
    self.assertEquals(classifier.getPartitionId(3), 433)

    cat, _, _, _ = classifier.infer(denseA, partitionId=213)
    self.assertEquals(cat, 0)

    # Test with patternId not in classifier
    cat, _, _, _ = classifier.infer(denseA, partitionId=666)
    self.assertEquals(cat, 0)

    # Partition Ids should be maintained after inference
    self.assertEquals(classifier.getPartitionId(0), 433)
    self.assertEquals(classifier.getPartitionId(1), 213)
    self.assertEquals(classifier.getPartitionId(2), None)
    self.assertEquals(classifier.getPartitionId(3), 433)

    # Should return exceptions if we go out of bounds
    with self.assertRaises(RuntimeError):
      classifier.getPartitionId(4)
    with self.assertRaises(RuntimeError):
      classifier.getPartitionId(-1)

    # Learn again
    classifier.learn(e, 4, isSparse=dimensionality, partitionId=413)
    self.assertEquals(classifier.getPartitionId(4), 413)

    # Test getPatternIndicesWithPartitionId
    self.assertItemsEqual(classifier.getPatternIndicesWithPartitionId(433),
                          [0, 3])
    self.assertItemsEqual(classifier.getPatternIndicesWithPartitionId(666),
                          [])
    self.assertItemsEqual(classifier.getPatternIndicesWithPartitionId(413),
                          [4])

    self.assertEquals(classifier.getNumPartitionIds(), 3)

    # Check that the full set of partition ids is what we expect
    self.assertItemsEqual(classifier.getPartitionIdPerPattern(),
                          [433, 213, np.inf, 433, 413])
    self.assertItemsEqual(classifier.getPartitionIdList(),[433, 413, 213])

    # Remove two rows - all indices shift down
    self.assertEquals(classifier._removeRows([0,2]), 2)
    self.assertItemsEqual(classifier.getPatternIndicesWithPartitionId(433),
                          [1])
    self.assertItemsEqual(classifier.getPatternIndicesWithPartitionId(413),
                          [2])

    # Remove another row and check number of partitions have decreased
    classifier._removeRows([0])
    self.assertEquals(classifier.getNumPartitionIds(), 2)

    # Check that the full set of partition ids is what we expect
    self.assertItemsEqual(classifier.getPartitionIdPerPattern(), [433, 413])
    self.assertItemsEqual(classifier.getPartitionIdList(),[433, 413])