示例#1
0
class PatternMachineTest(unittest.TestCase):
    def setUp(self):
        self.patternMachine = PatternMachine(10000, 5, num=50)

    def testGet(self):
        patternA = self.patternMachine.get(48)
        self.assertEqual(len(patternA), 5)

        patternB = self.patternMachine.get(49)
        self.assertEqual(len(patternB), 5)

        self.assertEqual(patternA & patternB, set())

    def testGetOutOfBounds(self):
        args = [50]
        self.assertRaises(IndexError, self.patternMachine.get, *args)

    def testAddNoise(self):
        patternMachine = PatternMachine(10000, 1000, num=1)
        pattern = patternMachine.get(0)

        noisy = patternMachine.addNoise(pattern, 0.0)
        self.assertEqual(len(pattern & noisy), 1000)

        noisy = patternMachine.addNoise(pattern, 0.5)
        self.assertTrue(400 < len(pattern & noisy) < 600)

        noisy = patternMachine.addNoise(pattern, 1.0)
        self.assertTrue(50 < len(pattern & noisy) < 150)

    def testNumbersForBit(self):
        pattern = self.patternMachine.get(49)

        for bit in pattern:
            self.assertEqual(self.patternMachine.numbersForBit(bit), set([49]))

    def testNumbersForBitOutOfBounds(self):
        args = [10000]
        self.assertRaises(IndexError, self.patternMachine.numbersForBit, *args)

    def testNumberMapForBits(self):
        pattern = self.patternMachine.get(49)
        numberMap = self.patternMachine.numberMapForBits(pattern)

        self.assertEqual(numberMap.keys(), [49])
        self.assertEqual(numberMap[49], pattern)

    def testWList(self):
        w = [4, 7, 11]
        patternMachine = PatternMachine(100, w, num=50)
        widths = dict((el, 0) for el in w)

        for i in range(50):
            pattern = patternMachine.get(i)
            width = len(pattern)
            self.assertTrue(width in w)
            widths[len(pattern)] += 1

        for i in w:
            self.assertTrue(widths[i] > 0)
示例#2
0
  def testWriteRead(self):
    tm1 = TemporalMemory(
      columnDimensions=(100,),
      cellsPerColumn=4,
      activationThreshold=7,
      initialPermanence=0.37,
      connectedPermanence=0.58,
      minThreshold=4,
      maxNewSynapseCount=18,
      permanenceIncrement=0.23,
      permanenceDecrement=0.08,
      seed=91
    )

    # Run some data through before serializing
    patternMachine = PatternMachine(100, 4)
    sequenceMachine = SequenceMachine(patternMachine)
    sequence = sequenceMachine.generateFromNumbers(range(5))
    for _ in range(3):
      for pattern in sequence:
        tm1.compute(pattern)

    proto1 = TemporalMemoryProto_capnp.TemporalMemoryProto.new_message()
    tm1.write(proto1)

    # Write the proto to a temp file and read it back into a new proto
    with tempfile.TemporaryFile() as f:
      proto1.write(f)
      f.seek(0)
      proto2 = TemporalMemoryProto_capnp.TemporalMemoryProto.read(f)

    # Load the deserialized proto
    tm2 = TemporalMemory.read(proto2)

    # Check that the two temporal memory objects have the same attributes
    self.assertEqual(tm1, tm2)
    # Run a couple records through after deserializing and check results match
    tm1.compute(patternMachine.get(0))
    tm2.compute(patternMachine.get(0))
    self.assertEqual(set(tm1.getActiveCells()), set(tm2.getActiveCells()))
    self.assertEqual(set(tm1.getPredictiveCells()),
                     set(tm2.getPredictiveCells()))
    self.assertEqual(set(tm1.getWinnerCells()), set(tm2.getWinnerCells()))
    self.assertEqual(tm1.connections, tm2.connections)

    tm1.compute(patternMachine.get(3))
    tm2.compute(patternMachine.get(3))
    self.assertEqual(set(tm1.getActiveCells()), set(tm2.getActiveCells()))
    self.assertEqual(set(tm1.getPredictiveCells()),
                     set(tm2.getPredictiveCells()))
    self.assertEqual(set(tm1.getWinnerCells()), set(tm2.getWinnerCells()))
    self.assertEqual(tm1.connections, tm2.connections)
示例#3
0
    def testWriteRead(self):
        tm1 = TemporalMemory(columnDimensions=[100],
                             cellsPerColumn=4,
                             activationThreshold=7,
                             initialPermanence=0.37,
                             connectedPermanence=0.58,
                             minThreshold=4,
                             maxNewSynapseCount=18,
                             permanenceIncrement=0.23,
                             permanenceDecrement=0.08,
                             seed=91)

        # Run some data through before serializing
        patternMachine = PatternMachine(100, 4)
        sequenceMachine = SequenceMachine(patternMachine)
        sequence = sequenceMachine.generateFromNumbers(range(5))
        for _ in range(3):
            for pattern in sequence:
                tm1.compute(pattern)

        proto1 = TemporalMemoryProto_capnp.TemporalMemoryProto.new_message()
        tm1.write(proto1)

        # Write the proto to a temp file and read it back into a new proto
        with tempfile.TemporaryFile() as f:
            proto1.write(f)
            f.seek(0)
            proto2 = TemporalMemoryProto_capnp.TemporalMemoryProto.read(f)

        # Load the deserialized proto
        tm2 = TemporalMemory.read(proto2)

        # Check that the two temporal memory objects have the same attributes
        self.assertEqual(tm1, tm2)
        # Run a couple records through after deserializing and check results match
        tm1.compute(patternMachine.get(0))
        tm2.compute(patternMachine.get(0))
        self.assertEqual(set(tm1.getActiveCells()), set(tm2.getActiveCells()))
        self.assertEqual(set(tm1.getPredictiveCells()),
                         set(tm2.getPredictiveCells()))
        self.assertEqual(set(tm1.getWinnerCells()), set(tm2.getWinnerCells()))
        self.assertEqual(tm1.connections, tm2.connections)

        tm1.compute(patternMachine.get(3))
        tm2.compute(patternMachine.get(3))
        self.assertEqual(set(tm1.getActiveCells()), set(tm2.getActiveCells()))
        self.assertEqual(set(tm1.getPredictiveCells()),
                         set(tm2.getPredictiveCells()))
        self.assertEqual(set(tm1.getWinnerCells()), set(tm2.getWinnerCells()))
        self.assertEqual(tm1.connections, tm2.connections)
示例#4
0
  def testAddSpatialNoise(self):
    patternMachine = PatternMachine(10000, 1000, num=100)
    sequenceMachine = SequenceMachine(patternMachine)
    numbers = range(0, 100)
    numbers.append(None)

    sequence = sequenceMachine.generateFromNumbers(numbers)
    noisy = sequenceMachine.addSpatialNoise(sequence, 0.5)

    overlap = len(noisy[0] & patternMachine.get(0))
    self.assertTrue(400 < overlap < 600)

    sequence = sequenceMachine.generateFromNumbers(numbers)
    noisy = sequenceMachine.addSpatialNoise(sequence, 0.0)

    overlap = len(noisy[0] & patternMachine.get(0))
    self.assertEqual(overlap, 1000)
示例#5
0
    def testAddSpatialNoise(self):
        patternMachine = PatternMachine(10000, 1000, num=100)
        sequenceMachine = SequenceMachine(patternMachine)
        numbers = range(0, 100)
        numbers.append(None)

        sequence = sequenceMachine.generateFromNumbers(numbers)
        noisy = sequenceMachine.addSpatialNoise(sequence, 0.5)

        overlap = len(noisy[0] & patternMachine.get(0))
        self.assertTrue(400 < overlap < 600)

        sequence = sequenceMachine.generateFromNumbers(numbers)
        noisy = sequenceMachine.addSpatialNoise(sequence, 0.0)

        overlap = len(noisy[0] & patternMachine.get(0))
        self.assertEqual(overlap, 1000)
示例#6
0
  def testAddNoise(self):
    patternMachine = PatternMachine(10000, 1000, num=1)
    pattern = patternMachine.get(0)

    noisy = patternMachine.addNoise(pattern, 0.0)
    self.assertEqual(len(pattern & noisy), 1000)

    noisy = patternMachine.addNoise(pattern, 0.5)
    self.assertTrue(400 < len(pattern & noisy) < 600)

    noisy = patternMachine.addNoise(pattern, 1.0)
    self.assertTrue(50 < len(pattern & noisy) < 150)
示例#7
0
    def testAddNoise(self):
        patternMachine = PatternMachine(10000, 1000, num=1)
        pattern = patternMachine.get(0)

        noisy = patternMachine.addNoise(pattern, 0.0)
        self.assertEqual(len(pattern & noisy), 1000)

        noisy = patternMachine.addNoise(pattern, 0.5)
        self.assertTrue(400 < len(pattern & noisy) < 600)

        noisy = patternMachine.addNoise(pattern, 1.0)
        self.assertTrue(50 < len(pattern & noisy) < 150)
示例#8
0
  def testWList(self):
    w = [4, 7, 11]
    patternMachine = PatternMachine(100, w, num=50)
    widths = dict((el, 0) for el in w)

    for i in range(50):
      pattern = patternMachine.get(i)
      width = len(pattern)
      self.assertTrue(width in w)
      widths[len(pattern)] += 1

    for i in w:
      self.assertTrue(widths[i] > 0)
示例#9
0
    def testWList(self):
        w = [4, 7, 11]
        patternMachine = PatternMachine(100, w, num=50)
        widths = dict((el, 0) for el in w)

        for i in range(50):
            pattern = patternMachine.get(i)
            width = len(pattern)
            self.assertTrue(width in w)
            widths[len(pattern)] += 1

        for i in w:
            self.assertTrue(widths[i] > 0)
示例#10
0
class ExtendedTemporalMemoryTest(unittest.TestCase):
    def setUp(self):
        self.tm = ExtendedTemporalMemory(learnOnOneCell=False)

    def testInitInvalidParams(self):
        # Invalid columnDimensions
        kwargs = {"columnDimensions": [], "cellsPerColumn": 32}
        self.assertRaises(ValueError, ExtendedTemporalMemory, **kwargs)

        # Invalid cellsPerColumn
        kwargs = {"columnDimensions": [2048], "cellsPerColumn": 0}
        self.assertRaises(ValueError, ExtendedTemporalMemory, **kwargs)
        kwargs = {"columnDimensions": [2048], "cellsPerColumn": -10}
        self.assertRaises(ValueError, ExtendedTemporalMemory, **kwargs)

    def testlearnOnOneCellParam(self):
        tm = self.tm
        self.assertFalse(tm.learnOnOneCell)

        tm = ExtendedTemporalMemory(learnOnOneCell=True)
        self.assertTrue(tm.learnOnOneCell)

    def testActivateCorrectlyPredictiveCells(self):
        tm = self.tm

        prevPredictiveCells = set([0, 237, 1026, 26337, 26339, 55536])
        activeColumns = set([32, 47, 823])
        prevMatchingCells = set()

        (activeCells, winnerCells, predictedColumns, predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(
            prevPredictiveCells, prevMatchingCells, activeColumns
        )

        self.assertEqual(activeCells, set([1026, 26337, 26339]))
        self.assertEqual(winnerCells, set([1026, 26337, 26339]))
        self.assertEqual(predictedColumns, set([32, 823]))
        self.assertEqual(predictedInactiveCells, set())

    def testActivateCorrectlyPredictiveCellsEmpty(self):
        tm = self.tm

        # No previous predictive cells, no active columns
        prevPredictiveCells = set()
        activeColumns = set()
        prevMatchingCells = set()

        (activeCells, winnerCells, predictedColumns, predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(
            prevPredictiveCells, prevMatchingCells, activeColumns
        )

        self.assertEqual(activeCells, set())
        self.assertEqual(winnerCells, set())
        self.assertEqual(predictedColumns, set())
        self.assertEqual(predictedInactiveCells, set())

        # No previous predictive cells, with active columns

        prevPredictiveCells = set()
        activeColumns = set([32, 47, 823])
        prevMatchingCells = set()

        (activeCells, winnerCells, predictedColumns, predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(
            prevPredictiveCells, prevMatchingCells, activeColumns
        )

        self.assertEqual(activeCells, set())
        self.assertEqual(winnerCells, set())
        self.assertEqual(predictedColumns, set())
        self.assertEqual(predictedInactiveCells, set())

        # No active columns, with previously predictive cells
        prevPredictiveCells = set([0, 237, 1026, 26337, 26339, 55536])
        activeColumns = set()
        prevMatchingCells = set()

        (activeCells, winnerCells, predictedColumns, predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(
            prevPredictiveCells, prevMatchingCells, activeColumns
        )

        self.assertEqual(activeCells, set())
        self.assertEqual(winnerCells, set())
        self.assertEqual(predictedColumns, set())
        self.assertEqual(predictedInactiveCells, set())

    def testActivateCorrectlyPredictiveCellsOrphan(self):
        tm = self.tm
        tm.predictedSegmentDecrement = 0.001
        prevPredictiveCells = set([])
        activeColumns = set([32, 47, 823])
        prevMatchingCells = set([32, 47])

        (activeCells, winnerCells, predictedColumns, predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(
            prevPredictiveCells, prevMatchingCells, activeColumns
        )

        self.assertEqual(activeCells, set([]))
        self.assertEqual(winnerCells, set([]))
        self.assertEqual(predictedColumns, set([]))
        self.assertEqual(predictedInactiveCells, set([32, 47]))

    def testBurstColumns(self):
        tm = ExtendedTemporalMemory(cellsPerColumn=4, connectedPermanence=0.50, minThreshold=1, seed=42)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.4)
        connections.createSynapse(0, 477, 0.9)

        connections.createSegment(0)
        connections.createSynapse(1, 49, 0.9)
        connections.createSynapse(1, 3, 0.8)

        connections.createSegment(1)
        connections.createSynapse(2, 733, 0.7)

        connections.createSegment(108)
        connections.createSynapse(3, 486, 0.9)

        activeColumns = set([0, 1, 26])
        predictedColumns = set([26])
        prevActiveCells = set([23, 37, 49, 733])
        prevWinnerCells = set([23, 37, 49, 733])

        prevActiveApicalCells = set()
        learnOnOneCell = False
        chosenCellForColumn = {}

        (activeCells, winnerCells, learningSegments, apicalLearningSegments, chosenCellForColumn) = tm.burstColumns(
            activeColumns,
            predictedColumns,
            prevActiveCells,
            prevActiveApicalCells,
            prevWinnerCells,
            learnOnOneCell,
            chosenCellForColumn,
            connections,
            tm.apicalConnections,
        )

        self.assertEqual(activeCells, set([0, 1, 2, 3, 4, 5, 6, 7]))
        randomWinner = 4  # 4 should be randomly chosen cell
        self.assertEqual(winnerCells, set([0, randomWinner]))
        self.assertEqual(learningSegments, set([0, 4]))  # 4 is new segment created

        # Check that new segment was added to winner cell (6) in column 1
        self.assertEqual(connections.segmentsForCell(randomWinner), set([4]))

    def testBurstColumnsEmpty(self):
        tm = self.tm

        activeColumns = set()
        predictedColumns = set()
        prevActiveCells = set()
        prevWinnerCells = set()
        connections = tm.connections

        prevActiveApicalCells = set()
        learnOnOneCell = False
        chosenCellForColumn = {}

        (activeCells, winnerCells, learningSegments, apicalLearningSegments, chosenCellForColumn) = tm.burstColumns(
            activeColumns,
            predictedColumns,
            prevActiveCells,
            prevActiveApicalCells,
            prevWinnerCells,
            learnOnOneCell,
            chosenCellForColumn,
            connections,
            tm.apicalConnections,
        )

        self.assertEqual(activeCells, set())
        self.assertEqual(winnerCells, set())
        self.assertEqual(learningSegments, set())
        self.assertEqual(apicalLearningSegments, set())

    def testLearnOnSegments(self):
        tm = ExtendedTemporalMemory(maxNewSynapseCount=2)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.4)
        connections.createSynapse(0, 477, 0.9)

        connections.createSegment(1)
        connections.createSynapse(1, 733, 0.7)

        connections.createSegment(8)
        connections.createSynapse(2, 486, 0.9)

        connections.createSegment(100)

        prevActiveSegments = set([0, 2])
        learningSegments = set([1, 3])
        prevActiveCells = set([23, 37, 733])
        winnerCells = set([0])
        prevWinnerCells = set([10, 11, 12, 13, 14])
        predictedInactiveCells = set()
        prevMatchingSegments = set()
        tm.learnOnSegments(
            prevActiveSegments,
            learningSegments,
            prevActiveCells,
            winnerCells,
            prevWinnerCells,
            connections,
            predictedInactiveCells,
            prevMatchingSegments,
        )

        # Check segment 0
        synapseData = connections.dataForSynapse(0)
        self.assertAlmostEqual(synapseData.permanence, 0.7)

        synapseData = connections.dataForSynapse(1)
        self.assertAlmostEqual(synapseData.permanence, 0.5)

        synapseData = connections.dataForSynapse(2)
        self.assertAlmostEqual(synapseData.permanence, 0.8)

        # Check segment 1
        synapseData = connections.dataForSynapse(3)
        self.assertAlmostEqual(synapseData.permanence, 0.8)

        self.assertEqual(len(connections.synapsesForSegment(1)), 2)

        # Check segment 2
        synapseData = connections.dataForSynapse(4)
        self.assertAlmostEqual(synapseData.permanence, 0.9)

        self.assertEqual(len(connections.synapsesForSegment(2)), 1)

        # Check segment 3
        self.assertEqual(len(connections.synapsesForSegment(3)), 2)

    def testComputePredictiveCells(self):
        tm = ExtendedTemporalMemory(activationThreshold=2, minThreshold=2, predictedSegmentDecrement=0.004)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.5)
        connections.createSynapse(0, 477, 0.9)

        connections.createSegment(1)
        connections.createSynapse(1, 733, 0.7)
        connections.createSynapse(1, 733, 0.4)

        connections.createSegment(1)
        connections.createSynapse(2, 974, 0.9)

        connections.createSegment(8)
        connections.createSynapse(3, 486, 0.9)

        connections.createSegment(100)

        activeCells = set([23, 37, 733, 974])

        (activeSegments, predictiveCells, matchingSegments, matchingCells) = tm.computePredictiveCells(
            activeCells, connections
        )
        self.assertEqual(activeSegments, set([0]))
        self.assertEqual(predictiveCells, set([0]))
        self.assertEqual(matchingSegments, set([0, 1]))
        self.assertEqual(matchingCells, set([0, 1]))

    def testBestMatchingCell(self):
        tm = ExtendedTemporalMemory(connectedPermanence=0.50, minThreshold=1, seed=42)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.4)
        connections.createSynapse(0, 477, 0.9)

        connections.createSegment(0)
        connections.createSynapse(1, 49, 0.9)
        connections.createSynapse(1, 3, 0.8)

        connections.createSegment(1)
        connections.createSynapse(2, 733, 0.7)

        connections.createSegment(108)
        connections.createSynapse(3, 486, 0.9)

        activeCells = set([23, 37, 49, 733])
        activeApicalCells = set()

        self.assertEqual(
            tm.bestMatchingCell(
                tm.cellsForColumn(0), activeCells, activeApicalCells, connections, tm.apicalConnections
            ),
            (0, 0, None),
        )

        self.assertEqual(
            tm.bestMatchingCell(
                tm.cellsForColumn(3), activeCells, activeApicalCells, connections, tm.apicalConnections
            ),
            (103, None, None),
        )  # Random cell from column

        self.assertEqual(
            tm.bestMatchingCell(
                tm.cellsForColumn(999), activeCells, activeApicalCells, connections, tm.apicalConnections
            ),
            (31979, None, None),
        )  # Random cell from column

    def testBestMatchingCellFewestSegments(self):
        tm = ExtendedTemporalMemory(
            columnDimensions=[2], cellsPerColumn=2, connectedPermanence=0.50, minThreshold=1, seed=42
        )

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 3, 0.3)

        activeSynapsesForSegment = set([])
        activeApicalCells = set()

        for _ in range(100):
            # Never pick cell 0, always pick cell 1
            (cell, _, _) = tm.bestMatchingCell(
                tm.cellsForColumn(0), activeSynapsesForSegment, activeApicalCells, connections, tm.apicalConnections
            )
            self.assertEqual(cell, 1)

    def testBestMatchingSegment(self):
        tm = ExtendedTemporalMemory(connectedPermanence=0.50, minThreshold=1)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.4)
        connections.createSynapse(0, 477, 0.9)

        connections.createSegment(0)
        connections.createSynapse(1, 49, 0.9)
        connections.createSynapse(1, 3, 0.8)

        connections.createSegment(1)
        connections.createSynapse(2, 733, 0.7)

        connections.createSegment(8)
        connections.createSynapse(3, 486, 0.9)

        activeCells = set([23, 37, 49, 733])

        self.assertEqual(tm.bestMatchingSegment(0, activeCells, connections), (0, 2))

        self.assertEqual(tm.bestMatchingSegment(1, activeCells, connections), (2, 1))

        self.assertEqual(tm.bestMatchingSegment(8, activeCells, connections), (None, None))

        self.assertEqual(tm.bestMatchingSegment(100, activeCells, connections), (None, None))

    def testLeastUsedCell(self):
        tm = ExtendedTemporalMemory(columnDimensions=[2], cellsPerColumn=2, seed=42)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 3, 0.3)

        for _ in range(100):
            # Never pick cell 0, always pick cell 1
            self.assertEqual(tm.leastUsedCell(tm.cellsForColumn(0), connections), 1)

    def testAdaptSegment(self):
        tm = self.tm

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.4)
        connections.createSynapse(0, 477, 0.9)

        tm.adaptSegment(0, set([0, 1]), connections, tm.permanenceIncrement, tm.permanenceDecrement)

        synapseData = connections.dataForSynapse(0)
        self.assertAlmostEqual(synapseData.permanence, 0.7)

        synapseData = connections.dataForSynapse(1)
        self.assertAlmostEqual(synapseData.permanence, 0.5)

        synapseData = connections.dataForSynapse(2)
        self.assertAlmostEqual(synapseData.permanence, 0.8)

    def testAdaptSegmentToMax(self):
        tm = self.tm

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.9)

        tm.adaptSegment(0, set([0]), connections, tm.permanenceIncrement, tm.permanenceDecrement)
        synapseData = connections.dataForSynapse(0)
        self.assertAlmostEqual(synapseData.permanence, 1.0)

        # Now permanence should be at max
        tm.adaptSegment(0, set([0]), connections, tm.permanenceIncrement, tm.permanenceDecrement)
        synapseData = connections.dataForSynapse(0)
        self.assertAlmostEqual(synapseData.permanence, 1.0)

    def testAdaptSegmentToMin(self):
        tm = self.tm

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.1)

        tm.adaptSegment(0, set(), connections, tm.permanenceIncrement, tm.permanenceDecrement)

        synapses = connections.synapsesForSegment(0)
        self.assertFalse(0 in synapses)

    def testPickCellsToLearnOn(self):
        tm = ExtendedTemporalMemory(seed=42)

        connections = tm.connections
        connections.createSegment(0)

        winnerCells = set([4, 47, 58, 93])

        self.assertEqual(tm.pickCellsToLearnOn(2, 0, winnerCells, connections), set([4, 93]))  # randomly picked

        self.assertEqual(tm.pickCellsToLearnOn(100, 0, winnerCells, connections), set([4, 47, 58, 93]))

        self.assertEqual(tm.pickCellsToLearnOn(0, 0, winnerCells, connections), set())

    def testPickCellsToLearnOnAvoidDuplicates(self):
        tm = ExtendedTemporalMemory(seed=42)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)

        winnerCells = set([23])

        # Ensure that no additional (duplicate) cells were picked
        self.assertEqual(tm.pickCellsToLearnOn(2, 0, winnerCells, connections), set())

    def testColumnForCell1D(self):
        tm = ExtendedTemporalMemory(columnDimensions=[2048], cellsPerColumn=5)
        self.assertEqual(tm.columnForCell(0), 0)
        self.assertEqual(tm.columnForCell(4), 0)
        self.assertEqual(tm.columnForCell(5), 1)
        self.assertEqual(tm.columnForCell(10239), 2047)

    def testColumnForCell2D(self):
        tm = ExtendedTemporalMemory(columnDimensions=[64, 64], cellsPerColumn=4)
        self.assertEqual(tm.columnForCell(0), 0)
        self.assertEqual(tm.columnForCell(3), 0)
        self.assertEqual(tm.columnForCell(4), 1)
        self.assertEqual(tm.columnForCell(16383), 4095)

    def testColumnForCellInvalidCell(self):
        tm = ExtendedTemporalMemory(columnDimensions=[64, 64], cellsPerColumn=4)

        try:
            tm.columnForCell(16383)
        except IndexError:
            self.fail("IndexError raised unexpectedly")

        args = [16384]
        self.assertRaises(IndexError, tm.columnForCell, *args)

        args = [-1]
        self.assertRaises(IndexError, tm.columnForCell, *args)

    def testCellsForColumn1D(self):
        tm = ExtendedTemporalMemory(columnDimensions=[2048], cellsPerColumn=5)
        expectedCells = set([5, 6, 7, 8, 9])
        self.assertEqual(tm.cellsForColumn(1), expectedCells)

    def testCellsForColumn2D(self):
        tm = ExtendedTemporalMemory(columnDimensions=[64, 64], cellsPerColumn=4)
        expectedCells = set([256, 257, 258, 259])
        self.assertEqual(tm.cellsForColumn(64), expectedCells)

    def testCellsForColumnInvalidColumn(self):
        tm = ExtendedTemporalMemory(columnDimensions=[64, 64], cellsPerColumn=4)

        try:
            tm.cellsForColumn(4095)
        except IndexError:
            self.fail("IndexError raised unexpectedly")

        args = [4096]
        self.assertRaises(IndexError, tm.cellsForColumn, *args)

        args = [-1]
        self.assertRaises(IndexError, tm.cellsForColumn, *args)

    def testNumberOfColumns(self):
        tm = ExtendedTemporalMemory(columnDimensions=[64, 64], cellsPerColumn=32)
        self.assertEqual(tm.numberOfColumns(), 64 * 64)

    def testNumberOfCells(self):
        tm = ExtendedTemporalMemory(columnDimensions=[64, 64], cellsPerColumn=32)
        self.assertEqual(tm.numberOfCells(), 64 * 64 * 32)

    def testMapCellsToColumns(self):
        tm = ExtendedTemporalMemory(columnDimensions=[100], cellsPerColumn=4)
        columnsForCells = tm.mapCellsToColumns(set([0, 1, 2, 5, 399]))
        self.assertEqual(columnsForCells[0], set([0, 1, 2]))
        self.assertEqual(columnsForCells[1], set([5]))
        self.assertEqual(columnsForCells[99], set([399]))

    def testCalculatePredictiveCells(self):
        tm = ExtendedTemporalMemory(columnDimensions=[4], cellsPerColumn=5)
        predictiveDistalCells = set([2, 3, 5, 8, 10, 12, 13, 14])
        predictiveApicalCells = set([1, 5, 7, 11, 14, 15, 17])
        self.assertEqual(tm.calculatePredictiveCells(predictiveDistalCells, predictiveApicalCells), set([2, 3, 5, 14]))

    def testCompute(self):
        tm = ExtendedTemporalMemory(
            columnDimensions=[4],
            cellsPerColumn=10,
            learnOnOneCell=False,
            initialPermanence=0.2,
            connectedPermanence=0.7,
            activationThreshold=1,
        )

        seg1 = tm.connections.createSegment(0)
        seg2 = tm.connections.createSegment(20)
        seg3 = tm.connections.createSegment(25)
        try:
            tm.connections.createSynapse(seg1, 15, 0.9)
            tm.connections.createSynapse(seg2, 35, 0.9)
            tm.connections.createSynapse(seg2, 45, 0.9)  # external cell
            tm.connections.createSynapse(seg3, 35, 0.9)
            tm.connections.createSynapse(seg3, 50, 0.9)  # external cell
        except IndexError:
            self.fail("IndexError raised unexpectedly for distal segments")

        aSeg1 = tm.apicalConnections.createSegment(1)
        aSeg2 = tm.apicalConnections.createSegment(25)
        try:
            tm.apicalConnections.createSynapse(aSeg1, 3, 0.9)
            tm.apicalConnections.createSynapse(aSeg2, 1, 0.9)
        except IndexError:
            self.fail("IndexError raised unexpectedly for apical segments")

        activeColumns = set([1, 3])
        activeExternalCells = set([5, 10, 15])
        activeApicalCells = set([1, 2, 3, 4])

        tm.compute(
            activeColumns, activeExternalCells=activeExternalCells, activeApicalCells=activeApicalCells, learn=False
        )

        activeColumns = set([0, 2])
        tm.compute(activeColumns, activeExternalCells=set(), activeApicalCells=set())

        self.assertEqual(tm.activeCells, set([0, 20, 25]))

    def testLearning(self):
        tm = ExtendedTemporalMemory(
            columnDimensions=[4],
            cellsPerColumn=10,
            learnOnOneCell=False,
            initialPermanence=0.5,
            connectedPermanence=0.6,
            activationThreshold=1,
            minThreshold=1,
            maxNewSynapseCount=2,
            permanenceDecrement=0.05,
            permanenceIncrement=0.2,
        )

        seg1 = tm.connections.createSegment(0)
        seg2 = tm.connections.createSegment(10)
        seg3 = tm.connections.createSegment(20)
        seg4 = tm.connections.createSegment(30)
        try:
            tm.connections.createSynapse(seg1, 10, 0.9)
            tm.connections.createSynapse(seg2, 20, 0.9)
            tm.connections.createSynapse(seg3, 30, 0.9)
            tm.connections.createSynapse(seg3, 41, 0.9)
            tm.connections.createSynapse(seg3, 25, 0.9)
            tm.connections.createSynapse(seg4, 0, 0.9)
        except IndexError:
            self.fail("IndexError raised unexpectedly for distal segments")

        aSeg1 = tm.apicalConnections.createSegment(0)
        aSeg2 = tm.apicalConnections.createSegment(20)
        try:
            tm.apicalConnections.createSynapse(aSeg1, 42, 0.8)
            tm.apicalConnections.createSynapse(aSeg2, 43, 0.8)
        except IndexError:
            self.fail("IndexError raised unexpectedly for apical segments")

        activeColumns = set([1, 3])
        activeExternalCells = set([1])  # will be re-indexed to 41
        activeApicalCells = set([2, 3])  # will be re-indexed to 42, 43

        tm.compute(
            activeColumns, activeExternalCells=activeExternalCells, activeApicalCells=activeApicalCells, learn=False
        )

        activeColumns = set([0, 2])
        tm.compute(activeColumns, activeExternalCells=None, activeApicalCells=None, learn=True)

        self.assertEqual(tm.activeCells, set([0, 20]))

        # distal learning
        synapse = list(tm.connections.synapsesForSegment(seg1))[0]
        self.assertEqual(tm.connections.dataForSynapse(synapse).permanence, 1.0)

        synapse = list(tm.connections.synapsesForSegment(seg2))[0]
        self.assertEqual(tm.connections.dataForSynapse(synapse).permanence, 0.9)

        synapse = list(tm.connections.synapsesForSegment(seg3))[0]
        self.assertEqual(tm.connections.dataForSynapse(synapse).permanence, 1.0)
        synapse = list(tm.connections.synapsesForSegment(seg3))[1]
        self.assertEqual(tm.connections.dataForSynapse(synapse).permanence, 1.0)
        synapse = list(tm.connections.synapsesForSegment(seg3))[2]
        self.assertEqual(tm.connections.dataForSynapse(synapse).permanence, 0.85)

        synapse = list(tm.connections.synapsesForSegment(seg4))[0]
        self.assertEqual(tm.connections.dataForSynapse(synapse).permanence, 0.9)

        # apical learning
        synapse = list(tm.apicalConnections.synapsesForSegment(aSeg1))[0]
        self.assertEqual(tm.apicalConnections.dataForSynapse(synapse).permanence, 1.0)

        synapse = list(tm.apicalConnections.synapsesForSegment(aSeg2))[0]
        self.assertEqual(tm.apicalConnections.dataForSynapse(synapse).permanence, 1.0)

    @unittest.skipUnless(capnp is not None, "No serialization available for ETM")
    def testWriteRead(self):
        tm1 = ExtendedTemporalMemory(
            columnDimensions=[100],
            cellsPerColumn=4,
            activationThreshold=7,
            initialPermanence=0.37,
            connectedPermanence=0.58,
            minThreshold=4,
            maxNewSynapseCount=18,
            permanenceIncrement=0.23,
            permanenceDecrement=0.08,
            seed=91,
        )

        # Run some data through before serializing
        self.patternMachine = PatternMachine(100, 4)
        self.sequenceMachine = SequenceMachine(self.patternMachine)
        sequence = self.sequenceMachine.generateFromNumbers(range(5))
        for _ in range(3):
            for pattern in sequence:
                tm1.compute(pattern)

        proto1 = TemporalMemoryProto_capnp.TemporalMemoryProto.new_message()
        tm1.write(proto1)

        # Write the proto to a temp file and read it back into a new proto
        with tempfile.TemporaryFile() as f:
            proto1.write(f)
            f.seek(0)
            proto2 = TemporalMemoryProto_capnp.TemporalMemoryProto.read(f)

        # Load the deserialized proto
        tm2 = ExtendedTemporalMemory.read(proto2)

        # Check that the two temporal memory objects have the same attributes
        self.assertEqual(tm1, tm2)

        # Run a couple records through after deserializing and check results match
        tm1.compute(self.patternMachine.get(0))
        tm2.compute(self.patternMachine.get(0))
        self.assertEqual(set(tm1.getActiveCells()), set(tm2.getActiveCells()))
        self.assertEqual(set(tm1.getPredictiveCells()), set(tm2.getPredictiveCells()))
        self.assertEqual(set(tm1.getWinnerCells()), set(tm2.getWinnerCells()))
        self.assertEqual(tm1.connections, tm2.connections)

        tm1.compute(self.patternMachine.get(3))
        tm2.compute(self.patternMachine.get(3))
        self.assertEqual(set(tm1.getActiveCells()), set(tm2.getActiveCells()))
        self.assertEqual(set(tm1.getPredictiveCells()), set(tm2.getPredictiveCells()))
        self.assertEqual(set(tm1.getWinnerCells()), set(tm2.getWinnerCells()))
        self.assertEqual(tm1.connections, tm2.connections)
示例#11
0
class PatternMachineTest(unittest.TestCase):


  def setUp(self):
    self.patternMachine = PatternMachine(10000, 5, num=50)


  def testGet(self):
    patternA = self.patternMachine.get(48)
    self.assertEqual(len(patternA), 5)

    patternB = self.patternMachine.get(49)
    self.assertEqual(len(patternB), 5)

    self.assertEqual(patternA & patternB, set())


  def testGetOutOfBounds(self):
    args = [50]
    self.assertRaises(IndexError, self.patternMachine.get, *args)


  def testAddNoise(self):
    patternMachine = PatternMachine(10000, 1000, num=1)
    pattern = patternMachine.get(0)

    noisy = patternMachine.addNoise(pattern, 0.0)
    self.assertEqual(len(pattern & noisy), 1000)

    noisy = patternMachine.addNoise(pattern, 0.5)
    self.assertTrue(400 < len(pattern & noisy) < 600)

    noisy = patternMachine.addNoise(pattern, 1.0)
    self.assertTrue(50 < len(pattern & noisy) < 150)


  def testNumbersForBit(self):
    pattern = self.patternMachine.get(49)

    for bit in pattern:
      self.assertEqual(self.patternMachine.numbersForBit(bit), set([49]))


  def testNumbersForBitOutOfBounds(self):
    args = [10000]
    self.assertRaises(IndexError, self.patternMachine.numbersForBit, *args)


  def testNumberMapForBits(self):
    pattern = self.patternMachine.get(49)
    numberMap = self.patternMachine.numberMapForBits(pattern)

    self.assertEqual(numberMap.keys(), [49])
    self.assertEqual(numberMap[49], pattern)


  def testWList(self):
    w = [4, 7, 11]
    patternMachine = PatternMachine(100, w, num=50)
    widths = dict((el, 0) for el in w)

    for i in range(50):
      pattern = patternMachine.get(i)
      width = len(pattern)
      self.assertTrue(width in w)
      widths[len(pattern)] += 1

    for i in w:
      self.assertTrue(widths[i] > 0)
class TemporalMemoryTest(unittest.TestCase):
    def setUp(self):
        self.tm = TemporalMemory()

    def testInitInvalidParams(self):
        # Invalid columnDimensions
        kwargs = {"columnDimensions": [], "cellsPerColumn": 32}
        self.assertRaises(ValueError, TemporalMemory, **kwargs)

        # Invalid cellsPerColumn
        kwargs = {"columnDimensions": [2048], "cellsPerColumn": 0}
        self.assertRaises(ValueError, TemporalMemory, **kwargs)
        kwargs = {"columnDimensions": [2048], "cellsPerColumn": -10}
        self.assertRaises(ValueError, TemporalMemory, **kwargs)

    def testActivateCorrectlyPredictiveCells(self):
        tm = self.tm

        prevPredictiveCells = set([0, 237, 1026, 26337, 26339, 55536])
        activeColumns = set([32, 47, 823])
        prevMatchingCells = set()

        (activeCells, winnerCells, predictedColumns,
         predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(
             prevPredictiveCells, prevMatchingCells, activeColumns)

        self.assertEqual(activeCells, set([1026, 26337, 26339]))
        self.assertEqual(winnerCells, set([1026, 26337, 26339]))
        self.assertEqual(predictedColumns, set([32, 823]))
        self.assertEqual(predictedInactiveCells, set())

    def testActivateCorrectlyPredictiveCellsEmpty(self):
        tm = self.tm

        # No previous predictive cells, no active columns
        prevPredictiveCells = set()
        activeColumns = set()
        prevMatchingCells = set()

        (activeCells, winnerCells, predictedColumns,
         predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(
             prevPredictiveCells, prevMatchingCells, activeColumns)

        self.assertEqual(activeCells, set())
        self.assertEqual(winnerCells, set())
        self.assertEqual(predictedColumns, set())
        self.assertEqual(predictedInactiveCells, set())

        # No previous predictive cells, with active columns

        prevPredictiveCells = set()
        activeColumns = set([32, 47, 823])
        prevMatchingCells = set()

        (activeCells, winnerCells, predictedColumns,
         predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(
             prevPredictiveCells, prevMatchingCells, activeColumns)

        self.assertEqual(activeCells, set())
        self.assertEqual(winnerCells, set())
        self.assertEqual(predictedColumns, set())
        self.assertEqual(predictedInactiveCells, set())

        # No active columns, with previously predictive cells

        prevPredictiveCells = set([0, 237, 1026, 26337, 26339, 55536])
        activeColumns = set()
        prevMatchingCells = set()

        (activeCells, winnerCells, predictedColumns,
         predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(
             prevPredictiveCells, prevMatchingCells, activeColumns)

        self.assertEqual(activeCells, set())
        self.assertEqual(winnerCells, set())
        self.assertEqual(predictedColumns, set())
        self.assertEqual(predictedInactiveCells, set())

    def testActivateCorrectlyPredictiveCellsOrphan(self):
        tm = self.tm
        tm.predictedSegmentDecrement = 0.001
        prevPredictiveCells = set([])
        activeColumns = set([32, 47, 823])
        prevMatchingCells = set([32, 47])

        (activeCells, winnerCells, predictedColumns,
         predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(
             prevPredictiveCells, prevMatchingCells, activeColumns)

        self.assertEqual(activeCells, set([]))
        self.assertEqual(winnerCells, set([]))
        self.assertEqual(predictedColumns, set([]))
        self.assertEqual(predictedInactiveCells, set([32, 47]))

    def testBurstColumns(self):
        tm = TemporalMemory(cellsPerColumn=4,
                            connectedPermanence=0.50,
                            minThreshold=1,
                            seed=42)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.4)
        connections.createSynapse(0, 477, 0.9)

        connections.createSegment(0)
        connections.createSynapse(1, 49, 0.9)
        connections.createSynapse(1, 3, 0.8)

        connections.createSegment(1)
        connections.createSynapse(2, 733, 0.7)

        connections.createSegment(108)
        connections.createSynapse(3, 486, 0.9)

        activeColumns = set([0, 1, 26])
        predictedColumns = set([26])
        prevActiveCells = set([23, 37, 49, 733])
        prevWinnerCells = set([23, 37, 49, 733])

        (activeCells, winnerCells,
         learningSegments) = tm.burstColumns(activeColumns, predictedColumns,
                                             prevActiveCells, prevWinnerCells,
                                             connections)

        self.assertEqual(activeCells, set([0, 1, 2, 3, 4, 5, 6, 7]))
        self.assertEqual(winnerCells, set([0, 6]))  # 6 is randomly chosen cell
        self.assertEqual(learningSegments,
                         set([0, 4]))  # 4 is new segment created

        # Check that new segment was added to winner cell (6) in column 1
        self.assertEqual(connections.segmentsForCell(6), set([4]))

    def testBurstColumnsEmpty(self):
        tm = self.tm

        activeColumns = set()
        predictedColumns = set()
        prevActiveCells = set()
        prevWinnerCells = set()
        connections = tm.connections

        (activeCells, winnerCells,
         learningSegments) = tm.burstColumns(activeColumns, predictedColumns,
                                             prevActiveCells, prevWinnerCells,
                                             connections)

        self.assertEqual(activeCells, set())
        self.assertEqual(winnerCells, set())
        self.assertEqual(learningSegments, set())

    def testLearnOnSegments(self):
        tm = TemporalMemory(maxNewSynapseCount=2)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.4)
        connections.createSynapse(0, 477, 0.9)

        connections.createSegment(1)
        connections.createSynapse(1, 733, 0.7)

        connections.createSegment(8)
        connections.createSynapse(2, 486, 0.9)

        connections.createSegment(100)

        prevActiveSegments = set([0, 2])
        learningSegments = set([1, 3])
        prevActiveCells = set([23, 37, 733])
        winnerCells = set([0])
        prevWinnerCells = set([10, 11, 12, 13, 14])
        predictedInactiveCells = set()
        prevMatchingSegments = set()
        tm.learnOnSegments(prevActiveSegments, learningSegments,
                           prevActiveCells, winnerCells, prevWinnerCells,
                           connections, predictedInactiveCells,
                           prevMatchingSegments)

        # Check segment 0
        synapseData = connections.dataForSynapse(0)
        self.assertAlmostEqual(synapseData.permanence, 0.7)

        synapseData = connections.dataForSynapse(1)
        self.assertAlmostEqual(synapseData.permanence, 0.5)

        synapseData = connections.dataForSynapse(2)
        self.assertAlmostEqual(synapseData.permanence, 0.8)

        # Check segment 1
        synapseData = connections.dataForSynapse(3)
        self.assertAlmostEqual(synapseData.permanence, 0.8)

        self.assertEqual(len(connections.synapsesForSegment(1)), 2)

        # Check segment 2
        synapseData = connections.dataForSynapse(4)
        self.assertAlmostEqual(synapseData.permanence, 0.9)

        self.assertEqual(len(connections.synapsesForSegment(2)), 1)

        # Check segment 3
        self.assertEqual(len(connections.synapsesForSegment(3)), 2)

    def testComputePredictiveCells(self):
        tm = TemporalMemory(activationThreshold=2,
                            minThreshold=2,
                            predictedSegmentDecrement=0.004)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.5)
        connections.createSynapse(0, 477, 0.9)

        connections.createSegment(1)
        connections.createSynapse(1, 733, 0.7)
        connections.createSynapse(1, 733, 0.4)

        connections.createSegment(1)
        connections.createSynapse(2, 974, 0.9)

        connections.createSegment(8)
        connections.createSynapse(3, 486, 0.9)

        connections.createSegment(100)

        activeCells = set([23, 37, 733, 974])

        (activeSegments, predictiveCells, matchingSegments,
         matchingCells) = tm.computePredictiveCells(activeCells, connections)
        self.assertEqual(activeSegments, set([0]))
        self.assertEqual(predictiveCells, set([0]))
        self.assertEqual(matchingSegments, set([0, 1]))
        self.assertEqual(matchingCells, set([0, 1]))

    def testBestMatchingCell(self):
        tm = TemporalMemory(connectedPermanence=0.50, minThreshold=1, seed=42)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.4)
        connections.createSynapse(0, 477, 0.9)

        connections.createSegment(0)
        connections.createSynapse(1, 49, 0.9)
        connections.createSynapse(1, 3, 0.8)

        connections.createSegment(1)
        connections.createSynapse(2, 733, 0.7)

        connections.createSegment(108)
        connections.createSynapse(3, 486, 0.9)

        activeCells = set([23, 37, 49, 733])

        self.assertEqual(
            tm.bestMatchingCell(tm.cellsForColumn(0), activeCells,
                                connections), (0, 0))

        self.assertEqual(
            tm.bestMatchingCell(
                tm.cellsForColumn(3),  # column containing cell 108
                activeCells,
                connections),
            (96, None))  # Random cell from column

        self.assertEqual(
            tm.bestMatchingCell(tm.cellsForColumn(999), activeCells,
                                connections),
            (31972, None))  # Random cell from column

    def testBestMatchingCellFewestSegments(self):
        tm = TemporalMemory(columnDimensions=[2],
                            cellsPerColumn=2,
                            connectedPermanence=0.50,
                            minThreshold=1,
                            seed=42)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 3, 0.3)

        activeSynapsesForSegment = set([])

        for _ in range(100):
            # Never pick cell 0, always pick cell 1
            (cell, _) = tm.bestMatchingCell(tm.cellsForColumn(0),
                                            activeSynapsesForSegment,
                                            connections)
            self.assertEqual(cell, 1)

    def testBestMatchingSegment(self):
        tm = TemporalMemory(connectedPermanence=0.50, minThreshold=1)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.4)
        connections.createSynapse(0, 477, 0.9)

        connections.createSegment(0)
        connections.createSynapse(1, 49, 0.9)
        connections.createSynapse(1, 3, 0.8)

        connections.createSegment(1)
        connections.createSynapse(2, 733, 0.7)

        connections.createSegment(8)
        connections.createSynapse(3, 486, 0.9)

        activeCells = set([23, 37, 49, 733])

        self.assertEqual(tm.bestMatchingSegment(0, activeCells, connections),
                         (0, 2))

        self.assertEqual(tm.bestMatchingSegment(1, activeCells, connections),
                         (2, 1))

        self.assertEqual(tm.bestMatchingSegment(8, activeCells, connections),
                         (None, None))

        self.assertEqual(tm.bestMatchingSegment(100, activeCells, connections),
                         (None, None))

    def testLeastUsedCell(self):
        tm = TemporalMemory(columnDimensions=[2], cellsPerColumn=2, seed=42)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 3, 0.3)

        for _ in range(100):
            # Never pick cell 0, always pick cell 1
            self.assertEqual(
                tm.leastUsedCell(tm.cellsForColumn(0), connections), 1)

    def testAdaptSegment(self):
        tm = self.tm

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)
        connections.createSynapse(0, 37, 0.4)
        connections.createSynapse(0, 477, 0.9)

        tm.adaptSegment(0, set([0, 1]), connections, tm.permanenceIncrement,
                        tm.permanenceDecrement)

        synapseData = connections.dataForSynapse(0)
        self.assertAlmostEqual(synapseData.permanence, 0.7)

        synapseData = connections.dataForSynapse(1)
        self.assertAlmostEqual(synapseData.permanence, 0.5)

        synapseData = connections.dataForSynapse(2)
        self.assertAlmostEqual(synapseData.permanence, 0.8)

    def testAdaptSegmentToMax(self):
        tm = self.tm

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.9)

        tm.adaptSegment(0, set([0]), connections, tm.permanenceIncrement,
                        tm.permanenceDecrement)
        synapseData = connections.dataForSynapse(0)
        self.assertAlmostEqual(synapseData.permanence, 1.0)

        # Now permanence should be at max
        tm.adaptSegment(0, set([0]), connections, tm.permanenceIncrement,
                        tm.permanenceDecrement)
        synapseData = connections.dataForSynapse(0)
        self.assertAlmostEqual(synapseData.permanence, 1.0)

    def testAdaptSegmentToMin(self):
        tm = self.tm

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.1)

        tm.adaptSegment(0, set(), connections, tm.permanenceIncrement,
                        tm.permanenceDecrement)

        synapses = connections.synapsesForSegment(0)
        self.assertFalse(0 in synapses)

    def testPickCellsToLearnOn(self):
        tm = TemporalMemory(seed=42)

        connections = tm.connections
        connections.createSegment(0)

        winnerCells = set([4, 47, 58, 93])

        self.assertEqual(tm.pickCellsToLearnOn(2, 0, winnerCells, connections),
                         set([4, 58]))  # randomly picked

        self.assertEqual(
            tm.pickCellsToLearnOn(100, 0, winnerCells, connections),
            set([4, 47, 58, 93]))

        self.assertEqual(tm.pickCellsToLearnOn(0, 0, winnerCells, connections),
                         set())

    def testPickCellsToLearnOnAvoidDuplicates(self):
        tm = TemporalMemory(seed=42)

        connections = tm.connections
        connections.createSegment(0)
        connections.createSynapse(0, 23, 0.6)

        winnerCells = set([23])

        # Ensure that no additional (duplicate) cells were picked
        self.assertEqual(tm.pickCellsToLearnOn(2, 0, winnerCells, connections),
                         set())

    def testColumnForCell1D(self):
        tm = TemporalMemory(columnDimensions=[2048], cellsPerColumn=5)
        self.assertEqual(tm.columnForCell(0), 0)
        self.assertEqual(tm.columnForCell(4), 0)
        self.assertEqual(tm.columnForCell(5), 1)
        self.assertEqual(tm.columnForCell(10239), 2047)

    def testColumnForCell2D(self):
        tm = TemporalMemory(columnDimensions=[64, 64], cellsPerColumn=4)
        self.assertEqual(tm.columnForCell(0), 0)
        self.assertEqual(tm.columnForCell(3), 0)
        self.assertEqual(tm.columnForCell(4), 1)
        self.assertEqual(tm.columnForCell(16383), 4095)

    def testColumnForCellInvalidCell(self):
        tm = TemporalMemory(columnDimensions=[64, 64], cellsPerColumn=4)

        try:
            tm.columnForCell(16383)
        except IndexError:
            self.fail("IndexError raised unexpectedly")

        args = [16384]
        self.assertRaises(IndexError, tm.columnForCell, *args)

        args = [-1]
        self.assertRaises(IndexError, tm.columnForCell, *args)

    def testCellsForColumn1D(self):
        tm = TemporalMemory(columnDimensions=[2048], cellsPerColumn=5)
        expectedCells = set([5, 6, 7, 8, 9])
        self.assertEqual(tm.cellsForColumn(1), expectedCells)

    def testCellsForColumn2D(self):
        tm = TemporalMemory(columnDimensions=[64, 64], cellsPerColumn=4)
        expectedCells = set([256, 257, 258, 259])
        self.assertEqual(tm.cellsForColumn(64), expectedCells)

    def testCellsForColumnInvalidColumn(self):
        tm = TemporalMemory(columnDimensions=[64, 64], cellsPerColumn=4)

        try:
            tm.cellsForColumn(4095)
        except IndexError:
            self.fail("IndexError raised unexpectedly")

        args = [4096]
        self.assertRaises(IndexError, tm.cellsForColumn, *args)

        args = [-1]
        self.assertRaises(IndexError, tm.cellsForColumn, *args)

    def testNumberOfColumns(self):
        tm = TemporalMemory(columnDimensions=[64, 64], cellsPerColumn=32)
        self.assertEqual(tm.numberOfColumns(), 64 * 64)

    def testNumberOfCells(self):
        tm = TemporalMemory(columnDimensions=[64, 64], cellsPerColumn=32)
        self.assertEqual(tm.numberOfCells(), 64 * 64 * 32)

    def testMapCellsToColumns(self):
        tm = TemporalMemory(columnDimensions=[100], cellsPerColumn=4)
        columnsForCells = tm.mapCellsToColumns(set([0, 1, 2, 5, 399]))
        self.assertEqual(columnsForCells[0], set([0, 1, 2]))
        self.assertEqual(columnsForCells[1], set([5]))
        self.assertEqual(columnsForCells[99], set([399]))

    def testWrite(self):
        tm1 = TemporalMemory(columnDimensions=[100],
                             cellsPerColumn=4,
                             activationThreshold=7,
                             initialPermanence=0.37,
                             connectedPermanence=0.58,
                             minThreshold=4,
                             maxNewSynapseCount=18,
                             permanenceIncrement=0.23,
                             permanenceDecrement=0.08,
                             seed=91)

        # Run some data through before serializing
        self.patternMachine = PatternMachine(100, 4)
        self.sequenceMachine = SequenceMachine(self.patternMachine)
        sequence = self.sequenceMachine.generateFromNumbers(range(5))
        for _ in range(3):
            for pattern in sequence:
                tm1.compute(pattern)

        proto1 = TemporalMemoryProto_capnp.TemporalMemoryProto.new_message()
        tm1.write(proto1)

        # Write the proto to a temp file and read it back into a new proto
        with tempfile.TemporaryFile() as f:
            proto1.write(f)
            f.seek(0)
            proto2 = TemporalMemoryProto_capnp.TemporalMemoryProto.read(f)

        # Load the deserialized proto
        tm2 = TemporalMemory.read(proto2)

        # Check that the two temporal memory objects have the same attributes
        self.assertEqual(tm1, tm2)

        # Run a couple records through after deserializing and check results match
        tm1.compute(self.patternMachine.get(0))
        tm2.compute(self.patternMachine.get(0))
        self.assertEqual(tm1.activeCells, tm2.activeCells)
        self.assertEqual(tm1.predictiveCells, tm2.predictiveCells)
        self.assertEqual(tm1.winnerCells, tm2.winnerCells)
        self.assertEqual(tm1.connections, tm2.connections)

        tm1.compute(self.patternMachine.get(3))
        tm2.compute(self.patternMachine.get(3))
        self.assertEqual(tm1.activeCells, tm2.activeCells)
        self.assertEqual(tm1.predictiveCells, tm2.predictiveCells)
        self.assertEqual(tm1.winnerCells, tm2.winnerCells)
        self.assertEqual(tm1.connections, tm2.connections)
class ExtensiveColumnPoolerTest(unittest.TestCase):
  """
  Algorithmic tests for the ColumnPooler region.

  Each test actually tests multiple aspects of the algorithm. For more
  atomic tests refer to column_pooler_unit_test.

  The notation for objects is the following:
    object{patternA, patternB, ...}

  In these tests, the proximally-fed SDR's are simulated as unique (location,
  feature) pairs regardless of actual locations and features, unless stated
  otherwise.
  """

  inputWidth = 2048 * 8
  numInputActiveBits = int(0.02 * inputWidth)
  outputWidth = 2048
  numOutputActiveBits = 40
  seed = 42


  def testNewInputs(self):
    """
    Checks that the behavior is correct when facing unseed inputs.
    """
    self.init()

    # feed the first input, a random SDR should be generated
    initialPattern = self.generateObject(1)
    self.learn(initialPattern, numRepetitions=1, newObject=True)
    representation = self._getActiveRepresentation()
    self.assertEqual(
      len(representation),
      self.numOutputActiveBits,
      "The generated representation is incorrect"
    )

    # feed a new input for the same object, the previous SDR should persist
    newPattern = self.generateObject(1)
    self.learn(newPattern, numRepetitions=1, newObject=False)
    newRepresentation = self._getActiveRepresentation()
    self.assertNotEqual(initialPattern, newPattern)
    self.assertEqual(
      newRepresentation,
      representation,
      "The SDR did not persist when learning the same object"
    )

    # without sensory input, the SDR should persist as well
    emptyPattern = [set()]
    self.learn(emptyPattern, numRepetitions=1, newObject=False)
    newRepresentation = self._getActiveRepresentation()
    self.assertEqual(
      newRepresentation,
      representation,
      "The SDR did not persist after an empty input."
    )


  def testLearnSinglePattern(self):
    """
    A single pattern is learnt for a single object.
    Objects: A{X, Y}
    """
    self.init()

    object = self.generateObject(1)
    self.learn(object, numRepetitions=1, newObject=True)
    # check that the active representation is sparse
    representation = self._getActiveRepresentation()
    self.assertEqual(
      len(representation),
      self.numOutputActiveBits,
      "The generated representation is incorrect"
    )

    # check that the pattern was correctly learnt
    self.infer(feedforwardPattern=object[0])
    self.assertEqual(
      self._getActiveRepresentation(),
      representation,
      "The pooled representation is not stable"
    )

    # present new pattern for same object
    # it should be mapped to the same representation
    newPattern = [self.generatePattern()]
    self.learn(newPattern, numRepetitions=1, newObject=False)
    # check that the active representation is sparse
    newRepresentation = self._getActiveRepresentation()
    self.assertEqual(
      newRepresentation,
      representation,
      "The new pattern did not map to the same object representation"
    )

    # check that the pattern was correctly learnt and is stable
    self.infer(feedforwardPattern=object[0])
    self.assertEqual(
      self._getActiveRepresentation(),
      representation,
      "The pooled representation is not stable"
    )


  def testLearnSingleObject(self):
    """
    Many patterns are learnt for a single object.
    Objects: A{P, Q, R, S, T}
    """
    self.init()

    object = self.generateObject(numPatterns=5)
    self.learn(object, numRepetitions=1, randomOrder=True, newObject=True)
    representation = self._getActiveRepresentation()

    # check that all patterns map to the same object
    for pattern in object:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representation,
        "The pooled representation is not stable"
      )

    # if activity stops, check that the representation persists
    self.infer(feedforwardPattern=set())
    self.assertEqual(
      self._getActiveRepresentation(),
      representation,
      "The pooled representation did not persist"
    )


  def testLearnTwoObjectNoCommonPattern(self):
    """
    Same test as before, using two objects, without common pattern.
    Objects: A{P, Q, R, S,T}   B{V, W, X, Y, Z}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=3, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    self.learn(objectB, numRepetitions=3, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    self.assertNotEqual(representationA, representationB)

    # check that all patterns map to the same object
    for pattern in objectA:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns map to the same object
    for pattern in objectB:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # feed union of patterns in object A
    pattern = objectA[0] | objectA[1]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns in objects A and B
    pattern = objectA[0] | objectB[0]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )


  def testLearnTwoObjectsOneCommonPattern(self):
    """
    Same test as before, except the two objects share a pattern
    Objects: A{P, Q, R, S,T}   B{P, W, X, Y, Z}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=3, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    objectB[0] = objectA[0]
    self.learn(objectB, numRepetitions=3, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    self.assertNotEqual(representationA, representationB)
    # very small overlap
    self.assertLessEqual(len(representationA & representationB), 3)

    # check that all patterns except the common one map to the same object
    for pattern in objectA[1:]:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectB[1:]:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # feed shared pattern
    pattern = objectA[0]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )

    # feed union of patterns in object A
    pattern = objectA[1] | objectA[2]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns in objects A and B
    pattern = objectA[1] | objectB[1]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )


  def testLearnThreeObjectsOneCommonPattern(self):
    """
    Same test as before, with three objects
    Objects: A{P, Q, R, S,T}   B{P, W, X, Y, Z}   C{W, H, I, K, L}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=3, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    objectB[0] = objectA[0]
    self.learn(objectB, numRepetitions=3, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    objectC = self.generateObject(numPatterns=5)
    objectC[0] = objectB[1]
    self.learn(objectC, numRepetitions=3, randomOrder=True, newObject=True)
    representationC = self._getActiveRepresentation()

    self.assertNotEquals(representationA, representationB, representationC)
    # very small overlap
    self.assertLessEqual(len(representationA & representationB), 3)
    self.assertLessEqual(len(representationB & representationC), 3)
    self.assertLessEqual(len(representationA & representationC), 3)


    # check that all patterns except the common one map to the same object
    for pattern in objectA[1:]:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectB[2:]:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectC[1:]:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationC,
        "The pooled representation for the third object is not stable"
      )

    # feed shared pattern between A and B
    pattern = objectA[0]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )

    # feed shared pattern between B and C
    pattern = objectB[1]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationB | representationC,
      "The active representation is incorrect"
    )

    # feed union of patterns in object A
    pattern = objectA[1] | objectA[2]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns to activate all objects
    pattern = objectA[1] | objectB[1]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB | representationC,
      "The active representation is incorrect"
    )


  def testLearnThreeObjectsOneCommonPatternSpatialNoise(self):
    """
    Same test as before, with three objects
    Objects: A{P, Q, R, S,T}   B{P, W, X, Y, Z}   C{W, H, I, K, L}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=3, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    objectB[0] = objectA[0]
    self.learn(objectB, numRepetitions=3, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    objectC = self.generateObject(numPatterns=5)
    objectC[0] = objectB[1]
    self.learn(objectC, numRepetitions=3, randomOrder=True, newObject=True)
    representationC = self._getActiveRepresentation()

    self.assertNotEquals(representationA, representationB, representationC)
    # very small overlap
    self.assertLessEqual(len(representationA & representationB), 3)
    self.assertLessEqual(len(representationB & representationC), 3)
    self.assertLessEqual(len(representationA & representationC), 3)


    # check that all patterns except the common one map to the same object
    for pattern in objectA[1:]:
      noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
      self.infer(feedforwardPattern=noisyPattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectB[2:]:
      noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
      self.infer(feedforwardPattern=noisyPattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectC[1:]:
      noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
      self.infer(feedforwardPattern=noisyPattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationC,
        "The pooled representation for the third object is not stable"
      )

    # feed shared pattern between A and B
    pattern = objectA[0]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )

    # feed shared pattern between B and C
    pattern = objectB[1]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationB | representationC,
      "The active representation is incorrect"
    )

    # feed union of patterns in object A
    pattern = objectA[1] | objectA[2]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns to activate all objects
    pattern = objectA[1] | objectB[1]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB | representationC,
      "The active representation is incorrect"
    )


  def testLearnOneObjectInTwoColumns(self):
    """Learns one object in two different columns."""
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    objectARepresentations = self._getActiveRepresentations()

    for pooler in self.poolers:
      pooler.reset()

    for patterns in objectA:
      for i in xrange(3):
        activeRepresentations = self._getActiveRepresentations()

        self.inferMultipleColumns(
          feedforwardPatterns=patterns,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        if i > 0:
          self.assertEqual(activeRepresentations,
                           self._getActiveRepresentations())
          self.assertEqual(objectARepresentations,
                           self._getActiveRepresentations())


  def testLearnTwoObjectsInTwoColumnsNoCommonPattern(self):
    """Learns two objects in two different columns."""
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)
    objectB = self.generateObject(numPatterns=5, numCols=2)

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsA = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectB,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True,
    )
    activeRepresentationsB = self._getActiveRepresentations()

    for pooler in self.poolers:
      pooler.reset()

    # check inference for object A
    # for the first pattern, the distal predictions won't be correct
    firstPattern = True
    for patternsA in objectA:
      for i in xrange(3):
        activeRepresentations = self._getActiveRepresentations()
        self.inferMultipleColumns(
          feedforwardPatterns=patternsA,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        if firstPattern:
          firstPattern = False
        else:
          self.assertEqual(
            activeRepresentationsA,
            self._getPredictedActiveCells()
          )
        self.assertEqual(
          activeRepresentationsA,
          self._getActiveRepresentations()
        )

    for pooler in self.poolers:
      pooler.reset()

    # check inference for object B
    firstPattern = True
    for patternsB in objectB:
      for i in xrange(3):
        activeRepresentations = self._getActiveRepresentations()
        self.inferMultipleColumns(
          feedforwardPatterns=patternsB,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices
        )

        if firstPattern:
          firstPattern = False
        else:
          self.assertEqual(
            activeRepresentationsB,
            self._getPredictedActiveCells()
          )
        self.assertEqual(
          activeRepresentationsB,
          self._getActiveRepresentations()
        )


  def testLearnTwoObjectsInTwoColumnsOneCommonPattern(self):
    """Learns two objects in two different columns, with a common pattern."""
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)
    objectB = self.generateObject(numPatterns=5, numCols=2)

    # second pattern in column 0 is shared
    objectB[1][0] = objectA[1][0]

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsA = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectB,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsB = self._getActiveRepresentations()

    # check inference for object A
    # for the first pattern, the distal predictions won't be correct
    # for the second one, the prediction will be unique thanks to the
    # distal predictions from the other column which has no ambiguity
    for pooler in self.poolers:
      pooler.reset()

    firstPattern = True
    for patternsA in objectA:
      for i in xrange(3):
        activeRepresentations = self._getActiveRepresentations()
        self.inferMultipleColumns(
          feedforwardPatterns=patternsA,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        if firstPattern:
          firstPattern = False
        else:
          self.assertEqual(
            activeRepresentationsA,
            self._getPredictedActiveCells()
          )
        self.assertEqual(
          activeRepresentationsA,
          self._getActiveRepresentations()
        )

    for pooler in self.poolers:
      pooler.reset()

    # check inference for object B
    firstPattern = True
    for patternsB in objectB:
      for i in xrange(3):
        activeRepresentations = self._getActiveRepresentations()
        self.inferMultipleColumns(
          feedforwardPatterns=patternsB,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices
        )

        if firstPattern:
          firstPattern = False
        else:
          self.assertEqual(
            activeRepresentationsB,
            self._getPredictedActiveCells()
          )
        self.assertEqual(
          activeRepresentationsB,
          self._getActiveRepresentations()
        )


  def testLearnTwoObjectsInTwoColumnsOneCommonPatternEmptyFirstInput(self):
    """Learns two objects in two different columns, with a common pattern."""
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)
    objectB = self.generateObject(numPatterns=5, numCols=2)

    # second pattern in column 0 is shared
    objectB[1][0] = objectA[1][0]

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsA = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectB,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsB = self._getActiveRepresentations()

    # check inference for object A
    for pooler in self.poolers:
      pooler.reset()

    firstPattern = True
    for patternsA in objectA:
      activeRepresentations = self._getActiveRepresentations()
      if firstPattern:
        self.inferMultipleColumns(
          feedforwardPatterns=[set(), patternsA[1]],
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        desiredRepresentation = [set(), activeRepresentationsA[1]]
      else:
        self.inferMultipleColumns(
          feedforwardPatterns=patternsA,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        desiredRepresentation = activeRepresentationsA
      self.assertEqual(
        desiredRepresentation,
        self._getActiveRepresentations()
      )


  def testPersistence(self):
    """After learning, representation should persist in L2 without input."""
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    objectARepresentations = self._getActiveRepresentations()

    for pooler in self.poolers:
      pooler.reset()

    for patterns in objectA:
      for i in xrange(3):

        # replace third pattern for column 2 by empty pattern
        if i == 2:
          patterns[1] = set()

        activeRepresentations = self._getActiveRepresentations()

        self.inferMultipleColumns(
          feedforwardPatterns=patterns,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        if i > 0:
          self.assertEqual(activeRepresentations,
                           self._getActiveRepresentations())
          self.assertEqual(objectARepresentations,
                           self._getActiveRepresentations())


  def testLateralDisambiguation(self):
    """Lateral disambiguation using a constant simulated distal input."""
    self.init()

    objectA = self.generateObject(numPatterns=5)
    lateralInputA = [None] + [self.generatePattern() for _ in xrange(4)]
    self.learn(objectA,
               lateralPatterns=lateralInputA,
               numRepetitions=3,
               randomOrder=True,
               newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    objectB[3] = objectA[3]
    lateralInputB = [None] + [self.generatePattern() for _ in xrange(4)]
    self.learn(objectB,
               lateralPatterns=lateralInputB,
               numRepetitions=3,
               randomOrder=True,
               newObject=True)
    representationB = self._getActiveRepresentation()

    self.assertNotEqual(representationA, representationB)
    # very small overlap
    self.assertLessEqual(len(representationA & representationB), 3)

    # no ambiguity with lateral input
    for pattern in objectA:
      self.infer(feedforwardPattern=pattern, lateralInput=lateralInputA[-1])
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # no ambiguity with lateral input
    for pattern in objectB:
      self.infer(feedforwardPattern=pattern, lateralInput=lateralInputB[-1])
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )


  @unittest.skip("Fails, need to discuss")
  def testMultiColumnCompetition(self):
    """Competition between multiple conflicting lateral inputs."""
    self.init(numCols=4)
    neighborsIndices = [[1, 2, 3], [0, 2, 3], [0, 1, 3], [0, 1, 2]]

    objectA = self.generateObject(numPatterns=5, numCols=4)
    objectB = self.generateObject(numPatterns=5, numCols=4)

    # second pattern in column 0 is shared
    objectB[1][0] = objectA[1][0]

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsA = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectB,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsB = self._getActiveRepresentations()

    # check inference for object A
    # for the first pattern, the distal predictions won't be correct
    # for the second one, the prediction will be unique thanks to the
    # distal predictions from the other column which has no ambiguity
    for pooler in self.poolers:
      pooler.reset()

    # sensed patterns will be mixed
    sensedPatterns = objectA[1][:-1] + [objectA[1][-1] | objectB[1][-1]]

    # feed sensed patterns first time
    # every one feels the correct object, except first column which feels
    # the union (reminder: lateral input are delayed)
    activeRepresentations = self._getActiveRepresentations()
    self.inferMultipleColumns(
      feedforwardPatterns=sensedPatterns,
      activeRepresentations=activeRepresentations,
      neighborsIndices=neighborsIndices,
    )
    firstSensedRepresentations = [
      activeRepresentationsA[0] | activeRepresentationsB[0],
      activeRepresentationsA[1],
      activeRepresentationsA[2],
      activeRepresentationsA[3] | activeRepresentationsB[3]
    ]
    self.assertEqual(
      firstSensedRepresentations,
      self._getActiveRepresentations()
    )

    # feed sensed patterns second time
    # the distal predictions are still ambiguous in C1, but disambiguated
    # in C4
    activeRepresentations = self._getActiveRepresentations()
    self.inferMultipleColumns(
      feedforwardPatterns=sensedPatterns,
      activeRepresentations=activeRepresentations,
      neighborsIndices=neighborsIndices,
    )
    secondSensedRepresentations = [
      activeRepresentationsA[0] | activeRepresentationsB[0],
      activeRepresentationsA[1],
      activeRepresentationsA[2],
      activeRepresentationsA[3]
    ]
    self.assertEqual(
      secondSensedRepresentations,
      self._getActiveRepresentations()
    )

    # feed sensed patterns third time
    # this time, it is all disambiguated
    activeRepresentations = self._getActiveRepresentations()
    self.inferMultipleColumns(
      feedforwardPatterns=sensedPatterns,
      activeRepresentations=activeRepresentations,
      neighborsIndices=neighborsIndices,
    )
    self.assertEqual(
      activeRepresentationsA,
      self._getActiveRepresentations()
    )


  def testMutualDisambiguationThroughUnions(self):
    """
    Learns three object in two different columns.

    Feed ambiguous sensations, A u B and B u C. The system should narrow down
    to B.
    """
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)
    objectB = self.generateObject(numPatterns=5, numCols=2)
    objectC = self.generateObject(numPatterns=5, numCols=2)

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsA = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectB,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsB = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectC,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsC = self._getActiveRepresentations()

    # create sensed patterns (ambiguous)
    sensedPatterns = [objectA[1][0] | objectB[1][0],
                      objectB[2][1] | objectC[2][1]]

    for pooler in self.poolers:
      pooler.reset()

    # feed sensed patterns first time
    # the L2 representations should be ambiguous
    activeRepresentations = self._getActiveRepresentations()
    self.inferMultipleColumns(
      feedforwardPatterns=sensedPatterns,
      activeRepresentations=activeRepresentations,
      neighborsIndices=neighborsIndices,
    )
    firstRepresentations = [
      activeRepresentationsA[0] | activeRepresentationsB[0],
      activeRepresentationsB[1] | activeRepresentationsC[1]
    ]
    self.assertEqual(
      firstRepresentations,
      self._getActiveRepresentations()
    )

    # feed a second time
    # the L2 representations should be ambiguous
    activeRepresentations = self._getActiveRepresentations()
    self.inferMultipleColumns(
      feedforwardPatterns=sensedPatterns,
      activeRepresentations=activeRepresentations,
      neighborsIndices=neighborsIndices,
    )
    self.assertEqual(
      firstRepresentations,
      self._getActiveRepresentations()
    )

    # feed a third time, distal predictions should disambiguate
    # we are using the third time because there is an off-by-one in pooler
    activeRepresentations = self._getActiveRepresentations()
    self.inferMultipleColumns(
      feedforwardPatterns=sensedPatterns,
      activeRepresentations=activeRepresentations,
      neighborsIndices=neighborsIndices,
    )

    # check that representations are unique, being slightly tolerant
    self.assertLessEqual(
      len(self._getActiveRepresentations()[0] - activeRepresentationsB[0]),
      5,
    )

    self.assertLessEqual(
      len(self._getActiveRepresentations()[1] - activeRepresentationsB[1]),
      5,
    )

    self.assertGreaterEqual(
      len(self._getActiveRepresentations()[0] & activeRepresentationsB[0]),
      35,
    )

    self.assertGreaterEqual(
      len(self._getActiveRepresentations()[1] & activeRepresentationsB[1]),
      35,
    )

    self.assertEqual(
      self._getActiveRepresentations(),
      self._getPredictedActiveCells(),
    )


  def setUp(self):
    """
    Sets up the test.
    """
    # single column case
    self.pooler = None

    # multi column case
    self.poolers = []

    # create pattern machine
    self.proximalPatternMachine = PatternMachine(
      n=self.inputWidth,
      w=self.numOutputActiveBits,
      num=200,
      seed=self.seed
    )

    self.patternId = 0
    np.random.seed(self.seed)


  # Wrappers around ColumnPooler API

  def learn(self,
            feedforwardPatterns,
            lateralPatterns=None,
            numRepetitions=1,
            randomOrder=True,
            newObject=True):
    """
    Parameters:
    ----------------------------
    Learns a single object, with the provided patterns.

    @param   feedforwardPatterns   (list(set))
             List of proximal input patterns

    @param   lateralPatterns       (list(list(set)))
             List of distal input patterns, or None. If no lateral input is
             used. The outer list is expected to have the same length as
             feedforwardPatterns, whereas each inner list's length is the
             number of cortical columns which are distally connected to the
             pooler.

    @param   numRepetitions        (int)
             Number of times the patterns will be fed

    @param   randomOrder           (bool)
             If true, the order of patterns will be shuffled at each
             repetition

    """
    if newObject:
      self.pooler.mmClearHistory()
      self.pooler.reset()

    # set-up
    indices = range(len(feedforwardPatterns))
    if lateralPatterns is None:
      lateralPatterns = [None] * len(feedforwardPatterns)

    for _ in xrange(numRepetitions):
      if randomOrder:
        np.random.shuffle(indices)

      for idx in indices:
        self.pooler.compute(feedforwardPatterns[idx],
                            activeExternalCells=lateralPatterns[idx],
                            learn=True)


  def infer(self,
            feedforwardPattern,
            lateralInput=None,
            printMetrics=False):
    """
    Feeds a single pattern to the column pooler (as well as an eventual lateral
    pattern).

    Parameters:
    ----------------------------
    @param feedforwardPattern       (set)
           Input proximal pattern to the pooler

    @param lateralPatterns          (list(set))
           Input dislal patterns to the pooler (one for each neighboring CC's)


    @param printMetrics             (bool)
           If true, will print cell metrics

    """
    self.pooler.compute(feedforwardPattern,
                        activeExternalCells=lateralInput,
                        learn=False)

    if printMetrics:
      print self.pooler.mmPrettyPrintMetrics(
        self.pooler.mmGetDefaultMetrics()
      )


  # Helper functions

  def generatePattern(self):
    """
    Returns a random proximal input pattern.
    """
    pattern = self.proximalPatternMachine.get(self.patternId)
    self.patternId += 1
    return pattern


  def generateObject(self, numPatterns, numCols=1):
    """
    Creates a list of patterns, for a given object.

    If numCols > 1 is given, a list of list of patterns will be returned.
    """
    if numCols == 1:
      return [self.generatePattern() for _ in xrange(numPatterns)]

    else:
      patterns = []
      for i in xrange(numPatterns):
        patterns.append([self.generatePattern() for _ in xrange(numCols)])
      return patterns


  def init(self, overrides=None, numCols=1):
    """
    Creates the column pooler with specified parameter overrides.

    Except for the specified overrides and problem-specific parameters, used
    parameters are implementation defaults.
    """
    params = {
      "inputWidth": self.inputWidth,
      "numActiveColumnsPerInhArea": self.numOutputActiveBits,
      "columnDimensions": (self.outputWidth,),
      "seed": self.seed,
      "initialPermanence": 0.51,
      "connectedPermanence": 0.6,
      "permanenceIncrement": 0.1,
      "permanenceDecrement": 0.02,
      "minThreshold": 10,
      "predictedSegmentDecrement": 0.004,
      "activationThreshold": 10,
      "maxNewSynapseCount": 20,
      "maxSegmentsPerCell": 255,
      "maxSynapsesPerSegment": 255,
    }
    if overrides is None:
      overrides = {}
    params.update(overrides)

    if numCols == 1:
      self.pooler = MonitoredColumnPooler(**params)
    else:
      # TODO: We need a different seed for each pooler otherwise each one
      # outputs an identical representation. Use random seed for now but ideally
      # we would set different specific seeds for each pooler
      params['seed']=0
      self.poolers = [MonitoredColumnPooler(**params) for _ in xrange(numCols)]


  def _getActiveRepresentation(self):
    """
    Retrieves the current active representation in the pooler.
    """
    if self.pooler is None:
      raise ValueError("No pooler has been instantiated")

    return set(self.pooler.getActiveCells())


  # Multi-column testing

  def learnMultipleColumns(self,
                           feedforwardPatterns,
                           numRepetitions=1,
                           neighborsIndices=None,
                           randomOrder=True,
                           newObject=True):
    """
    Learns a single object, feeding it through the multiple columns.

    Parameters:
    ----------------------------
    Learns a single object, with the provided patterns.

    @param   feedforwardPatterns   (list(list(set)))
             List of proximal input patterns (one for each pooler).


    @param   neighborsIndices      (list(list))
             List of column indices each column received input from.

    @param   numRepetitions        (int)
             Number of times the patterns will be fed

    @param   randomOrder           (bool)
             If true, the order of patterns will be shuffled at each
             repetition

    """
    if newObject:
      for pooler in self.poolers:
        pooler.mmClearHistory()
        pooler.reset()

    # use different set of pattern indices to allow random orders
    indices = [range(len(feedforwardPatterns))] * len(self.poolers)
    representations = [set()] * len(self.poolers)

    # by default, all columns are neighbors
    if neighborsIndices is None:
      neighborsIndices = [
        range(i) + range(i+1, len(self.poolers))
        for i in xrange(len(self.poolers))
      ]

    for _ in xrange(numRepetitions):

      # independently shuffle pattern orders if necessary
      if randomOrder:
        for idx in indices:
          np.random.shuffle(idx)

      for i in xrange(len(indices[0])):
        # get union of relevant lateral representations
        lateralInputs = []
        for col in xrange(len(self.poolers)):
          lateralInputsCol = set()
          for idx in neighborsIndices[col]:
            lateralInputsCol = lateralInputsCol.union(representations[idx])
          lateralInputs.append(lateralInputsCol)

        # Train each column
        for col in xrange(len(self.poolers)):
          self.poolers[col].compute(
            feedforwardInput=feedforwardPatterns[indices[col][i]][col],
            activeExternalCells=lateralInputs[col],
            learn=True
          )

        # update active representations
        representations = self._getActiveRepresentations()
        for i in xrange(len(representations)):
          representations[i] = set([i * self.outputWidth + k \
                                   for k in representations[i]])


  def inferMultipleColumns(self,
                           feedforwardPatterns,
                           activeRepresentations=None,
                           neighborsIndices=None,
                           printMetrics=False,
                           reset=False):
    """
    Feeds a single pattern to the column pooler (as well as an eventual lateral
    pattern).

    Parameters:
    ----------------------------
    @param feedforwardPattern       (list(set))
           Input proximal patterns to the pooler (one for each column)

    @param activeRepresentations    (list(set))
           Active representations in the columns at the previous step.

    @param neighborsIndices         (list(list))
           List of column indices each column received input from.

    @param printMetrics             (bool)
           If true, will print cell metrics

    """
    if reset:
      for pooler in self.poolers:
        pooler.reset()

    # create copy of activeRepresentations to not mutate it
    representations = [None] * len(self.poolers)

    # by default, all columns are neighbors
    if neighborsIndices is None:
      neighborsIndices = [
        range(i) + range(i+1, len(self.poolers))
        for i in xrange(len(self.poolers))
      ]

    for i in xrange(len(self.poolers)):
      if activeRepresentations[i] is not None:
        representations[i] = set(i * self.outputWidth + k \
                                       for k in activeRepresentations[i])

    for col in range(len(self.poolers)):
      lateralInputs = [representations[idx] for idx in neighborsIndices[col]]
      if len(lateralInputs) > 0:
        lateralInputs = set.union(*lateralInputs)
      else:
        lateralInputs = set()

      self.poolers[col].compute(
        feedforwardPatterns[col],
        activeExternalCells=lateralInputs,
        learn=False
      )

    if printMetrics:
      for pooler in self.poolers:
        print pooler.mmPrettyPrintMetrics(
          pooler.mmGetDefaultMetrics()
        )


  def _getActiveRepresentations(self):
    """
    Retrieves the current active representations in the poolers.
    """
    if len(self.poolers) == 0:
      raise ValueError("No pooler has been instantiated")

    return [set(pooler.getActiveCells()) for pooler in self.poolers]


  def _getPredictedActiveCells(self):
    """
    Retrieves the current active representations in the poolers.
    """
    if len(self.poolers) == 0:
      raise ValueError("No pooler has been instantiated")

    return [set(pooler.getActiveCells()) & set(pooler.tm.getPredictiveCells())\
            for pooler in self.poolers]
class ExtensiveColumnPoolerTest(unittest.TestCase):
  """
  Algorithmic tests for the ColumnPooler region.

  Each test actually tests multiple aspects of the algorithm. For more
  atomic tests refer to column_pooler_unit_test.

  The notation for objects is the following:
    object{patternA, patternB, ...}

  In these tests, the proximally-fed SDR's are simulated as unique (location,
  feature) pairs regardless of actual locations and features, unless stated
  otherwise.
  """

  inputWidth = 2048 * 8
  numInputActiveBits = int(0.02 * inputWidth)
  outputWidth = 2048
  numOutputActiveBits = 40
  seed = 42


  def testNewInputs(self):
    """
    Checks that the behavior is correct when facing unseed inputs.
    """
    self.init()

    # feed the first input, a random SDR should be generated
    initialPattern = self.generateObject(1)
    self.learn(initialPattern, numRepetitions=1, newObject=True)
    representation = self._getActiveRepresentation()
    self.assertEqual(
      len(representation),
      self.numOutputActiveBits,
      "The generated representation is incorrect"
    )

    # feed a new input for the same object, the previous SDR should persist
    newPattern = self.generateObject(1)
    self.learn(newPattern, numRepetitions=1, newObject=False)
    newRepresentation = self._getActiveRepresentation()
    self.assertNotEqual(initialPattern, newPattern)
    self.assertEqual(
      newRepresentation,
      representation,
      "The SDR did not persist when learning the same object"
    )

    # without sensory input, the SDR should persist as well
    emptyPattern = [set()]
    self.learn(emptyPattern, numRepetitions=1, newObject=False)
    newRepresentation = self._getActiveRepresentation()
    self.assertEqual(
      newRepresentation,
      representation,
      "The SDR did not persist after an empty input."
    )


  def testLearnSinglePattern(self):
    """
    A single pattern is learnt for a single object.
    Objects: A{X, Y}
    """
    self.init()

    object = self.generateObject(1)
    self.learn(object, numRepetitions=1, newObject=True)
    # check that the active representation is sparse
    representation = self._getActiveRepresentation()
    self.assertEqual(
      len(representation),
      self.numOutputActiveBits,
      "The generated representation is incorrect"
    )

    # check that the pattern was correctly learnt
    self.infer(feedforwardPattern=object[0])
    self.assertEqual(
      self._getActiveRepresentation(),
      representation,
      "The pooled representation is not stable"
    )

    # present new pattern for same object
    # it should be mapped to the same representation
    newPattern = [self.generatePattern()]
    self.learn(newPattern, numRepetitions=1, newObject=False)
    # check that the active representation is sparse
    newRepresentation = self._getActiveRepresentation()
    self.assertEqual(
      newRepresentation,
      representation,
      "The new pattern did not map to the same object representation"
    )

    # check that the pattern was correctly learnt and is stable
    self.infer(feedforwardPattern=object[0])
    self.assertEqual(
      self._getActiveRepresentation(),
      representation,
      "The pooled representation is not stable"
    )


  def testLearnSingleObject(self):
    """
    Many patterns are learnt for a single object.
    Objects: A{P, Q, R, S, T}
    """
    self.init()

    object = self.generateObject(numPatterns=5)
    self.learn(object, numRepetitions=1, randomOrder=True, newObject=True)
    representation = self._getActiveRepresentation()

    # check that all patterns map to the same object
    for pattern in object:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representation,
        "The pooled representation is not stable"
      )

    # if activity stops, check that the representation persists
    self.infer(feedforwardPattern=set())
    self.assertEqual(
      self._getActiveRepresentation(),
      representation,
      "The pooled representation did not persist"
    )


  def testLearnTwoObjectNoCommonPattern(self):
    """
    Same test as before, using two objects, without common pattern.
    Objects: A{P, Q, R, S,T}   B{V, W, X, Y, Z}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=1, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    self.learn(objectB, numRepetitions=1, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    self.assertNotEqual(representationA, representationB)

    # check that all patterns map to the same object
    for pattern in objectA:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns map to the same object
    for pattern in objectB:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # feed union of patterns in object A
    pattern = objectA[0] | objectA[1]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns in objects A and B
    pattern = objectA[0] | objectB[0]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )



  def testLearnTwoObjectsOneCommonPattern(self):
    """
    Same test as before, except the two objects share a pattern
    Objects: A{P, Q, R, S,T}   B{P, W, X, Y, Z}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=1, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    objectB[0] = objectA[0]
    self.learn(objectB, numRepetitions=1, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    self.assertNotEqual(representationA, representationB)
    # very small overlap
    self.assertLessEqual(len(representationA & representationB), 3)

    # check that all patterns except the common one map to the same object
    for pattern in objectA[1:]:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectB[1:]:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # feed shared pattern
    pattern = objectA[0]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )

    # feed union of patterns in object A
    pattern = objectA[1] | objectA[2]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns in objects A and B
    pattern = objectA[1] | objectB[1]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )

  def testLearnThreeObjectsOneCommonPattern(self):
    """
    Same test as before, with three objects
    Objects: A{P, Q, R, S,T}   B{P, W, X, Y, Z}   C{W, H, I, K, L}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=1, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    objectB[0] = objectA[0]
    self.learn(objectB, numRepetitions=1, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    objectC = self.generateObject(numPatterns=5)
    objectC[0] = objectB[1]
    self.learn(objectC, numRepetitions=1, randomOrder=True, newObject=True)
    representationC = self._getActiveRepresentation()

    self.assertNotEquals(representationA, representationB, representationC)
    # very small overlap
    self.assertLessEqual(len(representationA & representationB), 3)
    self.assertLessEqual(len(representationB & representationC), 3)
    self.assertLessEqual(len(representationA & representationC), 3)


    # check that all patterns except the common one map to the same object
    for pattern in objectA[1:]:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectB[2:]:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectC[1:]:
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationC,
        "The pooled representation for the third object is not stable"
      )

    # feed shared pattern between A and B
    pattern = objectA[0]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )

    # feed shared pattern between B and C
    pattern = objectB[1]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationB | representationC,
      "The active representation is incorrect"
    )

    # feed union of patterns in object A
    pattern = objectA[1] | objectA[2]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns to activate all objects
    pattern = objectA[1] | objectB[1]
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB | representationC,
      "The active representation is incorrect"
    )


  def testLearnThreeObjectsOneCommonPatternSpatialNoise(self):
    """
    Same test as before, with three objects
    Objects: A{P, Q, R, S,T}   B{P, W, X, Y, Z}   C{W, H, I, K, L}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=1, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    objectB[0] = objectA[0]
    self.learn(objectB, numRepetitions=1, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    objectC = self.generateObject(numPatterns=5)
    objectC[0] = objectB[1]
    self.learn(objectC, numRepetitions=1, randomOrder=True, newObject=True)
    representationC = self._getActiveRepresentation()

    self.assertNotEquals(representationA, representationB, representationC)
    # very small overlap
    self.assertLessEqual(len(representationA & representationB), 3)
    self.assertLessEqual(len(representationB & representationC), 3)
    self.assertLessEqual(len(representationA & representationC), 3)


    # check that all patterns except the common one map to the same object
    for pattern in objectA[1:]:
      noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
      self.infer(feedforwardPattern=noisyPattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectB[2:]:
      noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
      self.infer(feedforwardPattern=noisyPattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectC[1:]:
      noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
      self.infer(feedforwardPattern=noisyPattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationC,
        "The pooled representation for the third object is not stable"
      )

    # feed shared pattern between A and B
    pattern = objectA[0]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )

    # feed shared pattern between B and C
    pattern = objectB[1]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationB | representationC,
      "The active representation is incorrect"
    )

    # feed union of patterns in object A
    pattern = objectA[1] | objectA[2]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns to activate all objects
    pattern = objectA[1] | objectB[1]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB | representationC,
      "The active representation is incorrect"
    )


  def setUp(self):
    """
    Sets up the test.
    """
    self.pooler = None
    self.proximalPatternMachine = PatternMachine(
      n=self.inputWidth,
      w=self.numOutputActiveBits,
      num=200,
      seed=self.seed
    )
    self.patternId = 0
    np.random.seed(self.seed)


  # Wrappers around ColumnPooler API

  def learn(self,
            feedforwardPatterns,
            lateralPatterns=None,
            numRepetitions=1,
            randomOrder=True,
            newObject=True):
    """
    Parameters:
    ----------------------------
    Learns a single object, with the provided patterns.

    @param   feedforwardPatterns   (list(set))
             List of proximal input patterns

    @param   lateralPatterns       (list(list(set)))
             List of distal input patterns, or None. If no lateral input is
             used. The outer list is expected to have the same length as
             feedforwardPatterns, whereas each inner list's length is the
             number of cortical columns which are distally connected to the
             pooler.

    @param   numRepetitions        (int)
             Number of times the patterns will be fed

    @param   randomOrder           (bool)
             If true, the order of patterns will be shuffled at each
             repetition

    """
    if newObject:
      self.pooler.mmClearHistory()
      self.pooler.reset()

    # set-up
    indices = range(len(feedforwardPatterns))
    if lateralPatterns is None:
      lateralPatterns = [None] * len(feedforwardPatterns)

    for _ in xrange(numRepetitions):
      if randomOrder:
        np.random.shuffle(indices)

      for idx in indices:
        self.pooler.compute(feedforwardPatterns[idx],
                            activeExternalCells=lateralPatterns[idx],
                            learn=True)


  def infer(self,
            feedforwardPattern,
            lateralPatterns=None,
            printMetrics=False):
    """
    Feeds a single pattern to the column pooler (as well as an eventual lateral
    pattern).

    Parameters:
    ----------------------------
    @param feedforwardPattern       (set)
           Input proximal pattern to the pooler

    @param lateralPatterns          (list(set))
           Input dislal patterns to the pooler (one for each neighboring CC's)

    @param printMetrics             (bool)
           If true, will print cell metrics

    """
    self.pooler.compute(feedforwardPattern,
                        activeExternalCells=lateralPatterns,
                        learn=False)

    if printMetrics:
      print self.pooler.mmPrettyPrintMetrics(
        self.pooler.mmGetDefaultMetrics()
      )


  # Helper functions

  def generatePattern(self):
    """
    Returns a random proximal input pattern.
    """
    pattern = self.proximalPatternMachine.get(self.patternId)
    self.patternId += 1
    return pattern


  def generateObject(self, numPatterns):
    """
    Creates a list of patterns, for a given object.
    """
    return [self.generatePattern() for _ in xrange(numPatterns)]


  def init(self, overrides=None):
    """
    Creates the column pooler with specified parameter overrides.

    Except for the specified overrides and problem-specific parameters, used
    parameters are implementation defaults.
    """
    params = {
      "inputWidth": self.inputWidth,
      "numActivecolumnsPerInhArea": self.numOutputActiveBits,
      "columnDimensions": (self.outputWidth,),
      "seed": self.seed,
      "learnOnOneCell": False
    }
    if overrides is None:
      overrides = {}
    params.update(overrides)

    self.pooler = MonitoredColumnPooler(**params)


  def _getActiveRepresentation(self):
    """
    Retrieves the current active representation in the pooler.
    """
    if self.pooler is None:
      raise ValueError("No pooler has been instantiated")

    return set(self.pooler.getActiveCells())
class ExtensiveColumnPoolerTest(unittest.TestCase):
  """
  Algorithmic tests for the ColumnPooler region.

  Each test actually tests multiple aspects of the algorithm. For more
  atomic tests refer to column_pooler_unit_test.

  The notation for objects is the following:
    object{patternA, patternB, ...}

  In these tests, the proximally-fed SDR's are simulated as unique (location,
  feature) pairs regardless of actual locations and features, unless stated
  otherwise.
  """

  inputWidth = 2048 * 8
  numInputActiveBits = int(0.02 * inputWidth)
  outputWidth = 4096
  numOutputActiveBits = 40
  seed = 42


  def testNewInputs(self):
    """
    Checks that the behavior is correct when facing unseed inputs.
    """
    self.init()

    # feed the first input, a random SDR should be generated
    initialPattern = self.generateObject(1)
    self.learn(initialPattern, numRepetitions=1, newObject=True)
    representation = self._getActiveRepresentation()
    self.assertEqual(
      len(representation),
      self.numOutputActiveBits,
      "The generated representation is incorrect"
    )

    # feed a new input for the same object, the previous SDR should persist
    newPattern = self.generateObject(1)
    self.learn(newPattern, numRepetitions=1, newObject=False)
    newRepresentation = self._getActiveRepresentation()
    self.assertNotEqual(initialPattern, newPattern)
    self.assertEqual(
      newRepresentation,
      representation,
      "The SDR did not persist when learning the same object"
    )

    # without sensory input, the SDR should persist as well
    emptyPattern = [set()]
    self.learn(emptyPattern, numRepetitions=1, newObject=False)
    newRepresentation = self._getActiveRepresentation()
    self.assertEqual(
      newRepresentation,
      representation,
      "The SDR did not persist after an empty input."
    )


  def testLearnSinglePattern(self):
    """
    A single pattern is learnt for a single object.
    Objects: A{X, Y}
    """
    self.init()

    object = self.generateObject(1)
    self.learn(object, numRepetitions=2, newObject=True)
    # check that the active representation is sparse
    representation = self._getActiveRepresentation()
    self.assertEqual(
      len(representation),
      self.numOutputActiveBits,
      "The generated representation is incorrect"
    )

    # check that the pattern was correctly learnt
    self.pooler.reset()
    self.infer(feedforwardPattern=object[0])
    self.assertEqual(
      self._getActiveRepresentation(),
      representation,
      "The pooled representation is not stable"
    )

    # present new pattern for same object
    # it should be mapped to the same representation
    newPattern = [self.generatePattern()]
    self.learn(newPattern, numRepetitions=2, newObject=False)
    # check that the active representation is sparse
    newRepresentation = self._getActiveRepresentation()
    self.assertEqual(
      newRepresentation,
      representation,
      "The new pattern did not map to the same object representation"
    )

    # check that the pattern was correctly learnt and is stable
    self.pooler.reset()
    self.infer(feedforwardPattern=object[0])
    self.assertEqual(
      self._getActiveRepresentation(),
      representation,
      "The pooled representation is not stable"
    )


  def testLearnSingleObject(self):
    """
    Many patterns are learnt for a single object.
    Objects: A{P, Q, R, S, T}
    """
    self.init()

    object = self.generateObject(numPatterns=5)
    self.learn(object, numRepetitions=2, randomOrder=True, newObject=True)
    representation = self._getActiveRepresentation()

    # check that all patterns map to the same object
    for pattern in object:
      self.pooler.reset()
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representation,
        "The pooled representation is not stable"
      )

    # if activity stops, check that the representation persists
    self.infer(feedforwardPattern=set())
    self.assertEqual(
      self._getActiveRepresentation(),
      representation,
      "The pooled representation did not persist"
    )


  def testLearnTwoObjectNoCommonPattern(self):
    """
    Same test as before, using two objects, without common pattern.
    Objects: A{P, Q, R, S,T}   B{V, W, X, Y, Z}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=3, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    self.learn(objectB, numRepetitions=3, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    self.assertNotEqual(representationA, representationB)

    # check that all patterns map to the same object
    for pattern in objectA:
      self.pooler.reset()
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns map to the same object
    for pattern in objectB:
      self.pooler.reset()
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # feed union of patterns in object A
    pattern = objectA[0] | objectA[1]
    self.pooler.reset()
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns in objects A and B
    pattern = objectA[0] | objectB[0]
    self.pooler.reset()
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )


  def testLearnTwoObjectsOneCommonPattern(self):
    """
    Same test as before, except the two objects share a pattern
    Objects: A{P, Q, R, S,T}   B{P, W, X, Y, Z}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=3, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    objectB[0] = objectA[0]
    self.learn(objectB, numRepetitions=3, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    self.assertNotEqual(representationA, representationB)
    # very small overlap
    self.assertLessEqual(len(representationA & representationB), 3)

    # check that all patterns except the common one map to the same object
    for pattern in objectA[1:]:
      self.pooler.reset()
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectB[1:]:
      self.pooler.reset()
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # feed shared pattern
    pattern = objectA[0]
    self.pooler.reset()
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )

    # feed union of patterns in object A
    pattern = objectA[1] | objectA[2]
    self.pooler.reset()
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns in objects A and B
    pattern = objectA[1] | objectB[1]
    self.pooler.reset()
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )


  def testLearnThreeObjectsOneCommonPattern(self):
    """
    Same test as before, with three objects
    Objects: A{P, Q, R, S,T}   B{P, W, X, Y, Z}   C{W, H, I, K, L}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=3, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    objectB[0] = objectA[0]
    self.learn(objectB, numRepetitions=3, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    objectC = self.generateObject(numPatterns=5)
    objectC[0] = objectB[1]
    self.learn(objectC, numRepetitions=3, randomOrder=True, newObject=True)
    representationC = self._getActiveRepresentation()

    self.assertNotEquals(representationA, representationB, representationC)
    # very small overlap
    self.assertLessEqual(len(representationA & representationB), 3)
    self.assertLessEqual(len(representationB & representationC), 3)
    self.assertLessEqual(len(representationA & representationC), 3)


    # check that all patterns except the common one map to the same object
    for pattern in objectA[1:]:
      self.pooler.reset()
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectB[2:]:
      self.pooler.reset()
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectC[1:]:
      self.pooler.reset()
      self.infer(feedforwardPattern=pattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationC,
        "The pooled representation for the third object is not stable"
      )

    # feed shared pattern between A and B
    pattern = objectA[0]
    self.pooler.reset()
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )

    # feed shared pattern between B and C
    pattern = objectB[1]
    self.pooler.reset()
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationB | representationC,
      "The active representation is incorrect"
    )

    # feed union of patterns in object A
    pattern = objectA[1] | objectA[2]
    self.pooler.reset()
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns to activate all objects
    pattern = objectA[1] | objectB[1]
    self.pooler.reset()
    self.infer(feedforwardPattern=pattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB | representationC,
      "The active representation is incorrect"
    )


  def testLearnThreeObjectsOneCommonPatternSpatialNoise(self):
    """
    Same test as before, with three objects
    Objects: A{P, Q, R, S,T}   B{P, W, X, Y, Z}   C{W, H, I, K, L}
    """
    self.init()

    objectA = self.generateObject(numPatterns=5)
    self.learn(objectA, numRepetitions=3, randomOrder=True, newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    objectB[0] = objectA[0]
    self.learn(objectB, numRepetitions=3, randomOrder=True, newObject=True)
    representationB = self._getActiveRepresentation()

    objectC = self.generateObject(numPatterns=5)
    objectC[0] = objectB[1]
    self.learn(objectC, numRepetitions=3, randomOrder=True, newObject=True)
    representationC = self._getActiveRepresentation()

    self.assertNotEquals(representationA, representationB, representationC)
    # very small overlap
    self.assertLessEqual(len(representationA & representationB), 3)
    self.assertLessEqual(len(representationB & representationC), 3)
    self.assertLessEqual(len(representationA & representationC), 3)


    # check that all patterns except the common one map to the same object
    for pattern in objectA[1:]:
      noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
      self.pooler.reset()
      self.infer(feedforwardPattern=noisyPattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectB[2:]:
      noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
      self.pooler.reset()
      self.infer(feedforwardPattern=noisyPattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )

    # check that all patterns except the common one map to the same object
    for pattern in objectC[1:]:
      noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
      self.pooler.reset()
      self.infer(feedforwardPattern=noisyPattern)
      self.assertEqual(
        self._getActiveRepresentation(),
        representationC,
        "The pooled representation for the third object is not stable"
      )

    # feed shared pattern between A and B
    pattern = objectA[0]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.pooler.reset()
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB,
      "The active representation is incorrect"
    )

    # feed shared pattern between B and C
    pattern = objectB[1]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.pooler.reset()
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationB | representationC,
      "The active representation is incorrect"
    )

    # feed union of patterns in object A
    pattern = objectA[1] | objectA[2]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.pooler.reset()
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA,
      "The active representation is incorrect"
    )

    # feed unions of patterns to activate all objects
    pattern = objectA[1] | objectB[1]
    noisyPattern = self.proximalPatternMachine.addNoise(pattern, 0.05)
    self.pooler.reset()
    self.infer(feedforwardPattern=noisyPattern)
    self.assertEqual(
      self._getActiveRepresentation(),
      representationA | representationB | representationC,
      "The active representation is incorrect"
    )


  def testInferObjectOverTime(self):
    """Infer an object after touching only ambiguous points."""
    self.init()

    patterns = [self.generatePattern() for _ in xrange(3)]

    objectA = [patterns[0], patterns[1]]
    objectB = [patterns[1], patterns[2]]
    objectC = [patterns[2], patterns[0]]

    self.learn(objectA, numRepetitions=3, newObject=True)
    representationA = set(self.pooler.getActiveCells())
    self.learn(objectB, numRepetitions=3, newObject=True)
    representationB = set(self.pooler.getActiveCells())
    self.learn(objectC, numRepetitions=3, newObject=True)
    representationC = set(self.pooler.getActiveCells())

    self.pooler.reset()
    self.infer(patterns[0])
    self.assertEqual(set(self.pooler.getActiveCells()),
                     representationA | representationC)
    self.infer(patterns[1])
    self.assertEqual(set(self.pooler.getActiveCells()),
                     representationA)


  def testLearnOneObjectInTwoColumns(self):
    """Learns one object in two different columns."""
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    objectARepresentations = self._getActiveRepresentations()

    for pooler in self.poolers:
      pooler.reset()

    for patterns in objectA:
      for i in xrange(3):
        activeRepresentations = self._getActiveRepresentations()

        self.inferMultipleColumns(
          feedforwardPatterns=patterns,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        if i > 0:
          self.assertEqual(activeRepresentations,
                           self._getActiveRepresentations())
          self.assertEqual(objectARepresentations,
                           self._getActiveRepresentations())


  def testLearnTwoObjectsInTwoColumnsNoCommonPattern(self):
    """Learns two objects in two different columns."""
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)
    objectB = self.generateObject(numPatterns=5, numCols=2)

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsA = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectB,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True,
    )
    activeRepresentationsB = self._getActiveRepresentations()

    for pooler in self.poolers:
      pooler.reset()

    # check inference for object A
    for patternsA in objectA:
      for i in xrange(3):
        activeRepresentations = self._getActiveRepresentations()
        self.inferMultipleColumns(
          feedforwardPatterns=patternsA,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        self.assertEqual(
          activeRepresentationsA,
          self._getActiveRepresentations()
        )

    for pooler in self.poolers:
      pooler.reset()

    # check inference for object B
    for patternsB in objectB:
      for i in xrange(3):
        activeRepresentations = self._getActiveRepresentations()
        self.inferMultipleColumns(
          feedforwardPatterns=patternsB,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices
        )

        self.assertEqual(
          activeRepresentationsB,
          self._getActiveRepresentations()
        )


  def testLearnTwoObjectsInTwoColumnsOneCommonPattern(self):
    """Learns two objects in two different columns, with a common pattern."""
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)
    objectB = self.generateObject(numPatterns=5, numCols=2)

    # second pattern in column 0 is shared
    objectB[1][0] = objectA[1][0]

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsA = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectB,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsB = self._getActiveRepresentations()

    # check inference for object A
    # for the first pattern, the distal predictions won't be correct
    # for the second one, the prediction will be unique thanks to the
    # distal predictions from the other column which has no ambiguity
    for pooler in self.poolers:
      pooler.reset()

    for patternsA in objectA:
      for i in xrange(3):
        activeRepresentations = self._getActiveRepresentations()
        self.inferMultipleColumns(
          feedforwardPatterns=patternsA,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        self.assertEqual(
          activeRepresentationsA,
          self._getActiveRepresentations()
        )

    for pooler in self.poolers:
      pooler.reset()

    # check inference for object B
    for patternsB in objectB:
      for i in xrange(3):
        activeRepresentations = self._getActiveRepresentations()
        self.inferMultipleColumns(
          feedforwardPatterns=patternsB,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices
        )

        self.assertEqual(
          activeRepresentationsB,
          self._getActiveRepresentations()
        )


  def testLearnTwoObjectsInTwoColumnsOneCommonPatternEmptyFirstInput(self):
    """Learns two objects in two different columns, with a common pattern."""
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)
    objectB = self.generateObject(numPatterns=5, numCols=2)

    # second pattern in column 0 is shared
    objectB[1][0] = objectA[1][0]

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsA = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectB,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsB = self._getActiveRepresentations()

    # check inference for object A
    for pooler in self.poolers:
      pooler.reset()

    firstPattern = True
    for patternsA in objectA:
      activeRepresentations = self._getActiveRepresentations()
      if firstPattern:
        self.inferMultipleColumns(
          feedforwardPatterns=[set(), patternsA[1]],
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        desiredRepresentation = [set(), activeRepresentationsA[1]]
      else:
        self.inferMultipleColumns(
          feedforwardPatterns=patternsA,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        desiredRepresentation = activeRepresentationsA
      self.assertEqual(
        desiredRepresentation,
        self._getActiveRepresentations()
      )


  def testPersistence(self):
    """After learning, representation should persist in L2 without input."""
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    objectARepresentations = self._getActiveRepresentations()

    for pooler in self.poolers:
      pooler.reset()

    for patterns in objectA:
      for i in xrange(3):

        # replace third pattern for column 2 by empty pattern
        if i == 2:
          patterns[1] = set()

        activeRepresentations = self._getActiveRepresentations()

        self.inferMultipleColumns(
          feedforwardPatterns=patterns,
          activeRepresentations=activeRepresentations,
          neighborsIndices=neighborsIndices,
        )
        if i > 0:
          self.assertEqual(activeRepresentations,
                           self._getActiveRepresentations())
          self.assertEqual(objectARepresentations,
                           self._getActiveRepresentations())


  def testLateralDisambiguation(self):
    """Lateral disambiguation using a constant simulated distal input."""
    self.init(overrides={
      "lateralInputWidths": [self.inputWidth],
    })

    objectA = self.generateObject(numPatterns=5)
    lateralInputA = [[set()]] + [[self.generatePattern()] for _ in xrange(4)]
    self.learn(objectA,
               lateralPatterns=lateralInputA,
               numRepetitions=3,
               randomOrder=True,
               newObject=True)
    representationA = self._getActiveRepresentation()

    objectB = self.generateObject(numPatterns=5)
    objectB[3] = objectA[3]
    lateralInputB = [[set()]] + [[self.generatePattern()] for _ in xrange(4)]
    self.learn(objectB,
               lateralPatterns=lateralInputB,
               numRepetitions=3,
               randomOrder=True,
               newObject=True)
    representationB = self._getActiveRepresentation()

    self.assertNotEqual(representationA, representationB)
    # very small overlap
    self.assertLessEqual(len(representationA & representationB), 3)

    # no ambiguity with lateral input
    for pattern in objectA:
      self.pooler.reset()
      self.infer(feedforwardPattern=pattern, lateralInputs=lateralInputA[-1])
      self.assertEqual(
        self._getActiveRepresentation(),
        representationA,
        "The pooled representation for the first object is not stable"
      )

    # no ambiguity with lateral input
    for pattern in objectB:
      self.pooler.reset()
      self.infer(feedforwardPattern=pattern, lateralInputs=lateralInputB[-1])
      self.assertEqual(
        self._getActiveRepresentation(),
        representationB,
        "The pooled representation for the second object is not stable"
      )


  def testLateralContestResolved(self):
    """
    Infer an object via lateral disambiguation even if some other columns have
    similar ambiguity.

    """

    self.init(overrides={"lateralInputWidths": [self.inputWidth,
                                                self.inputWidth]})

    patterns = [self.generatePattern() for _ in xrange(3)]

    objectA = [patterns[0], patterns[1]]
    objectB = [patterns[1], patterns[2]]

    lateralInput1A = self.generatePattern()
    lateralInput2A = self.generatePattern()
    lateralInput1B = self.generatePattern()
    lateralInput2B = self.generatePattern()

    self.learn(objectA, lateralPatterns=[[lateralInput1A, lateralInput2A]]*2,
               numRepetitions=3, newObject=True)
    representationA = set(self.pooler.getActiveCells())
    self.learn(objectB, lateralPatterns=[[lateralInput1B, lateralInput2B]]*2,
               numRepetitions=3, newObject=True)
    representationB = set(self.pooler.getActiveCells())

    self.pooler.reset()

    # This column will say A | B
    # One lateral column says A | B
    # Another lateral column says A
    self.infer(patterns[1], lateralInputs=[(), ()])
    self.infer(patterns[1], lateralInputs=[lateralInput1A | lateralInput1B,
                                           lateralInput2A])


    self.assertEqual(set(self.pooler.getActiveCells()), representationA)





  @unittest.skip("Fails, need to discuss")
  def testMultiColumnCompetition(self):
    """Competition between multiple conflicting lateral inputs."""
    self.init(numCols=4)
    neighborsIndices = [[1, 2, 3], [0, 2, 3], [0, 1, 3], [0, 1, 2]]

    objectA = self.generateObject(numPatterns=5, numCols=4)
    objectB = self.generateObject(numPatterns=5, numCols=4)

    # second pattern in column 0 is shared
    objectB[1][0] = objectA[1][0]

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsA = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectB,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsB = self._getActiveRepresentations()

    # check inference for object A
    # for the first pattern, the distal predictions won't be correct
    # for the second one, the prediction will be unique thanks to the
    # distal predictions from the other column which has no ambiguity
    for pooler in self.poolers:
      pooler.reset()

    # sensed patterns will be mixed
    sensedPatterns = objectA[1][:-1] + [objectA[1][-1] | objectB[1][-1]]

    # feed sensed patterns first time
    # every one feels the correct object, except first column which feels
    # the union (reminder: lateral input are delayed)
    activeRepresentations = self._getActiveRepresentations()
    self.inferMultipleColumns(
      feedforwardPatterns=sensedPatterns,
      activeRepresentations=activeRepresentations,
      neighborsIndices=neighborsIndices,
    )
    firstSensedRepresentations = [
      activeRepresentationsA[0] | activeRepresentationsB[0],
      activeRepresentationsA[1],
      activeRepresentationsA[2],
      activeRepresentationsA[3] | activeRepresentationsB[3]
    ]
    self.assertEqual(
      firstSensedRepresentations,
      self._getActiveRepresentations()
    )

    # feed sensed patterns second time
    # the distal predictions are still ambiguous in C1, but disambiguated
    # in C4
    activeRepresentations = self._getActiveRepresentations()
    self.inferMultipleColumns(
      feedforwardPatterns=sensedPatterns,
      activeRepresentations=activeRepresentations,
      neighborsIndices=neighborsIndices,
    )
    secondSensedRepresentations = [
      activeRepresentationsA[0] | activeRepresentationsB[0],
      activeRepresentationsA[1],
      activeRepresentationsA[2],
      activeRepresentationsA[3]
    ]
    self.assertEqual(
      secondSensedRepresentations,
      self._getActiveRepresentations()
    )

    # feed sensed patterns third time
    # this time, it is all disambiguated
    activeRepresentations = self._getActiveRepresentations()
    self.inferMultipleColumns(
      feedforwardPatterns=sensedPatterns,
      activeRepresentations=activeRepresentations,
      neighborsIndices=neighborsIndices,
    )
    self.assertEqual(
      activeRepresentationsA,
      self._getActiveRepresentations()
    )


  def testMutualDisambiguationThroughUnions(self):
    """
    Learns three object in two different columns.

    Feed ambiguous sensations, A u B and B u C. The system should narrow down
    to B.
    """
    self.init(numCols=2)
    neighborsIndices = [[1], [0]]

    objectA = self.generateObject(numPatterns=5, numCols=2)
    objectB = self.generateObject(numPatterns=5, numCols=2)
    objectC = self.generateObject(numPatterns=5, numCols=2)

    # learn object
    self.learnMultipleColumns(
      objectA,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsA = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectB,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsB = self._getActiveRepresentations()

    # learn object
    self.learnMultipleColumns(
      objectC,
      numRepetitions=3,
      neighborsIndices=neighborsIndices,
      randomOrder=True,
      newObject=True
    )
    activeRepresentationsC = self._getActiveRepresentations()

    # create sensed patterns (ambiguous)
    sensedPatterns = [objectA[1][0] | objectB[1][0],
                      objectB[2][1] | objectC[2][1]]

    for pooler in self.poolers:
      pooler.reset()

    # feed sensed patterns first time
    # the L2 representations should be ambiguous
    activeRepresentations = self._getActiveRepresentations()
    self.inferMultipleColumns(
      feedforwardPatterns=sensedPatterns,
      activeRepresentations=activeRepresentations,
      neighborsIndices=neighborsIndices,
    )
    firstRepresentations = [
      activeRepresentationsA[0] | activeRepresentationsB[0],
      activeRepresentationsB[1] | activeRepresentationsC[1]
    ]
    self.assertEqual(
      firstRepresentations,
      self._getActiveRepresentations()
    )

    # feed a second time, distal predictions should disambiguate
    activeRepresentations = self._getActiveRepresentations()
    self.inferMultipleColumns(
      feedforwardPatterns=sensedPatterns,
      activeRepresentations=activeRepresentations,
      neighborsIndices=neighborsIndices,
    )

    # check that representations are unique, being slightly tolerant
    self.assertLessEqual(
      len(self._getActiveRepresentations()[0] - activeRepresentationsB[0]),
      5,
    )

    self.assertLessEqual(
      len(self._getActiveRepresentations()[1] - activeRepresentationsB[1]),
      5,
    )

    self.assertGreaterEqual(
      len(self._getActiveRepresentations()[0] & activeRepresentationsB[0]),
      35,
    )

    self.assertGreaterEqual(
      len(self._getActiveRepresentations()[1] & activeRepresentationsB[1]),
      35,
    )


  def setUp(self):
    """
    Sets up the test.
    """
    # single column case
    self.pooler = None

    # multi column case
    self.poolers = []

    # create pattern machine
    self.proximalPatternMachine = PatternMachine(
      n=self.inputWidth,
      w=self.numOutputActiveBits,
      num=200,
      seed=self.seed
    )

    self.patternId = 0
    np.random.seed(self.seed)


  # Wrappers around ColumnPooler API

  def learn(self,
            feedforwardPatterns,
            lateralPatterns=None,
            numRepetitions=1,
            randomOrder=True,
            newObject=True):
    """
    Parameters:
    ----------------------------
    Learns a single object, with the provided patterns.

    @param   feedforwardPatterns   (list(set))
             List of proximal input patterns

    @param   lateralPatterns       (list(list(iterable)))
             List of distal input patterns, or None. If no lateral input is
             used. The outer list is expected to have the same length as
             feedforwardPatterns, whereas each inner list's length is the
             number of cortical columns which are distally connected to the
             pooler.

    @param   numRepetitions        (int)
             Number of times the patterns will be fed

    @param   randomOrder           (bool)
             If true, the order of patterns will be shuffled at each
             repetition

    """
    if newObject:
      self.pooler.mmClearHistory()
      self.pooler.reset()

    # set-up
    indices = range(len(feedforwardPatterns))
    if lateralPatterns is None:
      lateralPatterns = [[] for _ in xrange(len(feedforwardPatterns))]

    for _ in xrange(numRepetitions):
      if randomOrder:
        np.random.shuffle(indices)

      for idx in indices:
        self.pooler.compute(sorted(feedforwardPatterns[idx]),
                            [sorted(lateralPattern)
                             for lateralPattern in lateralPatterns[idx]],
                            learn=True)


  def infer(self,
            feedforwardPattern,
            lateralInputs=(),
            printMetrics=False):
    """
    Feeds a single pattern to the column pooler (as well as an eventual lateral
    pattern).

    Parameters:
    ----------------------------
    @param feedforwardPattern       (set)
           Input proximal pattern to the pooler

    @param lateralInputs            (list(set))
           Input distal patterns to the pooler (one for each neighboring CC's)


    @param printMetrics             (bool)
           If true, will print cell metrics

    """
    self.pooler.compute(sorted(feedforwardPattern),
                        [sorted(lateralInput)
                         for lateralInput in lateralInputs],
                        learn=False)

    if printMetrics:
      print self.pooler.mmPrettyPrintMetrics(
        self.pooler.mmGetDefaultMetrics()
      )


  # Helper functions

  def generatePattern(self):
    """
    Returns a random proximal input pattern.
    """
    pattern = self.proximalPatternMachine.get(self.patternId)
    self.patternId += 1
    return pattern


  def generateObject(self, numPatterns, numCols=1):
    """
    Creates a list of patterns, for a given object.

    If numCols > 1 is given, a list of list of patterns will be returned.
    """
    if numCols == 1:
      return [self.generatePattern() for _ in xrange(numPatterns)]

    else:
      patterns = []
      for i in xrange(numPatterns):
        patterns.append([self.generatePattern() for _ in xrange(numCols)])
      return patterns


  def init(self, overrides=None, numCols=1):
    """
    Creates the column pooler with specified parameter overrides.

    Except for the specified overrides and problem-specific parameters, used
    parameters are implementation defaults.
    """
    params = {
      "inputWidth": self.inputWidth,
      "lateralInputWidths": [self.outputWidth]*(numCols-1),
      "cellCount": self.outputWidth,
      "sdrSize": self.numOutputActiveBits,
      "minThresholdProximal": 10,
      "sampleSizeProximal": 20,
      "connectedPermanenceProximal": 0.6,
      "initialDistalPermanence": 0.51,
      "activationThresholdDistal": 10,
      "sampleSizeDistal": 20,
      "connectedPermanenceDistal": 0.6,
      "seed": self.seed,
    }
    if overrides is None:
      overrides = {}
    params.update(overrides)

    if numCols == 1:
      self.pooler = MonitoredColumnPooler(**params)
    else:
      # TODO: We need a different seed for each pooler otherwise each one
      # outputs an identical representation. Use random seed for now but ideally
      # we would set different specific seeds for each pooler
      params['seed']=0
      self.poolers = [MonitoredColumnPooler(**params) for _ in xrange(numCols)]


  def _getActiveRepresentation(self):
    """
    Retrieves the current active representation in the pooler.
    """
    if self.pooler is None:
      raise ValueError("No pooler has been instantiated")

    return set(self.pooler.getActiveCells())


  # Multi-column testing

  def learnMultipleColumns(self,
                           feedforwardPatterns,
                           numRepetitions=1,
                           neighborsIndices=None,
                           randomOrder=True,
                           newObject=True):
    """
    Learns a single object, feeding it through the multiple columns.

    Parameters:
    ----------------------------
    Learns a single object, with the provided patterns.

    @param   feedforwardPatterns   (list(list(set)))
             List of proximal input patterns (one for each pooler).


    @param   neighborsIndices      (list(list))
             List of column indices each column received input from.

    @param   numRepetitions        (int)
             Number of times the patterns will be fed

    @param   randomOrder           (bool)
             If true, the order of patterns will be shuffled at each
             repetition

    """
    if newObject:
      for pooler in self.poolers:
        pooler.mmClearHistory()
        pooler.reset()

    # use different set of pattern indices to allow random orders
    indices = [range(len(feedforwardPatterns))] * len(self.poolers)
    prevActiveCells = [set() for _ in xrange(len(self.poolers))]

    # by default, all columns are neighbors
    if neighborsIndices is None:
      neighborsIndices = [
        range(i) + range(i+1, len(self.poolers))
        for i in xrange(len(self.poolers))
      ]

    for _ in xrange(numRepetitions):

      # independently shuffle pattern orders if necessary
      if randomOrder:
        for idx in indices:
          np.random.shuffle(idx)

      for i in xrange(len(indices[0])):
        # Train each column
        for col, pooler in enumerate(self.poolers):
          # get union of relevant lateral representations
          lateralInputs =  [sorted(activeCells)
                            for presynapticCol, activeCells
                            in enumerate(prevActiveCells)
                            if col != presynapticCol]

          pooler.compute(sorted(feedforwardPatterns[indices[col][i]][col]),
                         lateralInputs, learn=True)

        prevActiveCells = self._getActiveRepresentations()


  def inferMultipleColumns(self,
                           feedforwardPatterns,
                           activeRepresentations,
                           neighborsIndices=None,
                           printMetrics=False,
                           reset=False):
    """
    Feeds a single pattern to the column pooler (as well as an eventual lateral
    pattern).

    Parameters:
    ----------------------------
    @param feedforwardPattern       (list(set))
           Input proximal patterns to the pooler (one for each column)

    @param activeRepresentations    (list(set))
           Active representations in the columns at the previous step.

    @param neighborsIndices         (list(list))
           List of column indices each column received input from.

    @param printMetrics             (bool)
           If true, will print cell metrics

    """
    if reset:
      for pooler in self.poolers:
        pooler.reset()

    # by default, all columns are neighbors
    if neighborsIndices is None:
      neighborsIndices = [
        range(i) + range(i+1, len(self.poolers))
        for i in xrange(len(self.poolers))
      ]

    for col, pooler in enumerate(self.poolers):
      # get union of relevant lateral representations
      lateralInputs = [sorted(activeCells)
                       for presynapticCol, activeCells
                       in enumerate(activeRepresentations)
                       if col != presynapticCol]

      pooler.compute(sorted(feedforwardPatterns[col]),
                     lateralInputs, learn=False)

    if printMetrics:
      for pooler in self.poolers:
        print pooler.mmPrettyPrintMetrics(
          pooler.mmGetDefaultMetrics()
        )


  def _getActiveRepresentations(self):
    """
    Retrieves the current active representations in the poolers.
    """
    if len(self.poolers) == 0:
      raise ValueError("No pooler has been instantiated")

    return [set(pooler.getActiveCells()) for pooler in self.poolers]
示例#16
0
class TemporalMemoryTest(unittest.TestCase):


  def setUp(self):
    self.tm = TemporalMemory()


  def testInitInvalidParams(self):
    # Invalid columnDimensions
    kwargs = {"columnDimensions": [], "cellsPerColumn": 32}
    self.assertRaises(ValueError, TemporalMemory, **kwargs)

    # Invalid cellsPerColumn
    kwargs = {"columnDimensions": [2048], "cellsPerColumn": 0}
    self.assertRaises(ValueError, TemporalMemory, **kwargs)
    kwargs = {"columnDimensions": [2048], "cellsPerColumn": -10}
    self.assertRaises(ValueError, TemporalMemory, **kwargs)


  def testActivateCorrectlyPredictiveCells(self):
    tm = self.tm

    prevPredictiveCells = set([0, 237, 1026, 26337, 26339, 55536])
    activeColumns = set([32, 47, 823])
    prevMatchingCells = set()

    (activeCells,
    winnerCells,
    predictedColumns,
    predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(prevPredictiveCells,
                                                                  prevMatchingCells,
                                                                  activeColumns)

    self.assertEqual(activeCells, set([1026, 26337, 26339]))
    self.assertEqual(winnerCells, set([1026, 26337, 26339]))
    self.assertEqual(predictedColumns, set([32, 823]))
    self.assertEqual(predictedInactiveCells, set())


  def testActivateCorrectlyPredictiveCellsEmpty(self):
    tm = self.tm

    # No previous predictive cells, no active columns
    prevPredictiveCells = set()
    activeColumns      = set()
    prevMatchingCells = set()

    (activeCells,
    winnerCells,
    predictedColumns,
    predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(prevPredictiveCells,
                                                                  prevMatchingCells,
                                                                  activeColumns)

    self.assertEqual(activeCells,      set())
    self.assertEqual(winnerCells,      set())
    self.assertEqual(predictedColumns, set())
    self.assertEqual(predictedInactiveCells, set())

    # No previous predictive cells, with active columns

    prevPredictiveCells = set()
    activeColumns = set([32, 47, 823])
    prevMatchingCells = set()

    (activeCells,
    winnerCells,
    predictedColumns,
    predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(prevPredictiveCells,
                                                                  prevMatchingCells,
                                                                  activeColumns)

    self.assertEqual(activeCells,      set())
    self.assertEqual(winnerCells,      set())
    self.assertEqual(predictedColumns, set())
    self.assertEqual(predictedInactiveCells, set())

    # No active columns, with previously predictive cells

    prevPredictiveCells = set([0, 237, 1026, 26337, 26339, 55536])
    activeColumns = set()
    prevMatchingCells = set()

    (activeCells,
    winnerCells,
    predictedColumns,
    predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(prevPredictiveCells,
                                                                  prevMatchingCells,
                                                                  activeColumns)

    self.assertEqual(activeCells,      set())
    self.assertEqual(winnerCells,      set())
    self.assertEqual(predictedColumns, set())
    self.assertEqual(predictedInactiveCells, set())

  def testActivateCorrectlyPredictiveCellsOrphan(self):
    tm = self.tm

    prevPredictiveCells = set([])
    activeColumns = set([32, 47, 823])
    prevMatchingCells = set([32, 47])

    (activeCells,
    winnerCells,
    predictedColumns,
    predictedInactiveCells) = tm.activateCorrectlyPredictiveCells(prevPredictiveCells,
                                                                  prevMatchingCells,
                                                                  activeColumns)

    self.assertEqual(activeCells, set([]))
    self.assertEqual(winnerCells, set([]))
    self.assertEqual(predictedColumns, set([]))
    self.assertEqual(predictedInactiveCells, set([32,47]))

  def testBurstColumns(self):
    tm = TemporalMemory(
      cellsPerColumn=4,
      connectedPermanence=0.50,
      minThreshold=1,
      seed=42
    )

    connections = tm.connections
    connections.createSegment(0)
    connections.createSynapse(0, 23, 0.6)
    connections.createSynapse(0, 37, 0.4)
    connections.createSynapse(0, 477, 0.9)

    connections.createSegment(0)
    connections.createSynapse(1, 49, 0.9)
    connections.createSynapse(1, 3, 0.8)

    connections.createSegment(1)
    connections.createSynapse(2, 733, 0.7)

    connections.createSegment(108)
    connections.createSynapse(3, 486, 0.9)

    activeColumns = set([0, 1, 26])
    predictedColumns = set([26])
    prevActiveCells = set([23, 37, 49, 733])
    prevWinnerCells = set([23, 37, 49, 733])

    (activeCells,
     winnerCells,
     learningSegments) = tm.burstColumns(activeColumns,
                                         predictedColumns,
                                         prevActiveCells,
                                         prevWinnerCells,
                                         connections)

    self.assertEqual(activeCells, set([0, 1, 2, 3, 4, 5, 6, 7]))
    self.assertEqual(winnerCells, set([0, 6]))  # 6 is randomly chosen cell
    self.assertEqual(learningSegments, set([0, 4]))  # 4 is new segment created

    # Check that new segment was added to winner cell (6) in column 1
    self.assertEqual(connections.segmentsForCell(6), set([4]))


  def testBurstColumnsEmpty(self):
    tm = self.tm

    activeColumns    = set()
    predictedColumns = set()
    prevActiveCells = set()
    prevWinnerCells = set()
    connections = tm.connections

    (activeCells,
     winnerCells,
     learningSegments) = tm.burstColumns(activeColumns,
                                         predictedColumns,
                                         prevActiveCells,
                                         prevWinnerCells,
                                         connections)

    self.assertEqual(activeCells,      set())
    self.assertEqual(winnerCells,      set())
    self.assertEqual(learningSegments, set())


  def testLearnOnSegments(self):
    tm = TemporalMemory(maxNewSynapseCount=2)

    connections = tm.connections
    connections.createSegment(0)
    connections.createSynapse(0, 23, 0.6)
    connections.createSynapse(0, 37, 0.4)
    connections.createSynapse(0, 477, 0.9)

    connections.createSegment(1)
    connections.createSynapse(1, 733, 0.7)

    connections.createSegment(8)
    connections.createSynapse(2, 486, 0.9)

    connections.createSegment(100)

    prevActiveSegments = set([0, 2])
    learningSegments = set([1, 3])
    prevActiveCells = set([23, 37, 733])
    winnerCells = set([0])
    prevWinnerCells = set([10, 11, 12, 13, 14])
    predictedInactiveCells = set()
    prevMatchingSegments = set()
    tm.learnOnSegments(prevActiveSegments,
                       learningSegments,
                       prevActiveCells,
                       winnerCells,
                       prevWinnerCells,
                       connections,
                       predictedInactiveCells,
                       prevMatchingSegments)

    # Check segment 0
    synapseData = connections.dataForSynapse(0)
    self.assertAlmostEqual(synapseData.permanence, 0.7)

    synapseData = connections.dataForSynapse(1)
    self.assertAlmostEqual(synapseData.permanence, 0.5)

    synapseData = connections.dataForSynapse(2)
    self.assertAlmostEqual(synapseData.permanence, 0.8)

    # Check segment 1
    synapseData = connections.dataForSynapse(3)
    self.assertAlmostEqual(synapseData.permanence, 0.8)

    self.assertEqual(len(connections.synapsesForSegment(1)), 2)

    # Check segment 2
    synapseData = connections.dataForSynapse(4)
    self.assertAlmostEqual(synapseData.permanence, 0.9)

    self.assertEqual(len(connections.synapsesForSegment(2)), 1)

    # Check segment 3
    self.assertEqual(len(connections.synapsesForSegment(3)), 2)


  def testComputePredictiveCells(self):
    tm = TemporalMemory(activationThreshold=2, minThreshold=2)

    connections = tm.connections
    connections.createSegment(0)
    connections.createSynapse(0, 23, 0.6)
    connections.createSynapse(0, 37, 0.5)
    connections.createSynapse(0, 477, 0.9)

    connections.createSegment(1)
    connections.createSynapse(1, 733, 0.7)
    connections.createSynapse(1, 733, 0.4)

    connections.createSegment(1)
    connections.createSynapse(2, 974, 0.9)

    connections.createSegment(8)
    connections.createSynapse(3, 486, 0.9)

    connections.createSegment(100)

    activeCells = set([23, 37, 733, 974])

    (activeSegments,
     predictiveCells,
     matchingSegments,
     matchingCells) = tm.computePredictiveCells(activeCells, connections)
    self.assertEqual(activeSegments, set([0]))
    self.assertEqual(predictiveCells, set([0]))
    self.assertEqual(matchingSegments, set([0,1]))
    self.assertEqual(matchingCells, set([0,1]))


  def testBestMatchingCell(self):
    tm = TemporalMemory(
      connectedPermanence=0.50,
      minThreshold=1,
      seed=42
    )

    connections = tm.connections
    connections.createSegment(0)
    connections.createSynapse(0, 23, 0.6)
    connections.createSynapse(0, 37, 0.4)
    connections.createSynapse(0, 477, 0.9)

    connections.createSegment(0)
    connections.createSynapse(1, 49, 0.9)
    connections.createSynapse(1, 3, 0.8)

    connections.createSegment(1)
    connections.createSynapse(2, 733, 0.7)

    connections.createSegment(108)
    connections.createSynapse(3, 486, 0.9)

    activeCells = set([23, 37, 49, 733])

    self.assertEqual(tm.bestMatchingCell(tm.cellsForColumn(0),
                                         activeCells,
                                         connections),
                     (0, 0))

    self.assertEqual(tm.bestMatchingCell(tm.cellsForColumn(3),  # column containing cell 108
                                         activeCells,
                                         connections),
                     (96, None))  # Random cell from column

    self.assertEqual(tm.bestMatchingCell(tm.cellsForColumn(999),
                                         activeCells,
                                         connections),
                     (31972, None))  # Random cell from column


  def testBestMatchingCellFewestSegments(self):
    tm = TemporalMemory(
      columnDimensions=[2],
      cellsPerColumn=2,
      connectedPermanence=0.50,
      minThreshold=1,
      seed=42
    )

    connections = tm.connections
    connections.createSegment(0)
    connections.createSynapse(0, 3, 0.3)

    activeSynapsesForSegment = set([])

    for _ in range(100):
      # Never pick cell 0, always pick cell 1
      (cell, _) = tm.bestMatchingCell(tm.cellsForColumn(0),
                                      activeSynapsesForSegment,
                                      connections)
      self.assertEqual(cell, 1)


  def testBestMatchingSegment(self):
    tm = TemporalMemory(
      connectedPermanence=0.50,
      minThreshold=1
    )

    connections = tm.connections
    connections.createSegment(0)
    connections.createSynapse(0, 23, 0.6)
    connections.createSynapse(0, 37, 0.4)
    connections.createSynapse(0, 477, 0.9)

    connections.createSegment(0)
    connections.createSynapse(1, 49, 0.9)
    connections.createSynapse(1, 3, 0.8)

    connections.createSegment(1)
    connections.createSynapse(2, 733, 0.7)

    connections.createSegment(8)
    connections.createSynapse(3, 486, 0.9)

    activeCells = set([23, 37, 49, 733])

    self.assertEqual(tm.bestMatchingSegment(0,
                                            activeCells,
                                            connections),
                     (0, 2))

    self.assertEqual(tm.bestMatchingSegment(1,
                                            activeCells,
                                            connections),
                     (2, 1))

    self.assertEqual(tm.bestMatchingSegment(8,
                                            activeCells,
                                            connections),
                     (None, None))

    self.assertEqual(tm.bestMatchingSegment(100,
                                            activeCells,
                                            connections),
                     (None, None))


  def testLeastUsedCell(self):
    tm = TemporalMemory(
      columnDimensions=[2],
      cellsPerColumn=2,
      seed=42
    )

    connections = tm.connections
    connections.createSegment(0)
    connections.createSynapse(0, 3, 0.3)

    for _ in range(100):
      # Never pick cell 0, always pick cell 1
      self.assertEqual(tm.leastUsedCell(tm.cellsForColumn(0),
                                        connections),
                       1)


  def testAdaptSegment(self):
    tm = self.tm

    connections = tm.connections
    connections.createSegment(0)
    connections.createSynapse(0, 23, 0.6)
    connections.createSynapse(0, 37, 0.4)
    connections.createSynapse(0, 477, 0.9)

    tm.adaptSegment(0, set([0, 1]), connections,
                    tm.permanenceIncrement,
                    tm.permanenceDecrement)

    synapseData = connections.dataForSynapse(0)
    self.assertAlmostEqual(synapseData.permanence, 0.7)

    synapseData = connections.dataForSynapse(1)
    self.assertAlmostEqual(synapseData.permanence, 0.5)

    synapseData = connections.dataForSynapse(2)
    self.assertAlmostEqual(synapseData.permanence, 0.8)


  def testAdaptSegmentToMax(self):
    tm = self.tm

    connections = tm.connections
    connections.createSegment(0)
    connections.createSynapse(0, 23, 0.9)

    tm.adaptSegment(0, set([0]), connections,
                    tm.permanenceIncrement,
                    tm.permanenceDecrement)
    synapseData = connections.dataForSynapse(0)
    self.assertAlmostEqual(synapseData.permanence, 1.0)

    # Now permanence should be at max
    tm.adaptSegment(0, set([0]), connections,
                    tm.permanenceIncrement,
                    tm.permanenceDecrement)
    synapseData = connections.dataForSynapse(0)
    self.assertAlmostEqual(synapseData.permanence, 1.0)


  def testAdaptSegmentToMin(self):
    tm = self.tm

    connections = tm.connections
    connections.createSegment(0)
    connections.createSynapse(0, 23, 0.1)

    tm.adaptSegment(0, set(), connections,
                    tm.permanenceIncrement,
                    tm.permanenceDecrement)
    synapseData = connections.dataForSynapse(0)
    self.assertAlmostEqual(synapseData.permanence, 0.0)

    # Now permanence should be at min
    tm.adaptSegment(0, set(), connections,
                    tm.permanenceIncrement,
                    tm.permanenceDecrement)
    synapseData = connections.dataForSynapse(0)
    self.assertAlmostEqual(synapseData.permanence, 0.0)


  def testPickCellsToLearnOn(self):
    tm = TemporalMemory(seed=42)

    connections = tm.connections
    connections.createSegment(0)

    winnerCells = set([4, 47, 58, 93])

    self.assertEqual(tm.pickCellsToLearnOn(2, 0, winnerCells, connections),
                     set([4, 58]))  # randomly picked

    self.assertEqual(tm.pickCellsToLearnOn(100, 0, winnerCells, connections),
                     set([4, 47, 58, 93]))

    self.assertEqual(tm.pickCellsToLearnOn(0, 0, winnerCells, connections),
                     set())


  def testPickCellsToLearnOnAvoidDuplicates(self):
    tm = TemporalMemory(seed=42)

    connections = tm.connections
    connections.createSegment(0)
    connections.createSynapse(0, 23, 0.6)

    winnerCells = set([23])

    # Ensure that no additional (duplicate) cells were picked
    self.assertEqual(tm.pickCellsToLearnOn(2, 0, winnerCells, connections),
                     set())


  def testColumnForCell1D(self):
    tm = TemporalMemory(
      columnDimensions=[2048],
      cellsPerColumn=5
    )
    self.assertEqual(tm.columnForCell(0), 0)
    self.assertEqual(tm.columnForCell(4), 0)
    self.assertEqual(tm.columnForCell(5), 1)
    self.assertEqual(tm.columnForCell(10239), 2047)


  def testColumnForCell2D(self):
    tm = TemporalMemory(
      columnDimensions=[64, 64],
      cellsPerColumn=4
    )
    self.assertEqual(tm.columnForCell(0), 0)
    self.assertEqual(tm.columnForCell(3), 0)
    self.assertEqual(tm.columnForCell(4), 1)
    self.assertEqual(tm.columnForCell(16383), 4095)


  def testColumnForCellInvalidCell(self):
    tm = TemporalMemory(
      columnDimensions=[64, 64],
      cellsPerColumn=4
    )

    try:
      tm.columnForCell(16383)
    except IndexError:
      self.fail("IndexError raised unexpectedly")

    args = [16384]
    self.assertRaises(IndexError, tm.columnForCell, *args)

    args = [-1]
    self.assertRaises(IndexError, tm.columnForCell, *args)


  def testCellsForColumn1D(self):
    tm = TemporalMemory(
      columnDimensions=[2048],
      cellsPerColumn=5
    )
    expectedCells = set([5, 6, 7, 8, 9])
    self.assertEqual(tm.cellsForColumn(1), expectedCells)


  def testCellsForColumn2D(self):
    tm = TemporalMemory(
      columnDimensions=[64, 64],
      cellsPerColumn=4
    )
    expectedCells = set([256, 257, 258, 259])
    self.assertEqual(tm.cellsForColumn(64), expectedCells)


  def testCellsForColumnInvalidColumn(self):
    tm = TemporalMemory(
      columnDimensions=[64, 64],
      cellsPerColumn=4
    )

    try:
      tm.cellsForColumn(4095)
    except IndexError:
      self.fail("IndexError raised unexpectedly")

    args = [4096]
    self.assertRaises(IndexError, tm.cellsForColumn, *args)

    args = [-1]
    self.assertRaises(IndexError, tm.cellsForColumn, *args)


  def testNumberOfColumns(self):
    tm = TemporalMemory(
      columnDimensions=[64, 64],
      cellsPerColumn=32
    )
    self.assertEqual(tm.numberOfColumns(), 64 * 64)


  def testNumberOfCells(self):
    tm = TemporalMemory(
      columnDimensions=[64, 64],
      cellsPerColumn=32
    )
    self.assertEqual(tm.numberOfCells(), 64 * 64 * 32)


  def testMapCellsToColumns(self):
    tm = TemporalMemory(
      columnDimensions=[100],
      cellsPerColumn=4
    )
    columnsForCells = tm.mapCellsToColumns(set([0, 1, 2, 5, 399]))
    self.assertEqual(columnsForCells[0], set([0, 1, 2]))
    self.assertEqual(columnsForCells[1], set([5]))
    self.assertEqual(columnsForCells[99], set([399]))


  def testWrite(self):
    tm1 = TemporalMemory(
      columnDimensions=[100],
      cellsPerColumn=4,
      activationThreshold=7,
      initialPermanence=0.37,
      connectedPermanence=0.58,
      minThreshold=4,
      maxNewSynapseCount=18,
      permanenceIncrement=0.23,
      permanenceDecrement=0.08,
      seed=91
    )

    # Run some data through before serializing
    self.patternMachine = PatternMachine(100, 4)
    self.sequenceMachine = SequenceMachine(self.patternMachine)
    sequence = self.sequenceMachine.generateFromNumbers(range(5))
    for _ in range(3):
      for pattern in sequence:
        tm1.compute(pattern)

    proto1 = TemporalMemoryProto_capnp.TemporalMemoryProto.new_message()
    tm1.write(proto1)

    # Write the proto to a temp file and read it back into a new proto
    with tempfile.TemporaryFile() as f:
      proto1.write(f)
      f.seek(0)
      proto2 = TemporalMemoryProto_capnp.TemporalMemoryProto.read(f)

    # Load the deserialized proto
    tm2 = TemporalMemory.read(proto2)

    # Check that the two temporal memory objects have the same attributes
    self.assertEqual(tm1, tm2)

    # Run a couple records through after deserializing and check results match
    tm1.compute(self.patternMachine.get(0))
    tm2.compute(self.patternMachine.get(0))
    self.assertEqual(tm1.activeCells, tm2.activeCells)
    self.assertEqual(tm1.predictiveCells, tm2.predictiveCells)
    self.assertEqual(tm1.winnerCells, tm2.winnerCells)
    self.assertEqual(tm1.connections, tm2.connections)

    tm1.compute(self.patternMachine.get(3))
    tm2.compute(self.patternMachine.get(3))
    self.assertEqual(tm1.activeCells, tm2.activeCells)
    self.assertEqual(tm1.predictiveCells, tm2.predictiveCells)
    self.assertEqual(tm1.winnerCells, tm2.winnerCells)
    self.assertEqual(tm1.connections, tm2.connections)