Beispiel #1
0
class SegmentedForwardModel(object):
  """
  A forward model that uses dendrite segments. Every cell has a set of segments.
  Every segment has a set of synapses. The cell fires when the number of active
  synapses on one of its segments reaches a threshold.
  """

  def __init__(self, cellCount, inputSize, threshold):
    self.proximalConnections = SparseMatrixConnections(cellCount, inputSize)
    self.threshold = threshold
    self.activeCells = np.empty(0, dtype='uint32')
    self.activeSegments = np.empty(0, dtype='uint32')

  def associate(self, activeCells, activeInput):
    self.activeCells = activeCells
    self.activeSegments = self.proximalConnections.createSegments(
      activeCells)
    self.proximalConnections.matrix.setZerosOnOuter(
      self.activeSegments, activeInput, 1.0)

  def infer(self, activeInput):
    overlaps = self.proximalConnections.computeActivity(activeInput)
    self.activeSegments = np.where(overlaps >= self.threshold)[0]
    self.activeCells = self.proximalConnections.mapSegmentsToCells(
      self.activeSegments)
    self.activeCells.sort()
class SensorToSpecificObjectModule(object):
  """
  Represents the sensor location relative to a specific object. Typically
  these modules are arranged in an array, and the combined population SDR is
  used to predict a feature-location pair.

  This class has two sets of connections. Both of them compute the "sensor's
  location relative to a specific object" in different ways.

  The "metric connections" compute it from the
    "body's location relative to a specific object"
  and the
    "sensor's location relative to body"
  These connections are learned once and then never need to be updated. They
  might be genetically hardcoded. They're initialized externally, e.g. in
  BodyToSpecificObjectModule2D.

  The "anchor connections" compute it from the sensory input. Whenever a
  cortical column learns a feature-location pair, this layer forms reciprocal
  connections with the feature-location pair layer.

  These segments receive input at different times. The metric connections
  receive input first, and they activate a set of cells. This set of cells is
  used externally to predict a feature-location pair. Then this feature-location
  pair is the input to the anchor connections.
  """

  def __init__(self, cellDimensions, anchorInputSize,
               activationThreshold=10,
               initialPermanence=0.21,
               connectedPermanence=0.50,
               learningThreshold=10,
               sampleSize=20,
               permanenceIncrement=0.1,
               permanenceDecrement=0.0,
               maxSynapsesPerSegment=-1,
               seed=42):
    """
    @param cellDimensions (sequence of ints)
    @param anchorInputSize (int)
    @param activationThreshold (int)
    """
    self.activationThreshold = activationThreshold
    self.initialPermanence = initialPermanence
    self.connectedPermanence = connectedPermanence
    self.learningThreshold = learningThreshold
    self.sampleSize = sampleSize
    self.permanenceIncrement = permanenceIncrement
    self.permanenceDecrement = permanenceDecrement
    self.activationThreshold = activationThreshold
    self.maxSynapsesPerSegment = maxSynapsesPerSegment

    self.rng = Random(seed)

    self.cellCount = np.prod(cellDimensions)
    cellCountBySource = {
      "bodyToSpecificObject": self.cellCount,
      "sensorToBody": self.cellCount,
    }
    self.metricConnections = Multiconnections(self.cellCount,
                                              cellCountBySource)
    self.anchorConnections = SparseMatrixConnections(self.cellCount,
                                                     anchorInputSize)


  def reset(self):
    self.activeCells = np.empty(0, dtype="int")


  def metricCompute(self, sensorToBody, bodyToSpecificObject):
    """
    Compute the
      "sensor's location relative to a specific object"
    from the
      "body's location relative to a specific object"
    and the
      "sensor's location relative to body"

    @param sensorToBody (numpy array)
    Active cells of a single module that represents the sensor's location
    relative to the body

    @param bodyToSpecificObject (numpy array)
    Active cells of a single module that represents the body's location relative
    to a specific object
    """
    overlaps = self.metricConnections.computeActivity({
      "bodyToSpecificObject": bodyToSpecificObject,
      "sensorToBody": sensorToBody,
    })

    self.activeMetricSegments = np.where(overlaps >= 2)[0]
    self.activeCells = np.unique(
      self.metricConnections.mapSegmentsToCells(
        self.activeMetricSegments))


  def anchorCompute(self, anchorInput, learn):
    """
    Compute the
      "sensor's location relative to a specific object"
    from the feature-location pair.

    @param anchorInput (numpy array)
    Active cells in the feature-location pair layer

    @param learn (bool)
    If true, maintain current cell activity and learn this input on the
    currently active cells
    """
    if learn:
      self._anchorComputeLearningMode(anchorInput)
    else:
      overlaps = self.anchorConnections.computeActivity(
        anchorInput, self.connectedPermanence)

      self.activeSegments = np.where(overlaps >= self.activationThreshold)[0]
      self.activeCells = np.unique(
        self.anchorConnections.mapSegmentsToCells(self.activeSegments))


  def _anchorComputeLearningMode(self, anchorInput):
    """
    Associate this location with a sensory input. Subsequently, anchorInput will
    activate the current location during anchor().

    @param anchorInput (numpy array)
    A sensory input. This will often come from a feature-location pair layer.
    """

    overlaps = self.anchorConnections.computeActivity(
      anchorInput, self.connectedPermanence)

    activeSegments = np.where(overlaps >= self.activationThreshold)[0]

    potentialOverlaps = self.anchorConnections.computeActivity(anchorInput)
    matchingSegments = np.where(potentialOverlaps >=
                                self.learningThreshold)[0]

    # Cells with a active segment: reinforce the segment
    cellsForActiveSegments = self.anchorConnections.mapSegmentsToCells(
      activeSegments)
    learningActiveSegments = activeSegments[
      np.in1d(cellsForActiveSegments, self.activeCells)]
    remainingCells = np.setdiff1d(self.activeCells, cellsForActiveSegments)

    # Remaining cells with a matching segment: reinforce the best
    # matching segment.
    candidateSegments = self.anchorConnections.filterSegmentsByCell(
      matchingSegments, remainingCells)
    cellsForCandidateSegments = (
      self.anchorConnections.mapSegmentsToCells(candidateSegments))
    candidateSegments = candidateSegments[
      np.in1d(cellsForCandidateSegments, remainingCells)]
    onePerCellFilter = np2.argmaxMulti(potentialOverlaps[candidateSegments],
                                       cellsForCandidateSegments)
    learningMatchingSegments = candidateSegments[onePerCellFilter]

    newSegmentCells = np.setdiff1d(remainingCells, cellsForCandidateSegments)

    for learningSegments in (learningActiveSegments,
                             learningMatchingSegments):
      self._learn(self.anchorConnections, self.rng, learningSegments,
                  anchorInput, potentialOverlaps,
                  self.initialPermanence, self.sampleSize,
                  self.permanenceIncrement, self.permanenceDecrement,
                  self.maxSynapsesPerSegment)

    # Remaining cells without a matching segment: grow one.
    numNewSynapses = len(anchorInput)

    if self.sampleSize != -1:
      numNewSynapses = min(numNewSynapses, self.sampleSize)

    if self.maxSynapsesPerSegment != -1:
      numNewSynapses = min(numNewSynapses, self.maxSynapsesPerSegment)

    newSegments = self.anchorConnections.createSegments(newSegmentCells)

    self.anchorConnections.growSynapsesToSample(
      newSegments, anchorInput, numNewSynapses,
      self.initialPermanence, self.rng)

    self.activeSegments = activeSegments


  @staticmethod
  def _learn(connections, rng, learningSegments, activeInput,
             potentialOverlaps, initialPermanence, sampleSize,
             permanenceIncrement, permanenceDecrement, maxSynapsesPerSegment):
    """
    Adjust synapse permanences, grow new synapses, and grow new segments.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param segmentsToPunish (numpy array)
    @param activeInput (numpy array)
    @param potentialOverlaps (numpy array)
    """
    # Learn on existing segments
    connections.adjustSynapses(learningSegments, activeInput,
                               permanenceIncrement, -permanenceDecrement)

    # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
    # grow per segment. "maxNew" might be a number or it might be a list of
    # numbers.
    if sampleSize == -1:
      maxNew = len(activeInput)
    else:
      maxNew = sampleSize - potentialOverlaps[learningSegments]

    if maxSynapsesPerSegment != -1:
      synapseCounts = connections.mapSegmentsToSynapseCounts(
        learningSegments)
      numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
      maxNew = np.where(maxNew <= numSynapsesToReachMax,
                        maxNew, numSynapsesToReachMax)

    connections.growSynapsesToSample(learningSegments, activeInput,
                                     maxNew, initialPermanence, rng)


  def getActiveCells(self):
    return self.activeCells
class Superficial2DLocationModule(object):
  """
  A model of a location module. It's similar to a grid cell module, but it uses
  squares rather than triangles.

  The cells are arranged into a m*n rectangle which is tiled onto 2D space.
  Each cell represents a small rectangle in each tile.

  +------+------+------++------+------+------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #1  |  #2  |  #3  ||  #1  |  #2  |  #3  |
  |      |      |      ||      |      |      |
  +--------------------++--------------------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #4  |  #5  |  #6  ||  #4  |  #5  |  #6  |
  |      |      |      ||      |      |      |
  +--------------------++--------------------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #7  |  #8  |  #9  ||  #7  |  #8  |  #9  |
  |      |      |      ||      |      |      |
  +------+------+------++------+------+------+

  We assume that path integration works *somehow*. This model receives a "delta
  location" vector, and it shifts the active cells accordingly. The model stores
  intermediate coordinates of active cells. Whenever sensory cues activate a
  cell, the model adds this cell to the list of coordinates being shifted.
  Whenever sensory cues cause a cell to become inactive, that cell is removed
  from the list of coordinates.

  (This model doesn't attempt to propose how "path integration" works. It
  attempts to show how locations are anchored to sensory cues.)

  When orientation is set to 0 degrees, the displacement is a [di, dj],
  moving di cells "down" and dj cells "right".

  When orientation is set to 90 degrees, the displacement is essentially a
  [dx, dy], applied in typical x,y coordinates with the origin on the bottom
  left.

  Usage:
  - When the sensor moves, call movementCompute.
  - When the sensor senses something, call sensoryCompute.

  The "anchor input" is typically a feature-location pair SDR.

  To specify how points are tracked, pass anchoringMethod = "corners",
  "narrowing" or "discrete".  "discrete" will cause the network to operate in a
  fully discrete space, where uncertainty is impossible as long as movements are
  integers.  "narrowing" is designed to narrow down uncertainty of initial
  locations of sensory stimuli.  "corners" is designed for noise-tolerance, and
  will activate all cells that are possible outcomes of path integration.
  """

  def __init__(self,
               cellDimensions,
               moduleMapDimensions,
               orientation,
               anchorInputSize,
               cellCoordinateOffsets=(0.5,),
               activationThreshold=10,
               initialPermanence=0.21,
               connectedPermanence=0.50,
               learningThreshold=10,
               sampleSize=20,
               permanenceIncrement=0.1,
               permanenceDecrement=0.0,
               maxSynapsesPerSegment=-1,
               anchoringMethod="narrowing",
               rotationMatrix = None,
               seed=42):
    """
    @param cellDimensions (tuple(int, int))
    Determines the number of cells. Determines how space is divided between the
    cells.

    @param moduleMapDimensions (tuple(float, float))
    Determines the amount of world space covered by all of the cells combined.
    In grid cell terminology, this is equivalent to the "scale" of a module.
    A module with a scale of "5cm" would have moduleMapDimensions=(5.0, 5.0).

    @param orientation (float)
    The rotation of this map, measured in radians.

    @param anchorInputSize (int)
    The number of input bits in the anchor input.

    @param cellCoordinateOffsets (list of floats)
    These must each be between 0.0 and 1.0. Every time a cell is activated by
    anchor input, this class adds a "phase" which is shifted in subsequent
    motions. By default, this phase is placed at the center of the cell. This
    parameter allows you to control where the point is placed and whether multiple
    are placed. For example, with value [0.2, 0.8], when cell [2, 3] is activated
    it will place 4 phases, corresponding to the following points in cell
    coordinates: [2.2, 3.2], [2.2, 3.8], [2.8, 3.2], [2.8, 3.8]
    """

    self.cellDimensions = np.asarray(cellDimensions, dtype="int")
    self.moduleMapDimensions = np.asarray(moduleMapDimensions, dtype="float")
    self.phasesPerUnitDistance = 1.0 / self.moduleMapDimensions

    if rotationMatrix is None:
      self.orientation = orientation
      self.rotationMatrix = np.array(
        [[math.cos(orientation), -math.sin(orientation)],
         [math.sin(orientation), math.cos(orientation)]])
      if anchoringMethod == "discrete":
        # Need to convert matrix to have integer values
        nonzeros = self.rotationMatrix[np.where(np.abs(self.rotationMatrix)>0)]
        smallestValue = np.amin(nonzeros)
        self.rotationMatrix /= smallestValue
        self.rotationMatrix = np.ceil(self.rotationMatrix)
    else:
      self.rotationMatrix = rotationMatrix

    self.cellCoordinateOffsets = cellCoordinateOffsets

    # Phase is measured as a number in the range [0.0, 1.0)
    self.activePhases = np.empty((0,2), dtype="float")
    self.cellsForActivePhases = np.empty(0, dtype="int")
    self.phaseDisplacement = np.empty((0,2), dtype="float")

    self.activeCells = np.empty(0, dtype="int")
    self.activeSegments = np.empty(0, dtype="uint32")

    self.connections = SparseMatrixConnections(np.prod(cellDimensions),
                                               anchorInputSize)

    self.initialPermanence = initialPermanence
    self.connectedPermanence = connectedPermanence
    self.learningThreshold = learningThreshold
    self.sampleSize = sampleSize
    self.permanenceIncrement = permanenceIncrement
    self.permanenceDecrement = permanenceDecrement
    self.activationThreshold = activationThreshold
    self.maxSynapsesPerSegment = maxSynapsesPerSegment

    self.anchoringMethod = anchoringMethod

    self.rng = Random(seed)


  def reset(self):
    """
    Clear the active cells.
    """
    self.activePhases = np.empty((0,2), dtype="float")
    self.phaseDisplacement = np.empty((0,2), dtype="float")
    self.cellsForActivePhases = np.empty(0, dtype="int")
    self.activeCells = np.empty(0, dtype="int")


  def _computeActiveCells(self):
    # Round each coordinate to the nearest cell.
    activeCellCoordinates = np.floor(
      self.activePhases * self.cellDimensions).astype("int")

    # Convert coordinates to cell numbers.
    self.cellsForActivePhases = (
      np.ravel_multi_index(activeCellCoordinates.T, self.cellDimensions))
    self.activeCells = np.unique(self.cellsForActivePhases)


  def activateRandomLocation(self):
    """
    Set the location to a random point.
    """
    self.activePhases = np.array([np.random.random(2)])
    if self.anchoringMethod == "discrete":
      # Need to place the phase in the middle of a cell
      self.activePhases = np.floor(
        self.activePhases * self.cellDimensions)/self.cellDimensions
    self._computeActiveCells()


  def movementCompute(self, displacement, noiseFactor = 0):
    """
    Shift the current active cells by a vector.

    @param displacement (pair of floats)
    A translation vector [di, dj].
    """

    if noiseFactor != 0:
      displacement = copy.deepcopy(displacement)
      xnoise = np.random.normal(0, noiseFactor)
      ynoise = np.random.normal(0, noiseFactor)
      displacement[0] += xnoise
      displacement[1] += ynoise


    # Calculate delta in the module's coordinates.
    phaseDisplacement = (np.matmul(self.rotationMatrix, displacement) *
                         self.phasesPerUnitDistance)

    # Shift the active coordinates.
    np.add(self.activePhases, phaseDisplacement, out=self.activePhases)

    # In Python, (x % 1.0) can return 1.0 because of floating point goofiness.
    # Generally this doesn't cause problems, it's just confusing when you're
    # debugging.
    np.round(self.activePhases, decimals=9, out=self.activePhases)
    np.mod(self.activePhases, 1.0, out=self.activePhases)

    self._computeActiveCells()
    self.phaseDisplacement = phaseDisplacement


  def _sensoryComputeInferenceMode(self, anchorInput):
    """
    Infer the location from sensory input. Activate any cells with enough active
    synapses to this sensory input. Deactivate all other cells.

    @param anchorInput (numpy array)
    A sensory input. This will often come from a feature-location pair layer.
    """
    if len(anchorInput) == 0:
      return

    overlaps = self.connections.computeActivity(anchorInput,
                                                self.connectedPermanence)
    activeSegments = np.where(overlaps >= self.activationThreshold)[0]

    sensorySupportedCells = np.unique(
      self.connections.mapSegmentsToCells(activeSegments))

    inactivated = np.setdiff1d(self.activeCells, sensorySupportedCells)
    inactivatedIndices = np.in1d(self.cellsForActivePhases,
                                 inactivated).nonzero()[0]
    if inactivatedIndices.size > 0:
      self.activePhases = np.delete(self.activePhases, inactivatedIndices,
                                    axis=0)

    activated = np.setdiff1d(sensorySupportedCells, self.activeCells)

    # Find centers of point clouds
    if "corners" in self.anchoringMethod:
      activatedCoordsBase = np.transpose(
        np.unravel_index(sensorySupportedCells,
                         self.cellDimensions)).astype('float')
    else:
      activatedCoordsBase = np.transpose(
        np.unravel_index(activated, self.cellDimensions)).astype('float')

    # Generate points to add
    activatedCoords = np.concatenate(
      [activatedCoordsBase + [iOffset, jOffset]
       for iOffset in self.cellCoordinateOffsets
       for jOffset in self.cellCoordinateOffsets]
    )
    if "corners" in self.anchoringMethod:
      self.activePhases = activatedCoords / self.cellDimensions

    else:
      if activatedCoords.size > 0:
        self.activePhases = np.append(self.activePhases,
                                      activatedCoords / self.cellDimensions,
                                      axis=0)

    self._computeActiveCells()
    self.activeSegments = activeSegments


  def _sensoryComputeLearningMode(self, anchorInput):
    """
    Associate this location with a sensory input. Subsequently, anchorInput will
    activate the current location during anchor().

    @param anchorInput (numpy array)
    A sensory input. This will often come from a feature-location pair layer.
    """
    overlaps = self.connections.computeActivity(anchorInput,
                                                self.connectedPermanence)
    activeSegments = np.where(overlaps >= self.activationThreshold)[0]

    potentialOverlaps = self.connections.computeActivity(anchorInput)
    matchingSegments = np.where(potentialOverlaps >=
                                self.learningThreshold)[0]

    # Cells with a active segment: reinforce the segment
    cellsForActiveSegments = self.connections.mapSegmentsToCells(
      activeSegments)
    learningActiveSegments = activeSegments[
      np.in1d(cellsForActiveSegments, self.activeCells)]
    remainingCells = np.setdiff1d(self.activeCells, cellsForActiveSegments)

    # Remaining cells with a matching segment: reinforce the best
    # matching segment.
    candidateSegments = self.connections.filterSegmentsByCell(
      matchingSegments, remainingCells)
    cellsForCandidateSegments = (
      self.connections.mapSegmentsToCells(candidateSegments))
    candidateSegments = candidateSegments[
      np.in1d(cellsForCandidateSegments, remainingCells)]
    onePerCellFilter = np2.argmaxMulti(potentialOverlaps[candidateSegments],
                                       cellsForCandidateSegments)
    learningMatchingSegments = candidateSegments[onePerCellFilter]

    newSegmentCells = np.setdiff1d(remainingCells, cellsForCandidateSegments)

    for learningSegments in (learningActiveSegments,
                             learningMatchingSegments):
      self._learn(self.connections, self.rng, learningSegments,
                  anchorInput, potentialOverlaps,
                  self.initialPermanence, self.sampleSize,
                  self.permanenceIncrement, self.permanenceDecrement,
                  self.maxSynapsesPerSegment)

    # Remaining cells without a matching segment: grow one.
    numNewSynapses = len(anchorInput)

    if self.sampleSize != -1:
      numNewSynapses = min(numNewSynapses, self.sampleSize)

    if self.maxSynapsesPerSegment != -1:
      numNewSynapses = min(numNewSynapses, self.maxSynapsesPerSegment)

    newSegments = self.connections.createSegments(newSegmentCells)

    self.connections.growSynapsesToSample(
      newSegments, anchorInput, numNewSynapses,
      self.initialPermanence, self.rng)
    self.activeSegments = activeSegments


  def sensoryCompute(self, anchorInput, anchorGrowthCandidates, learn):
    if learn:
      self._sensoryComputeLearningMode(anchorGrowthCandidates)
    else:
      self._sensoryComputeInferenceMode(anchorInput)


  @staticmethod
  def _learn(connections, rng, learningSegments, activeInput,
             potentialOverlaps, initialPermanence, sampleSize,
             permanenceIncrement, permanenceDecrement, maxSynapsesPerSegment):
    """
    Adjust synapse permanences, grow new synapses, and grow new segments.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param segmentsToPunish (numpy array)
    @param activeInput (numpy array)
    @param potentialOverlaps (numpy array)
    """
    # Learn on existing segments
    connections.adjustSynapses(learningSegments, activeInput,
                               permanenceIncrement, -permanenceDecrement)

    # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
    # grow per segment. "maxNew" might be a number or it might be a list of
    # numbers.
    if sampleSize == -1:
      maxNew = len(activeInput)
    else:
      maxNew = sampleSize - potentialOverlaps[learningSegments]

    if maxSynapsesPerSegment != -1:
      synapseCounts = connections.mapSegmentsToSynapseCounts(
        learningSegments)
      numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
      maxNew = np.where(maxNew <= numSynapsesToReachMax,
                        maxNew, numSynapsesToReachMax)

    connections.growSynapsesToSample(learningSegments, activeInput,
                                     maxNew, initialPermanence, rng)


  def getActiveCells(self):
    return self.activeCells


  def numberOfCells(self):
    return np.prod(self.cellDimensions)
class SensorToSpecificObjectModule(object):
    """
  Represents the sensor location relative to a specific object. Typically
  these modules are arranged in an array, and the combined population SDR is
  used to predict a feature-location pair.

  This class has two sets of connections. Both of them compute the "sensor's
  location relative to a specific object" in different ways.

  The "metric connections" compute it from the
    "body's location relative to a specific object"
  and the
    "sensor's location relative to body"
  These connections are learned once and then never need to be updated. They
  might be genetically hardcoded. They're initialized externally, e.g. in
  BodyToSpecificObjectModule2D.

  The "anchor connections" compute it from the sensory input. Whenever a
  cortical column learns a feature-location pair, this layer forms reciprocal
  connections with the feature-location pair layer.

  These segments receive input at different times. The metric connections
  receive input first, and they activate a set of cells. This set of cells is
  used externally to predict a feature-location pair. Then this feature-location
  pair is the input to the anchor connections.
  """
    def __init__(self,
                 cellDimensions,
                 anchorInputSize,
                 activationThreshold=10,
                 initialPermanence=0.21,
                 connectedPermanence=0.50,
                 learningThreshold=10,
                 sampleSize=20,
                 permanenceIncrement=0.1,
                 permanenceDecrement=0.0,
                 maxSynapsesPerSegment=-1,
                 seed=42):
        """
    @param cellDimensions (sequence of ints)
    @param anchorInputSize (int)
    @param activationThreshold (int)
    """
        self.activationThreshold = activationThreshold
        self.initialPermanence = initialPermanence
        self.connectedPermanence = connectedPermanence
        self.learningThreshold = learningThreshold
        self.sampleSize = sampleSize
        self.permanenceIncrement = permanenceIncrement
        self.permanenceDecrement = permanenceDecrement
        self.activationThreshold = activationThreshold
        self.maxSynapsesPerSegment = maxSynapsesPerSegment

        self.rng = Random(seed)

        self.cellCount = np.prod(cellDimensions)
        cellCountBySource = {
            "bodyToSpecificObject": self.cellCount,
            "sensorToBody": self.cellCount,
        }
        self.metricConnections = Multiconnections(self.cellCount,
                                                  cellCountBySource)
        self.anchorConnections = SparseMatrixConnections(
            self.cellCount, anchorInputSize)

    def reset(self):
        self.activeCells = np.empty(0, dtype="int")

    def metricCompute(self, sensorToBody, bodyToSpecificObject):
        """
    Compute the
      "sensor's location relative to a specific object"
    from the
      "body's location relative to a specific object"
    and the
      "sensor's location relative to body"

    @param sensorToBody (numpy array)
    Active cells of a single module that represents the sensor's location
    relative to the body

    @param bodyToSpecificObject (numpy array)
    Active cells of a single module that represents the body's location relative
    to a specific object
    """
        overlaps = self.metricConnections.computeActivity({
            "bodyToSpecificObject":
            bodyToSpecificObject,
            "sensorToBody":
            sensorToBody,
        })

        self.activeMetricSegments = np.where(overlaps >= 2)[0]
        self.activeCells = np.unique(
            self.metricConnections.mapSegmentsToCells(
                self.activeMetricSegments))

    def anchorCompute(self, anchorInput, learn):
        """
    Compute the
      "sensor's location relative to a specific object"
    from the feature-location pair.

    @param anchorInput (numpy array)
    Active cells in the feature-location pair layer

    @param learn (bool)
    If true, maintain current cell activity and learn this input on the
    currently active cells
    """
        if learn:
            self._anchorComputeLearningMode(anchorInput)
        else:
            overlaps = self.anchorConnections.computeActivity(
                anchorInput, self.connectedPermanence)

            self.activeSegments = np.where(
                overlaps >= self.activationThreshold)[0]
            self.activeCells = np.unique(
                self.anchorConnections.mapSegmentsToCells(self.activeSegments))

    def _anchorComputeLearningMode(self, anchorInput):
        """
    Associate this location with a sensory input. Subsequently, anchorInput will
    activate the current location during anchor().

    @param anchorInput (numpy array)
    A sensory input. This will often come from a feature-location pair layer.
    """

        overlaps = self.anchorConnections.computeActivity(
            anchorInput, self.connectedPermanence)

        activeSegments = np.where(overlaps >= self.activationThreshold)[0]

        potentialOverlaps = self.anchorConnections.computeActivity(anchorInput)
        matchingSegments = np.where(
            potentialOverlaps >= self.learningThreshold)[0]

        # Cells with a active segment: reinforce the segment
        cellsForActiveSegments = self.anchorConnections.mapSegmentsToCells(
            activeSegments)
        learningActiveSegments = activeSegments[np.in1d(
            cellsForActiveSegments, self.activeCells)]
        remainingCells = np.setdiff1d(self.activeCells, cellsForActiveSegments)

        # Remaining cells with a matching segment: reinforce the best
        # matching segment.
        candidateSegments = self.anchorConnections.filterSegmentsByCell(
            matchingSegments, remainingCells)
        cellsForCandidateSegments = (
            self.anchorConnections.mapSegmentsToCells(candidateSegments))
        candidateSegments = candidateSegments[np.in1d(
            cellsForCandidateSegments, remainingCells)]
        onePerCellFilter = np2.argmaxMulti(
            potentialOverlaps[candidateSegments], cellsForCandidateSegments)
        learningMatchingSegments = candidateSegments[onePerCellFilter]

        newSegmentCells = np.setdiff1d(remainingCells,
                                       cellsForCandidateSegments)

        for learningSegments in (learningActiveSegments,
                                 learningMatchingSegments):
            self._learn(self.anchorConnections, self.rng, learningSegments,
                        anchorInput, potentialOverlaps, self.initialPermanence,
                        self.sampleSize, self.permanenceIncrement,
                        self.permanenceDecrement, self.maxSynapsesPerSegment)

        # Remaining cells without a matching segment: grow one.
        numNewSynapses = len(anchorInput)

        if self.sampleSize != -1:
            numNewSynapses = min(numNewSynapses, self.sampleSize)

        if self.maxSynapsesPerSegment != -1:
            numNewSynapses = min(numNewSynapses, self.maxSynapsesPerSegment)

        newSegments = self.anchorConnections.createSegments(newSegmentCells)

        self.anchorConnections.growSynapsesToSample(newSegments, anchorInput,
                                                    numNewSynapses,
                                                    self.initialPermanence,
                                                    self.rng)

        self.activeSegments = activeSegments

    @staticmethod
    def _learn(connections, rng, learningSegments, activeInput,
               potentialOverlaps, initialPermanence, sampleSize,
               permanenceIncrement, permanenceDecrement,
               maxSynapsesPerSegment):
        """
    Adjust synapse permanences, grow new synapses, and grow new segments.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param segmentsToPunish (numpy array)
    @param activeInput (numpy array)
    @param potentialOverlaps (numpy array)
    """
        # Learn on existing segments
        connections.adjustSynapses(learningSegments, activeInput,
                                   permanenceIncrement, -permanenceDecrement)

        # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
        # grow per segment. "maxNew" might be a number or it might be a list of
        # numbers.
        if sampleSize == -1:
            maxNew = len(activeInput)
        else:
            maxNew = sampleSize - potentialOverlaps[learningSegments]

        if maxSynapsesPerSegment != -1:
            synapseCounts = connections.mapSegmentsToSynapseCounts(
                learningSegments)
            numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
            maxNew = np.where(maxNew <= numSynapsesToReachMax, maxNew,
                              numSynapsesToReachMax)

        connections.growSynapsesToSample(learningSegments, activeInput, maxNew,
                                         initialPermanence, rng)

    def getActiveCells(self):
        return self.activeCells
class Superficial2DLocationModule(object):
    """
  A model of a location module. It's similar to a grid cell module, but it uses
  squares rather than triangles.

  The cells are arranged into a m*n rectangle which is tiled onto 2D space.
  Each cell represents a small rectangle in each tile.

  +------+------+------++------+------+------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #1  |  #2  |  #3  ||  #1  |  #2  |  #3  |
  |      |      |      ||      |      |      |
  +--------------------++--------------------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #4  |  #5  |  #6  ||  #4  |  #5  |  #6  |
  |      |      |      ||      |      |      |
  +--------------------++--------------------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #7  |  #8  |  #9  ||  #7  |  #8  |  #9  |
  |      |      |      ||      |      |      |
  +------+------+------++------+------+------+

  We assume that path integration works *somehow*. This model receives a "delta
  location" vector, and it shifts the active cells accordingly. The model stores
  intermediate coordinates of active cells. Whenever sensory cues activate a
  cell, the model adds this cell to the list of coordinates being shifted.
  Whenever sensory cues cause a cell to become inactive, that cell is removed
  from the list of coordinates.

  (This model doesn't attempt to propose how "path integration" works. It
  attempts to show how locations are anchored to sensory cues.)

  When orientation is set to 0 degrees, the displacement is a [di, dj],
  moving di cells "down" and dj cells "right".

  When orientation is set to 90 degrees, the displacement is essentially a
  [dx, dy], applied in typical x,y coordinates with the origin on the bottom
  left.

  Usage:
  - When the sensor moves, call movementCompute.
  - When the sensor senses something, call sensoryCompute.

  The "anchor input" is typically a feature-location pair SDR.

  To specify how points are tracked, pass anchoringMethod = "corners",
  "narrowing" or "discrete".  "discrete" will cause the network to operate in a
  fully discrete space, where uncertainty is impossible as long as movements are
  integers.  "narrowing" is designed to narrow down uncertainty of initial
  locations of sensory stimuli.  "corners" is designed for noise-tolerance, and
  will activate all cells that are possible outcomes of path integration.
  """
    def __init__(self,
                 cellDimensions,
                 moduleMapDimensions,
                 orientation,
                 anchorInputSize,
                 cellCoordinateOffsets=(0.5, ),
                 activationThreshold=10,
                 initialPermanence=0.21,
                 connectedPermanence=0.50,
                 learningThreshold=10,
                 sampleSize=20,
                 permanenceIncrement=0.1,
                 permanenceDecrement=0.0,
                 maxSynapsesPerSegment=-1,
                 anchoringMethod="narrowing",
                 rotationMatrix=None,
                 seed=42):
        """
    @param cellDimensions (tuple(int, int))
    Determines the number of cells. Determines how space is divided between the
    cells.

    @param moduleMapDimensions (tuple(float, float))
    Determines the amount of world space covered by all of the cells combined.
    In grid cell terminology, this is equivalent to the "scale" of a module.
    A module with a scale of "5cm" would have moduleMapDimensions=(5.0, 5.0).

    @param orientation (float)
    The rotation of this map, measured in radians.

    @param anchorInputSize (int)
    The number of input bits in the anchor input.

    @param cellCoordinateOffsets (list of floats)
    These must each be between 0.0 and 1.0. Every time a cell is activated by
    anchor input, this class adds a "phase" which is shifted in subsequent
    motions. By default, this phase is placed at the center of the cell. This
    parameter allows you to control where the point is placed and whether multiple
    are placed. For example, with value [0.2, 0.8], when cell [2, 3] is activated
    it will place 4 phases, corresponding to the following points in cell
    coordinates: [2.2, 3.2], [2.2, 3.8], [2.8, 3.2], [2.8, 3.8]
    """

        self.cellDimensions = np.asarray(cellDimensions, dtype="int")
        self.moduleMapDimensions = np.asarray(moduleMapDimensions,
                                              dtype="float")
        self.phasesPerUnitDistance = 1.0 / self.moduleMapDimensions

        if rotationMatrix is None:
            self.orientation = orientation
            self.rotationMatrix = np.array(
                [[math.cos(orientation), -math.sin(orientation)],
                 [math.sin(orientation),
                  math.cos(orientation)]])
            if anchoringMethod == "discrete":
                # Need to convert matrix to have integer values
                nonzeros = self.rotationMatrix[np.where(
                    np.abs(self.rotationMatrix) > 0)]
                smallestValue = np.amin(nonzeros)
                self.rotationMatrix /= smallestValue
                self.rotationMatrix = np.ceil(self.rotationMatrix)
        else:
            self.rotationMatrix = rotationMatrix

        self.cellCoordinateOffsets = cellCoordinateOffsets

        # Phase is measured as a number in the range [0.0, 1.0)
        self.activePhases = np.empty((0, 2), dtype="float")
        self.cellsForActivePhases = np.empty(0, dtype="int")
        self.phaseDisplacement = np.empty((0, 2), dtype="float")

        self.activeCells = np.empty(0, dtype="int")
        self.activeSegments = np.empty(0, dtype="uint32")

        self.connections = SparseMatrixConnections(np.prod(cellDimensions),
                                                   anchorInputSize)

        self.initialPermanence = initialPermanence
        self.connectedPermanence = connectedPermanence
        self.learningThreshold = learningThreshold
        self.sampleSize = sampleSize
        self.permanenceIncrement = permanenceIncrement
        self.permanenceDecrement = permanenceDecrement
        self.activationThreshold = activationThreshold
        self.maxSynapsesPerSegment = maxSynapsesPerSegment

        self.anchoringMethod = anchoringMethod

        self.rng = Random(seed)

    def reset(self):
        """
    Clear the active cells.
    """
        self.activePhases = np.empty((0, 2), dtype="float")
        self.phaseDisplacement = np.empty((0, 2), dtype="float")
        self.cellsForActivePhases = np.empty(0, dtype="int")
        self.activeCells = np.empty(0, dtype="int")

    def _computeActiveCells(self):
        # Round each coordinate to the nearest cell.
        activeCellCoordinates = np.floor(self.activePhases *
                                         self.cellDimensions).astype("int")

        # Convert coordinates to cell numbers.
        self.cellsForActivePhases = (np.ravel_multi_index(
            activeCellCoordinates.T, self.cellDimensions))
        self.activeCells = np.unique(self.cellsForActivePhases)

    def activateRandomLocation(self):
        """
    Set the location to a random point.
    """
        self.activePhases = np.array([np.random.random(2)])
        if self.anchoringMethod == "discrete":
            # Need to place the phase in the middle of a cell
            self.activePhases = np.floor(
                self.activePhases * self.cellDimensions) / self.cellDimensions
        self._computeActiveCells()

    def movementCompute(self, displacement, noiseFactor=0):
        """
    Shift the current active cells by a vector.

    @param displacement (pair of floats)
    A translation vector [di, dj].
    """

        if noiseFactor != 0:
            displacement = copy.deepcopy(displacement)
            xnoise = np.random.normal(0, noiseFactor)
            ynoise = np.random.normal(0, noiseFactor)
            displacement[0] += xnoise
            displacement[1] += ynoise

        # Calculate delta in the module's coordinates.
        phaseDisplacement = (np.matmul(self.rotationMatrix, displacement) *
                             self.phasesPerUnitDistance)

        # Shift the active coordinates.
        np.add(self.activePhases, phaseDisplacement, out=self.activePhases)

        # In Python, (x % 1.0) can return 1.0 because of floating point goofiness.
        # Generally this doesn't cause problems, it's just confusing when you're
        # debugging.
        np.round(self.activePhases, decimals=9, out=self.activePhases)
        np.mod(self.activePhases, 1.0, out=self.activePhases)

        self._computeActiveCells()
        self.phaseDisplacement = phaseDisplacement

    def _sensoryComputeInferenceMode(self, anchorInput):
        """
    Infer the location from sensory input. Activate any cells with enough active
    synapses to this sensory input. Deactivate all other cells.

    @param anchorInput (numpy array)
    A sensory input. This will often come from a feature-location pair layer.
    """
        if len(anchorInput) == 0:
            return

        overlaps = self.connections.computeActivity(anchorInput,
                                                    self.connectedPermanence)
        activeSegments = np.where(overlaps >= self.activationThreshold)[0]

        sensorySupportedCells = np.unique(
            self.connections.mapSegmentsToCells(activeSegments))

        inactivated = np.setdiff1d(self.activeCells, sensorySupportedCells)
        inactivatedIndices = np.in1d(self.cellsForActivePhases,
                                     inactivated).nonzero()[0]
        if inactivatedIndices.size > 0:
            self.activePhases = np.delete(self.activePhases,
                                          inactivatedIndices,
                                          axis=0)

        activated = np.setdiff1d(sensorySupportedCells, self.activeCells)

        # Find centers of point clouds
        if "corners" in self.anchoringMethod:
            activatedCoordsBase = np.transpose(
                np.unravel_index(sensorySupportedCells,
                                 self.cellDimensions)).astype('float')
        else:
            activatedCoordsBase = np.transpose(
                np.unravel_index(activated,
                                 self.cellDimensions)).astype('float')

        # Generate points to add
        activatedCoords = np.concatenate([
            activatedCoordsBase + [iOffset, jOffset]
            for iOffset in self.cellCoordinateOffsets
            for jOffset in self.cellCoordinateOffsets
        ])
        if "corners" in self.anchoringMethod:
            self.activePhases = activatedCoords / self.cellDimensions

        else:
            if activatedCoords.size > 0:
                self.activePhases = np.append(self.activePhases,
                                              activatedCoords /
                                              self.cellDimensions,
                                              axis=0)

        self._computeActiveCells()
        self.activeSegments = activeSegments

    def _sensoryComputeLearningMode(self, anchorInput):
        """
    Associate this location with a sensory input. Subsequently, anchorInput will
    activate the current location during anchor().

    @param anchorInput (numpy array)
    A sensory input. This will often come from a feature-location pair layer.
    """
        overlaps = self.connections.computeActivity(anchorInput,
                                                    self.connectedPermanence)
        activeSegments = np.where(overlaps >= self.activationThreshold)[0]

        potentialOverlaps = self.connections.computeActivity(anchorInput)
        matchingSegments = np.where(
            potentialOverlaps >= self.learningThreshold)[0]

        # Cells with a active segment: reinforce the segment
        cellsForActiveSegments = self.connections.mapSegmentsToCells(
            activeSegments)
        learningActiveSegments = activeSegments[np.in1d(
            cellsForActiveSegments, self.activeCells)]
        remainingCells = np.setdiff1d(self.activeCells, cellsForActiveSegments)

        # Remaining cells with a matching segment: reinforce the best
        # matching segment.
        candidateSegments = self.connections.filterSegmentsByCell(
            matchingSegments, remainingCells)
        cellsForCandidateSegments = (
            self.connections.mapSegmentsToCells(candidateSegments))
        candidateSegments = candidateSegments[np.in1d(
            cellsForCandidateSegments, remainingCells)]
        onePerCellFilter = np2.argmaxMulti(
            potentialOverlaps[candidateSegments], cellsForCandidateSegments)
        learningMatchingSegments = candidateSegments[onePerCellFilter]

        newSegmentCells = np.setdiff1d(remainingCells,
                                       cellsForCandidateSegments)

        for learningSegments in (learningActiveSegments,
                                 learningMatchingSegments):
            self._learn(self.connections, self.rng, learningSegments,
                        anchorInput, potentialOverlaps, self.initialPermanence,
                        self.sampleSize, self.permanenceIncrement,
                        self.permanenceDecrement, self.maxSynapsesPerSegment)

        # Remaining cells without a matching segment: grow one.
        numNewSynapses = len(anchorInput)

        if self.sampleSize != -1:
            numNewSynapses = min(numNewSynapses, self.sampleSize)

        if self.maxSynapsesPerSegment != -1:
            numNewSynapses = min(numNewSynapses, self.maxSynapsesPerSegment)

        newSegments = self.connections.createSegments(newSegmentCells)

        self.connections.growSynapsesToSample(newSegments, anchorInput,
                                              numNewSynapses,
                                              self.initialPermanence, self.rng)
        self.activeSegments = activeSegments

    def sensoryCompute(self, anchorInput, anchorGrowthCandidates, learn):
        if learn:
            self._sensoryComputeLearningMode(anchorGrowthCandidates)
        else:
            self._sensoryComputeInferenceMode(anchorInput)

    @staticmethod
    def _learn(connections, rng, learningSegments, activeInput,
               potentialOverlaps, initialPermanence, sampleSize,
               permanenceIncrement, permanenceDecrement,
               maxSynapsesPerSegment):
        """
    Adjust synapse permanences, grow new synapses, and grow new segments.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param segmentsToPunish (numpy array)
    @param activeInput (numpy array)
    @param potentialOverlaps (numpy array)
    """
        # Learn on existing segments
        connections.adjustSynapses(learningSegments, activeInput,
                                   permanenceIncrement, -permanenceDecrement)

        # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
        # grow per segment. "maxNew" might be a number or it might be a list of
        # numbers.
        if sampleSize == -1:
            maxNew = len(activeInput)
        else:
            maxNew = sampleSize - potentialOverlaps[learningSegments]

        if maxSynapsesPerSegment != -1:
            synapseCounts = connections.mapSegmentsToSynapseCounts(
                learningSegments)
            numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
            maxNew = np.where(maxNew <= numSynapsesToReachMax, maxNew,
                              numSynapsesToReachMax)

        connections.growSynapsesToSample(learningSegments, activeInput, maxNew,
                                         initialPermanence, rng)

    def getActiveCells(self):
        return self.activeCells

    def numberOfCells(self):
        return np.prod(self.cellDimensions)
class SingleLayerLocationMemory(object):
  """
  A layer of cells which learns how to take a "delta location" (e.g. a motor
  command or a proprioceptive delta) and update its active cells to represent
  the new location.

  Its active cells might represent a union of locations.
  As the location changes, the featureLocationInput causes this union to narrow
  down until the location is inferred.

  This layer receives absolute proprioceptive info as proximal input.
  For now, we assume that there's a one-to-one mapping between absolute
  proprioceptive input and the location SDR. So rather than modeling
  proximal synapses, we'll just relay the proprioceptive SDR. In the future
  we might want to consider a many-to-one mapping of proprioceptive inputs
  to location SDRs.

  After this layer is trained, it no longer needs the proprioceptive input.
  The delta location will drive the layer. The current active cells and the
  other distal connections will work together with this delta location to
  activate a new set of cells.

  When no cells are active, activate a large union of possible locations.
  With subsequent inputs, the union will narrow down to a single location SDR.
  """

  def __init__(self,
               cellCount,
               deltaLocationInputSize,
               featureLocationInputSize,
               activationThreshold=13,
               initialPermanence=0.21,
               connectedPermanence=0.50,
               learningThreshold=10,
               sampleSize=20,
               permanenceIncrement=0.1,
               permanenceDecrement=0.1,
               maxSynapsesPerSegment=-1,
               seed=42):

    # For transition learning, every segment is split into two parts.
    # For the segment to be active, both parts must be active.
    self.internalConnections = SparseMatrixConnections(
      cellCount, cellCount)
    self.deltaConnections = SparseMatrixConnections(
      cellCount, deltaLocationInputSize)

    # Distal segments that receive input from the layer that represents
    # feature-locations.
    self.featureLocationConnections = SparseMatrixConnections(
      cellCount, featureLocationInputSize)

    self.activeCells = np.empty(0, dtype="uint32")
    self.activeDeltaSegments = np.empty(0, dtype="uint32")
    self.activeFeatureLocationSegments = np.empty(0, dtype="uint32")

    self.initialPermanence = initialPermanence
    self.connectedPermanence = connectedPermanence
    self.learningThreshold = learningThreshold
    self.sampleSize = sampleSize
    self.permanenceIncrement = permanenceIncrement
    self.permanenceDecrement = permanenceDecrement
    self.activationThreshold = activationThreshold
    self.maxSynapsesPerSegment = maxSynapsesPerSegment

    self.rng = Random(seed)


  def reset(self):
    """
    Deactivate all cells.
    """

    self.activeCells = np.empty(0, dtype="uint32")
    self.activeDeltaSegments = np.empty(0, dtype="uint32")
    self.activeFeatureLocationSegments = np.empty(0, dtype="uint32")


  def compute(self, deltaLocation=(), newLocation=(),
              featureLocationInput=(), featureLocationGrowthCandidates=(),
              learn=True):
    """
    Run one time step of the Location Memory algorithm.

    @param deltaLocation (sorted numpy array)
    @param newLocation (sorted numpy array)
    @param featureLocationInput (sorted numpy array)
    @param featureLocationGrowthCandidates (sorted numpy array)
    """
    prevActiveCells = self.activeCells

    self.activeDeltaSegments = np.where(
      (self.internalConnections.computeActivity(
        prevActiveCells, self.connectedPermanence
      ) >= self.activationThreshold)
      &
      (self.deltaConnections.computeActivity(
        deltaLocation, self.connectedPermanence
      ) >= self.activationThreshold))[0]

    # When we're moving, the feature-location input has no effect.
    if len(deltaLocation) == 0:
      self.activeFeatureLocationSegments = np.where(
        self.featureLocationConnections.computeActivity(
          featureLocationInput, self.connectedPermanence
        ) >= self.activationThreshold)[0]
    else:
      self.activeFeatureLocationSegments = np.empty(0, dtype="uint32")


    if len(newLocation) > 0:
      # Drive activations by relaying this location SDR.
      self.activeCells = newLocation

      if learn:
        # Learn the delta.
        self._learnTransition(prevActiveCells, deltaLocation, newLocation)

        # Learn the featureLocationInput.
        self._learnFeatureLocationPair(newLocation, featureLocationInput,
                                       featureLocationGrowthCandidates)


    elif len(prevActiveCells) > 0:
      if len(deltaLocation) > 0:
        # Drive activations by applying the deltaLocation to the current location.
        # Completely ignore the featureLocationInput. It's outdated, associated
        # with the previous location.

        cellsForDeltaSegments = self.internalConnections.mapSegmentsToCells(
          self.activeDeltaSegments)

        self.activeCells = np.unique(cellsForDeltaSegments)
      else:
        # Keep previous active cells active.
        # Modulate with the featureLocationInput.

        if len(self.activeFeatureLocationSegments) > 0:

          cellsForFeatureLocationSegments = (
            self.featureLocationConnections.mapSegmentsToCells(
              self.activeFeatureLocationSegments))
          self.activeCells = np.intersect1d(prevActiveCells,
                                            cellsForFeatureLocationSegments)
        else:
          self.activeCells = prevActiveCells

    elif len(featureLocationInput) > 0:
      # Drive activations with the featureLocationInput.

      cellsForFeatureLocationSegments = (
        self.featureLocationConnections.mapSegmentsToCells(
          self.activeFeatureLocationSegments))

      self.activeCells = np.unique(cellsForFeatureLocationSegments)


  def _learnTransition(self, prevActiveCells, deltaLocation, newLocation):
    """
    For each cell in the newLocation SDR, learn the transition of prevLocation
    (i.e. prevActiveCells) + deltaLocation.

    The transition might be already known. In that case, just reinforce the
    existing segments.
    """

    prevLocationPotentialOverlaps = self.internalConnections.computeActivity(
      prevActiveCells)
    deltaPotentialOverlaps = self.deltaConnections.computeActivity(
      deltaLocation)

    matchingDeltaSegments = np.where(
      (prevLocationPotentialOverlaps >= self.learningThreshold) &
      (deltaPotentialOverlaps >= self.learningThreshold))[0]

    # Cells with a active segment pair: reinforce the segment
    cellsForActiveSegments = self.internalConnections.mapSegmentsToCells(
      self.activeDeltaSegments)
    learningActiveDeltaSegments = self.activeDeltaSegments[
      np.in1d(cellsForActiveSegments, newLocation)]
    remainingCells = np.setdiff1d(newLocation, cellsForActiveSegments)

    # Remaining cells with a matching segment pair: reinforce the best matching
    # segment pair.
    candidateSegments = self.internalConnections.filterSegmentsByCell(
      matchingDeltaSegments, remainingCells)
    cellsForCandidateSegments = self.internalConnections.mapSegmentsToCells(
      candidateSegments)
    candidateSegments = matchingDeltaSegments[
      np.in1d(cellsForCandidateSegments, remainingCells)]
    onePerCellFilter = np2.argmaxMulti(
      prevLocationPotentialOverlaps[candidateSegments] +
      deltaPotentialOverlaps[candidateSegments],
      cellsForCandidateSegments)
    learningMatchingDeltaSegments = candidateSegments[onePerCellFilter]

    newDeltaSegmentCells = np.setdiff1d(remainingCells, cellsForCandidateSegments)

    for learningSegments in (learningActiveDeltaSegments,
                             learningMatchingDeltaSegments):
      self._learn(self.internalConnections, self.rng, learningSegments,
                  prevActiveCells, prevActiveCells,
                  prevLocationPotentialOverlaps,
                  self.initialPermanence, self.sampleSize,
                  self.permanenceIncrement, self.permanenceDecrement,
                  self.maxSynapsesPerSegment)
      self._learn(self.deltaConnections, self.rng, learningSegments,
                  deltaLocation, deltaLocation, deltaPotentialOverlaps,
                  self.initialPermanence, self.sampleSize,
                  self.permanenceIncrement, self.permanenceDecrement,
                  self.maxSynapsesPerSegment)

    numNewLocationSynapses = len(prevActiveCells)
    numNewDeltaSynapses = len(deltaLocation)

    if self.sampleSize != -1:
      numNewLocationSynapses = min(numNewLocationSynapses, self.sampleSize)
      numNewDeltaSynapses = min(numNewDeltaSynapses, self.sampleSize)

    if self.maxSynapsesPerSegment != -1:
      numNewLocationSynapses = min(numNewLocationSynapses,
                                   self.maxSynapsesPerSegment)
      numNewDeltaSynapses = min(numNewLocationSynapses,
                                self.maxSynapsesPerSegment)

    newPrevLocationSegments = self.internalConnections.createSegments(
      newDeltaSegmentCells)
    newDeltaSegments = self.deltaConnections.createSegments(
      newDeltaSegmentCells)

    assert np.array_equal(newPrevLocationSegments, newDeltaSegments)

    self.internalConnections.growSynapsesToSample(
      newPrevLocationSegments, prevActiveCells, numNewLocationSynapses,
      self.initialPermanence, self.rng)
    self.deltaConnections.growSynapsesToSample(
      newDeltaSegments, deltaLocation, numNewDeltaSynapses,
      self.initialPermanence, self.rng)


  def _learnFeatureLocationPair(self, newLocation, featureLocationInput,
                                featureLocationGrowthCandidates):
    """
    Grow / reinforce synapses between the location layer's dendrites and the
    input layer's active cells.
    """

    potentialOverlaps = self.featureLocationConnections.computeActivity(
      featureLocationInput)
    matchingSegments = np.where(potentialOverlaps > self.learningThreshold)[0]

    # Cells with a active segment pair: reinforce the segment
    cellsForActiveSegments = self.featureLocationConnections.mapSegmentsToCells(
      self.activeFeatureLocationSegments)
    learningActiveSegments = self.activeFeatureLocationSegments[
      np.in1d(cellsForActiveSegments, newLocation)]
    remainingCells = np.setdiff1d(newLocation, cellsForActiveSegments)

    # Remaining cells with a matching segment pair: reinforce the best matching
    # segment pair.
    candidateSegments = self.featureLocationConnections.filterSegmentsByCell(
      matchingSegments, remainingCells)
    cellsForCandidateSegments = (
      self.featureLocationConnections.mapSegmentsToCells(
        candidateSegments))
    candidateSegments = candidateSegments[
      np.in1d(cellsForCandidateSegments, remainingCells)]
    onePerCellFilter = np2.argmaxMulti(potentialOverlaps[candidateSegments],
                                       cellsForCandidateSegments)
    learningMatchingSegments = candidateSegments[onePerCellFilter]

    newSegmentCells = np.setdiff1d(remainingCells, cellsForCandidateSegments)

    for learningSegments in (learningActiveSegments,
                             learningMatchingSegments):
      self._learn(self.featureLocationConnections, self.rng, learningSegments,
                  featureLocationInput, featureLocationGrowthCandidates,
                  potentialOverlaps,
                  self.initialPermanence, self.sampleSize,
                  self.permanenceIncrement, self.permanenceDecrement,
                  self.maxSynapsesPerSegment)

    numNewSynapses = len(featureLocationInput)

    if self.sampleSize != -1:
      numNewSynapses = min(numNewSynapses, self.sampleSize)

    if self.maxSynapsesPerSegment != -1:
      numNewSynapses = min(numNewSynapses, self.maxSynapsesPerSegment)

    newSegments = self.featureLocationConnections.createSegments(
      newSegmentCells)

    self.featureLocationConnections.growSynapsesToSample(
      newSegments, featureLocationGrowthCandidates, numNewSynapses,
      self.initialPermanence, self.rng)



  @staticmethod
  def _learn(connections, rng, learningSegments, activeInput, growthCandidates,
             potentialOverlaps, initialPermanence, sampleSize,
             permanenceIncrement, permanenceDecrement, maxSynapsesPerSegment):
    """
    Adjust synapse permanences, grow new synapses, and grow new segments.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param segmentsToPunish (numpy array)
    @param activeInput (numpy array)
    @param growthCandidates (numpy array)
    @param potentialOverlaps (numpy array)
    """

    # Learn on existing segments
    connections.adjustSynapses(learningSegments, activeInput,
                               permanenceIncrement, -permanenceDecrement)

    # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
    # grow per segment. "maxNew" might be a number or it might be a list of
    # numbers.
    if sampleSize == -1:
      maxNew = len(growthCandidates)
    else:
      maxNew = sampleSize - potentialOverlaps[learningSegments]

    if maxSynapsesPerSegment != -1:
      synapseCounts = connections.mapSegmentsToSynapseCounts(
        learningSegments)
      numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
      maxNew = np.where(maxNew <= numSynapsesToReachMax,
                        maxNew, numSynapsesToReachMax)

    connections.growSynapsesToSample(learningSegments, growthCandidates,
                                     maxNew, initialPermanence, rng)


  def getActiveCells(self):
    return self.activeCells
class TemporalMemory(object):
  """
  TemporalMemory with basal and apical connections, and with the ability to
  connect to external cells.

  Basal connections are used to implement traditional Temporal Memory.

  The apical connections are used for further disambiguation. If multiple cells
  in a minicolumn have active basal segments, each of those cells is predicted,
  unless one of them also has an active apical segment, in which case only the
  cells with active basal and apical segments are predicted.

  In other words, the apical connections have no effect unless the basal input
  is a union of SDRs (e.g. from bursting minicolumns).

  This TemporalMemory is unaware of whether its basalInput or apicalInput are
  from internal or external cells. They are just cell numbers. The caller knows
  what these cell numbers mean, but the TemporalMemory doesn't. This allows the
  same code to work for various algorithms.

  To implement sequence memory, use

    basalInputDimensions=(numColumns*cellsPerColumn,)

  and call compute like this:

    tm.compute(activeColumns, tm.getActiveCells(), tm.getWinnerCells())

  """

  def __init__(self,
               columnDimensions=(2048,),
               basalInputDimensions=(),
               apicalInputDimensions=(),
               cellsPerColumn=32,
               activationThreshold=13,
               initialPermanence=0.21,
               connectedPermanence=0.50,
               minThreshold=10,
               sampleSize=20,
               permanenceIncrement=0.1,
               permanenceDecrement=0.1,
               predictedSegmentDecrement=0.0,
               maxNewSynapseCount=None,
               maxSynapsesPerSegment=-1,
               maxSegmentsPerCell=None,
               seed=42):

    self.columnDimensions = columnDimensions
    self.numColumns = self._numPoints(columnDimensions)
    self.basalInputDimensions = basalInputDimensions
    self.apicalInputDimensions = apicalInputDimensions

    self.cellsPerColumn = cellsPerColumn
    self.initialPermanence = initialPermanence
    self.connectedPermanence = connectedPermanence
    self.minThreshold = minThreshold

    self.sampleSize = sampleSize
    if maxNewSynapseCount is not None:
      print "Parameter 'maxNewSynapseCount' is deprecated. Use 'sampleSize'."
      self.sampleSize = maxNewSynapseCount

    if maxSegmentsPerCell is not None:
      print "Warning: ignoring parameter 'maxSegmentsPerCell'"

    self.permanenceIncrement = permanenceIncrement
    self.permanenceDecrement = permanenceDecrement
    self.predictedSegmentDecrement = predictedSegmentDecrement
    self.activationThreshold = activationThreshold
    self.maxSynapsesPerSegment = maxSynapsesPerSegment

    self.basalConnections = SparseMatrixConnections(
      self.numColumns*cellsPerColumn, self._numPoints(basalInputDimensions))
    self.apicalConnections = SparseMatrixConnections(
      self.numColumns*cellsPerColumn, self._numPoints(apicalInputDimensions))
    self.rng = Random(seed)
    self.activeCells = EMPTY_UINT_ARRAY
    self.winnerCells = EMPTY_UINT_ARRAY
    self.prevPredictedCells = EMPTY_UINT_ARRAY


  def reset(self):
    self.activeCells = EMPTY_UINT_ARRAY
    self.winnerCells = EMPTY_UINT_ARRAY
    self.prevPredictedCells = EMPTY_UINT_ARRAY


  def compute(self,
              activeColumns,
              basalInput,
              basalGrowthCandidates,
              apicalInput=EMPTY_UINT_ARRAY,
              apicalGrowthCandidates=EMPTY_UINT_ARRAY,
              learn=True):
    """
    @param activeColumns (numpy array)
    @param basalInput (numpy array)
    @param basalGrowthCandidates (numpy array)
    @param apicalInput (numpy array)
    @param apicalGrowthCandidates (numpy array)
    @param learn (bool)
    """
    # Calculate predictions for this timestep
    (activeBasalSegments,
     matchingBasalSegments,
     basalPotentialOverlaps) = self._calculateSegmentActivity(
       self.basalConnections, basalInput, self.connectedPermanence,
       self.activationThreshold, self.minThreshold)

    (activeApicalSegments,
     matchingApicalSegments,
     apicalPotentialOverlaps) = self._calculateSegmentActivity(
       self.apicalConnections, apicalInput, self.connectedPermanence,
       self.activationThreshold, self.minThreshold)

    predictedCells = self._calculatePredictedCells(activeBasalSegments,
                                                   activeApicalSegments)

    # Calculate active cells
    (correctPredictedCells,
     burstingColumns) = np2.setCompare(predictedCells, activeColumns,
                                       predictedCells / self.cellsPerColumn,
                                       rightMinusLeft=True)
    newActiveCells = np.concatenate((correctPredictedCells,
                                     np2.getAllCellsInColumns(
                                       burstingColumns, self.cellsPerColumn)))

    # Calculate learning
    (learningActiveBasalSegments,
     learningMatchingBasalSegments,
     basalSegmentsToPunish,
     newBasalSegmentCells,
     learningCells) = self._calculateBasalLearning(
       activeColumns, burstingColumns, correctPredictedCells,
       activeBasalSegments, matchingBasalSegments, basalPotentialOverlaps)

    (learningActiveApicalSegments,
     learningMatchingApicalSegments,
     apicalSegmentsToPunish,
     newApicalSegmentCells) = self._calculateApicalLearning(
       learningCells, activeColumns, activeApicalSegments,
       matchingApicalSegments, apicalPotentialOverlaps)

    # Learn
    if learn:
      # Learn on existing segments
      for learningSegments in (learningActiveBasalSegments,
                               learningMatchingBasalSegments):
        self._learn(self.basalConnections, self.rng, learningSegments,
                    basalInput, basalGrowthCandidates, basalPotentialOverlaps,
                    self.initialPermanence, self.sampleSize,
                    self.permanenceIncrement, self.permanenceDecrement,
                    self.maxSynapsesPerSegment)

      for learningSegments in (learningActiveApicalSegments,
                               learningMatchingApicalSegments):

        self._learn(self.apicalConnections, self.rng, learningSegments,
                    apicalInput, apicalGrowthCandidates,
                    apicalPotentialOverlaps, self.initialPermanence,
                    self.sampleSize, self.permanenceIncrement,
                    self.permanenceDecrement, self.maxSynapsesPerSegment)

      # Punish incorrect predictions
      if self.predictedSegmentDecrement != 0.0:
        self.basalConnections.adjustActiveSynapses(
          basalSegmentsToPunish, basalInput, -self.predictedSegmentDecrement)
        self.apicalConnections.adjustActiveSynapses(
          apicalSegmentsToPunish, apicalInput, -self.predictedSegmentDecrement)

      # Grow new segments
      if len(basalGrowthCandidates) > 0:
        self._learnOnNewSegments(self.basalConnections, self.rng,
                                 newBasalSegmentCells, basalGrowthCandidates,
                                 self.initialPermanence, self.sampleSize,
                                 self.maxSynapsesPerSegment)

      if len(apicalGrowthCandidates) > 0:
        self._learnOnNewSegments(self.apicalConnections, self.rng,
                                 newApicalSegmentCells, apicalGrowthCandidates,
                                 self.initialPermanence, self.sampleSize,
                                 self.maxSynapsesPerSegment)

    # Save the results
    self.activeCells = newActiveCells
    self.winnerCells = learningCells
    self.prevPredictedCells = predictedCells


  def _calculateBasalLearning(self,
                              activeColumns,
                              burstingColumns,
                              correctPredictedCells,
                              activeBasalSegments,
                              matchingBasalSegments,
                              basalPotentialOverlaps):
    """
    Basic Temporal Memory learning. Correctly predicted cells always have
    active basal segments, and we learn on these segments. In bursting
    columns, we either learn on an existing basal segment, or we grow a new one.

    The only influence apical dendrites have on basal learning is: the apical
    dendrites influence which cells are considered "predicted". So an active
    apical dendrite can keep some basal segments in active columns from
    learning.

    @param correctPredictedCells (numpy array)
    @param burstingColumns (numpy array)
    @param activeBasalSegments (numpy array)
    @param matchingBasalSegments (numpy array)
    @param basalPotentialOverlaps (numpy array)

    @return (tuple)
    - learningActiveBasalSegments (numpy array)
      Active basal segments on correct predicted cells

    - learningMatchingBasalSegments (numpy array)
      Matching basal segments selected for learning in bursting columns

    - basalSegmentsToPunish (numpy array)
      Basal segments that should be punished for predicting an inactive column

    - newBasalSegmentCells (numpy array)
      Cells in bursting columns that were selected to grow new basal segments

    - learningCells (numpy array)
      Cells that have learning basal segments or are selected to grow a basal
      segment
    """

    # Correctly predicted columns
    learningActiveBasalSegments = self.basalConnections.filterSegmentsByCell(
      activeBasalSegments, correctPredictedCells)

    cellsForMatchingBasal = self.basalConnections.mapSegmentsToCells(
      matchingBasalSegments)
    matchingCells = np.unique(cellsForMatchingBasal)

    (matchingCellsInBurstingColumns,
     burstingColumnsWithNoMatch) = np2.setCompare(
       matchingCells, burstingColumns, matchingCells / self.cellsPerColumn,
       rightMinusLeft=True)

    learningMatchingBasalSegments = self._chooseBestSegmentPerColumn(
      self.basalConnections, matchingCellsInBurstingColumns,
      matchingBasalSegments, basalPotentialOverlaps, self.cellsPerColumn)
    newBasalSegmentCells = self._getCellsWithFewestSegments(
      self.basalConnections, self.rng, burstingColumnsWithNoMatch,
      self.cellsPerColumn)

    learningCells = np.concatenate(
      (correctPredictedCells,
       self.basalConnections.mapSegmentsToCells(learningMatchingBasalSegments),
       newBasalSegmentCells))

    # Incorrectly predicted columns
    correctMatchingBasalMask = np.in1d(
      cellsForMatchingBasal / self.cellsPerColumn, activeColumns)

    basalSegmentsToPunish = matchingBasalSegments[~correctMatchingBasalMask]

    return (learningActiveBasalSegments,
            learningMatchingBasalSegments,
            basalSegmentsToPunish,
            newBasalSegmentCells,
            learningCells)


  def _calculateApicalLearning(self,
                               learningCells,
                               activeColumns,
                               activeApicalSegments,
                               matchingApicalSegments,
                               apicalPotentialOverlaps):
    """
    Calculate apical learning for each learning cell.

    The set of learning cells was determined completely from basal segments.
    Do all apical learning on the same cells.

    Learn on any active segments on learning cells. For cells without active
    segments, learn on the best matching segment. For cells without a matching
    segment, grow a new segment.

    @param learningCells (numpy array)
    @param correctPredictedCells (numpy array)
    @param activeApicalSegments (numpy array)
    @param matchingApicalSegments (numpy array)
    @param apicalPotentialOverlaps (numpy array)

    @return (tuple)
    - learningActiveApicalSegments (numpy array)
      Active apical segments on correct predicted cells

    - learningMatchingApicalSegments (numpy array)
      Matching apical segments selected for learning in bursting columns

    - apicalSegmentsToPunish (numpy array)
      Apical segments that should be punished for predicting an inactive column

    - newApicalSegmentCells (numpy array)
      Cells in bursting columns that were selected to grow new apical segments
    """

    # Cells with active apical segments
    learningActiveApicalSegments = self.apicalConnections.filterSegmentsByCell(
      activeApicalSegments, learningCells)

    # Cells with matching apical segments
    learningCellsWithoutActiveApical = np.setdiff1d(
      learningCells,
      self.apicalConnections.mapSegmentsToCells(learningActiveApicalSegments))
    cellsForMatchingApical = self.apicalConnections.mapSegmentsToCells(
      matchingApicalSegments)
    learningCellsWithMatchingApical = np.intersect1d(
      learningCellsWithoutActiveApical, cellsForMatchingApical)
    learningMatchingApicalSegments = self._chooseBestSegmentPerCell(
      self.apicalConnections, learningCellsWithMatchingApical,
      matchingApicalSegments, apicalPotentialOverlaps)

    # Cells that need to grow an apical segment
    newApicalSegmentCells = np.setdiff1d(learningCellsWithoutActiveApical,
                                         learningCellsWithMatchingApical)

    # Incorrectly predicted columns
    correctMatchingApicalMask = np.in1d(
      cellsForMatchingApical / self.cellsPerColumn, activeColumns)

    apicalSegmentsToPunish = matchingApicalSegments[~correctMatchingApicalMask]

    return (learningActiveApicalSegments,
            learningMatchingApicalSegments,
            apicalSegmentsToPunish,
            newApicalSegmentCells)


  @staticmethod
  def _calculateSegmentActivity(connections, activeInput, connectedPermanence,
                                activationThreshold, minThreshold):
    """
    Calculate the active and matching segments for this timestep.

    @param connections (SparseMatrixConnections)
    @param activeInput (numpy array)

    @return (tuple)
    - activeSegments (numpy array)
      Dendrite segments with enough active connected synapses to cause a
      dendritic spike

    - matchingSegments (numpy array)
      Dendrite segments with enough active potential synapses to be selected for
      learning in a bursting column

    - potentialOverlaps (numpy array)
      The number of active potential synapses for each segment.
      Includes counts for active, matching, and nonmatching segments.
    """

    # Active
    overlaps = connections.computeActivity(activeInput, connectedPermanence)
    activeSegments = np.flatnonzero(overlaps >= activationThreshold)

    # Matching
    potentialOverlaps = connections.computeActivity(activeInput)
    matchingSegments = np.flatnonzero(potentialOverlaps >= minThreshold)

    return (activeSegments,
            matchingSegments,
            potentialOverlaps)


  def _calculatePredictedCells(self, activeBasalSegments, activeApicalSegments):
    """
    Calculate the predicted cells, given the set of active segments.

    An active basal segment is enough to predict a cell.
    An active apical segment is *not* enough to predict a cell.

    When a cell has both types of segments active, other cells in its minicolumn
    must also have both types of segments to be considered predictive.

    @param activeBasalSegments (numpy array)
    @param activeApicalSegments (numpy array)

    @return (numpy array)
    """

    cellsForBasalSegments = self.basalConnections.mapSegmentsToCells(
      activeBasalSegments)
    cellsForApicalSegments = self.apicalConnections.mapSegmentsToCells(
      activeApicalSegments)

    fullyDepolarizedCells = np.intersect1d(cellsForBasalSegments,
                                           cellsForApicalSegments)
    partlyDepolarizedCells = np.setdiff1d(cellsForBasalSegments,
                                          fullyDepolarizedCells)

    inhibitedMask = np.in1d(partlyDepolarizedCells / self.cellsPerColumn,
                            fullyDepolarizedCells / self.cellsPerColumn)
    predictedCells = np.append(fullyDepolarizedCells,
                               partlyDepolarizedCells[~inhibitedMask])

    return predictedCells


  @staticmethod
  def _learn(connections, rng, learningSegments, activeInput, growthCandidates,
             potentialOverlaps, initialPermanence, sampleSize,
             permanenceIncrement, permanenceDecrement, maxSynapsesPerSegment):
    """
    Adjust synapse permanences, grow new synapses, and grow new segments.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param segmentsToPunish (numpy array)
    @param newSegmentCells (numpy array)
    @param activeInput (numpy array)
    @param growthCandidates (numpy array)
    @param potentialOverlaps (numpy array)
    """

    # Learn on existing segments
    connections.adjustSynapses(learningSegments, activeInput,
                               permanenceIncrement, -permanenceDecrement)

    # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
    # grow per segment. "maxNew" might be a number or it might be a list of
    # numbers.
    if sampleSize == -1:
      maxNew = len(growthCandidates)
    else:
      maxNew = sampleSize - potentialOverlaps[learningSegments]

    if maxSynapsesPerSegment != -1:
      synapseCounts = connections.mapSegmentsToSynapseCounts(
        learningSegments)
      numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
      maxNew = np.where(maxNew <= numSynapsesToReachMax,
                        maxNew, numSynapsesToReachMax)

    connections.growSynapsesToSample(learningSegments, growthCandidates,
                                     maxNew, initialPermanence, rng)


  @staticmethod
  def _learnOnNewSegments(connections, rng, newSegmentCells, growthCandidates,
                          initialPermanence, sampleSize, maxSynapsesPerSegment):

    numNewSynapses = len(growthCandidates)

    if sampleSize != -1:
      numNewSynapses = min(numNewSynapses, sampleSize)

    if maxSynapsesPerSegment != -1:
      numNewSynapses = min(numNewSynapses, maxSynapsesPerSegment)

    newSegments = connections.createSegments(newSegmentCells)
    connections.growSynapsesToSample(newSegments, growthCandidates,
                                     numNewSynapses, initialPermanence,
                                     rng)


  @classmethod
  def _chooseBestSegmentPerCell(cls,
                                connections,
                                cells,
                                allMatchingSegments,
                                potentialOverlaps):
    """
    For each specified cell, choose its matching segment with largest number
    of active potential synapses. When there's a tie, the first segment wins.

    @param connections (SparseMatrixConnections)
    @param cells (numpy array)
    @param allMatchingSegments (numpy array)
    @param potentialOverlaps (numpy array)

    @return (numpy array)
    One segment per cell
    """

    candidateSegments = connections.filterSegmentsByCell(allMatchingSegments,
                                                         cells)

    # Narrow it down to one pair per cell.
    onePerCellFilter = np2.argmaxMulti(potentialOverlaps[candidateSegments],
                                       connections.mapSegmentsToCells(
                                         candidateSegments))
    learningSegments = candidateSegments[onePerCellFilter]

    return learningSegments


  @classmethod
  def _chooseBestSegmentPerColumn(cls, connections, matchingCells,
                                  allMatchingSegments, potentialOverlaps,
                                  cellsPerColumn):
    """
    For all the columns covered by 'matchingCells', choose the column's matching
    segment with largest number of active potential synapses. When there's a
    tie, the first segment wins.

    @param connections (SparseMatrixConnections)
    @param matchingCells (numpy array)
    @param allMatchingSegments (numpy array)
    @param potentialOverlaps (numpy array)
    """

    candidateSegments = connections.filterSegmentsByCell(allMatchingSegments,
                                                         matchingCells)

    # Narrow it down to one segment per column.
    cellScores = potentialOverlaps[candidateSegments]
    columnsForCandidates = (connections.mapSegmentsToCells(candidateSegments) /
                            cellsPerColumn)
    onePerColumnFilter = np2.argmaxMulti(cellScores, columnsForCandidates)

    learningSegments = candidateSegments[onePerColumnFilter]

    return learningSegments


  @classmethod
  def _getCellsWithFewestSegments(cls, connections, rng, columns,
                                  cellsPerColumn):
    """
    For each column, get the cell that has the fewest total basal segments.
    Break ties randomly.

    @param connections (SparseMatrixConnections)
    @param rng (Random)
    @param columns (numpy array) Columns to check

    @return (numpy array)
    One cell for each of the provided columns
    """
    candidateCells = np2.getAllCellsInColumns(columns, cellsPerColumn)

    # Arrange the segment counts into one row per minicolumn.
    segmentCounts = np.reshape(connections.getSegmentCounts(candidateCells),
                               newshape=(len(columns),
                                         cellsPerColumn))

    # Filter to just the cells that are tied for fewest in their minicolumn.
    minSegmentCounts = np.amin(segmentCounts, axis=1, keepdims=True)
    candidateCells = candidateCells[np.flatnonzero(segmentCounts ==
                                                   minSegmentCounts)]

    # Filter to one cell per column, choosing randomly from the minimums.
    # To do the random choice, add a random offset to each index in-place, using
    # casting to floor the result.
    (_,
     onePerColumnFilter,
     numCandidatesInColumns) = np.unique(candidateCells / cellsPerColumn,
                                         return_index=True, return_counts=True)

    offsetPercents = np.empty(len(columns), dtype="float32")
    rng.initializeReal32Array(offsetPercents)

    np.add(onePerColumnFilter,
           offsetPercents*numCandidatesInColumns,
           out=onePerColumnFilter,
           casting="unsafe")

    return candidateCells[onePerColumnFilter]


  @staticmethod
  def _numPoints(dimensions):
    """
    Get the number of discrete points in a set of dimensions.

    @param dimensions (sequence of integers)
    @return (int)
    """
    if len(dimensions) == 0:
      return 0
    else:
      return reduce(operator.mul, dimensions, 1)


  def getActiveCells(self):
    return self.activeCells


  def getWinnerCells(self):
    return self.winnerCells


  def getPreviouslyPredictedCells(self):
    return self.prevPredictedCells
class ApicalDependentTemporalMemory(object):
  """
  A generalized Temporal Memory that creates cell SDRs that are specific to both
  the basal and apical input.

  Prediction requires both basal and apical support. For sequence memory, the
  result is that every sequence happens within a "world" which is specified by
  the apical input. Sequences are not shared between worlds.

  This class is generalized in two ways:

  - This class does not specify when a 'timestep' begins and ends. It exposes
    two main methods: 'depolarizeCells' and 'activateCells', and callers or
    subclasses can introduce the notion of a timestep.
  - This class is unaware of whether its 'basalInput' or 'apicalInput' are from
    internal or external cells. They are just cell numbers. The caller knows
    what these cell numbers mean, but the TemporalMemory doesn't.
  """

  def __init__(self,
               columnCount=2048,
               basalInputSize=0,
               apicalInputSize=0,
               cellsPerColumn=32,
               activationThreshold=13,
               reducedBasalThreshold=10,
               initialPermanence=0.21,
               connectedPermanence=0.50,
               minThreshold=10,
               sampleSize=20,
               permanenceIncrement=0.1,
               permanenceDecrement=0.1,
               basalPredictedSegmentDecrement=0.0,
               apicalPredictedSegmentDecrement=0.0,
               maxSynapsesPerSegment=-1,
               seed=42):
    """
    @param columnCount (int)
    The number of minicolumns

    @param basalInputSize (sequence)
    The number of bits in the basal input

    @param apicalInputSize (int)
    The number of bits in the apical input

    @param cellsPerColumn (int)
    Number of cells per column

    @param activationThreshold (int)
    If the number of active connected synapses on a segment is at least this
    threshold, the segment is said to be active.

    @param reducedBasalThreshold (int)
    The activation threshold of basal (lateral) segments for cells that have
    active apical segments. If equal to activationThreshold (default),
    this parameter has no effect.

    @param initialPermanence (float)
    Initial permanence of a new synapse

    @param connectedPermanence (float)
    If the permanence value for a synapse is greater than this value, it is said
    to be connected.

    @param minThreshold (int)
    If the number of potential synapses active on a segment is at least this
    threshold, it is said to be "matching" and is eligible for learning.

    @param sampleSize (int)
    How much of the active SDR to sample with synapses.

    @param permanenceIncrement (float)
    Amount by which permanences of synapses are incremented during learning.

    @param permanenceDecrement (float)
    Amount by which permanences of synapses are decremented during learning.

    @param basalPredictedSegmentDecrement (float)
    Amount by which segments are punished for incorrect predictions.

    @param apicalPredictedSegmentDecrement (float)
    Amount by which segments are punished for incorrect predictions.

    @param maxSynapsesPerSegment
    The maximum number of synapses per segment.

    @param seed (int)
    Seed for the random number generator.
    """

    self.columnCount = columnCount
    self.cellsPerColumn = cellsPerColumn
    self.initialPermanence = initialPermanence
    self.connectedPermanence = connectedPermanence
    self.minThreshold = minThreshold
    self.sampleSize = sampleSize
    self.permanenceIncrement = permanenceIncrement
    self.permanenceDecrement = permanenceDecrement
    self.basalPredictedSegmentDecrement = basalPredictedSegmentDecrement
    self.apicalPredictedSegmentDecrement = apicalPredictedSegmentDecrement
    self.activationThreshold = activationThreshold
    self.reducedBasalThreshold = reducedBasalThreshold
    self.maxSynapsesPerSegment = maxSynapsesPerSegment
    self.basalConnections = SparseMatrixConnections(columnCount*cellsPerColumn,
                                                    basalInputSize)
    self.disableApicalDependence = False

    self.apicalConnections = SparseMatrixConnections(columnCount*cellsPerColumn,
                                                     apicalInputSize)
    self.rng = Random(seed)
    self.activeCells = np.empty(0, dtype="uint32")
    self.winnerCells = np.empty(0, dtype="uint32")
    self.predictedCells = np.empty(0, dtype="uint32")
    self.predictedActiveCells = np.empty(0, dtype="uint32")
    self.activeBasalSegments = np.empty(0, dtype="uint32")
    self.activeApicalSegments = np.empty(0, dtype="uint32")
    self.matchingBasalSegments = np.empty(0, dtype="uint32")
    self.matchingApicalSegments = np.empty(0, dtype="uint32")
    self.basalPotentialOverlaps = np.empty(0, dtype="int32")
    self.apicalPotentialOverlaps = np.empty(0, dtype="int32")


  def reset(self):
    """
    Clear all cell and segment activity.
    """
    self.activeCells = np.empty(0, dtype="uint32")
    self.winnerCells = np.empty(0, dtype="uint32")
    self.predictedCells = np.empty(0, dtype="uint32")
    self.predictedActiveCells = np.empty(0, dtype="uint32")
    self.activeBasalSegments = np.empty(0, dtype="uint32")
    self.activeApicalSegments = np.empty(0, dtype="uint32")
    self.matchingBasalSegments = np.empty(0, dtype="uint32")
    self.matchingApicalSegments = np.empty(0, dtype="uint32")
    self.basalPotentialOverlaps = np.empty(0, dtype="int32")
    self.apicalPotentialOverlaps = np.empty(0, dtype="int32")


  def depolarizeCells(self, basalInput, apicalInput, learn):
    """
    Calculate predictions.

    @param basalInput (numpy array)
    List of active input bits for the basal dendrite segments

    @param apicalInput (numpy array)
    List of active input bits for the apical dendrite segments

    @param learn (bool)
    Whether learning is enabled. Some TM implementations may depolarize cells
    differently or do segment activity bookkeeping when learning is enabled.
    """
    # Calculate predictions for this timestep
    (activeApicalSegments,
     matchingApicalSegments,
     apicalPotentialOverlaps) = self._calculateSegmentActivity(
       self.apicalConnections, apicalInput, self.connectedPermanence,
       self.activationThreshold, self.minThreshold, self.reducedBasalThreshold)

    apicallySupportedCells = self.apicalConnections.mapSegmentsToCells(
      activeApicalSegments)
    if not self.disableApicalDependence:
      (activeBasalSegments,
       matchingBasalSegments,
       basalPotentialOverlaps) = self._calculateSegmentActivity(
         self.basalConnections, basalInput,
         self.connectedPermanence, self.activationThreshold,
         self.minThreshold, self.reducedBasalThreshold,
         reducedThresholdCells = apicallySupportedCells,)

      predictedCells = np.intersect1d(
        self.basalConnections.mapSegmentsToCells(activeBasalSegments),
        apicallySupportedCells)
    else:
      (activeBasalSegments,
      matchingBasalSegments,
      basalPotentialOverlaps) = self._calculateSegmentActivity(
        self.basalConnections, basalInput, self.connectedPermanence,
        self.activationThreshold, self.minThreshold, self.reducedBasalThreshold)

      predictedCells = self.basalConnections.mapSegmentsToCells(activeBasalSegments)

    self.predictedCells = predictedCells
    self.activeBasalSegments = activeBasalSegments
    self.activeApicalSegments = activeApicalSegments
    self.matchingBasalSegments = matchingBasalSegments
    self.matchingApicalSegments = matchingApicalSegments
    self.basalPotentialOverlaps = basalPotentialOverlaps
    self.apicalPotentialOverlaps = apicalPotentialOverlaps


  def activateCells(self,
                    activeColumns,
                    basalReinforceCandidates,
                    apicalReinforceCandidates,
                    basalGrowthCandidates,
                    apicalGrowthCandidates,
                    learn=True):
    """
    Activate cells in the specified columns, using the result of the previous
    'depolarizeCells' as predictions. Then learn.

    @param activeColumns (numpy array)
    List of active columns

    @param basalReinforceCandidates (numpy array)
    List of bits that the active cells may reinforce basal synapses to.

    @param apicalReinforceCandidates (numpy array)
    List of bits that the active cells may reinforce apical synapses to.

    @param basalGrowthCandidates (numpy array or None)
    List of bits that the active cells may grow new basal synapses to.

    @param apicalGrowthCandidates (numpy array or None)
    List of bits that the active cells may grow new apical synapses to

    @param learn (bool)
    Whether to grow / reinforce / punish synapses
    """

    # Calculate active cells
    (correctPredictedCells,
     burstingColumns) = np2.setCompare(self.predictedCells, activeColumns,
                                       self.predictedCells / self.cellsPerColumn,
                                       rightMinusLeft=True)

    newActiveCells = np.concatenate((correctPredictedCells,
                                     np2.getAllCellsInColumns(
                                       burstingColumns, self.cellsPerColumn)))

    # Calculate learning
    (learningActiveBasalSegments,
     learningActiveApicalSegments,
     learningMatchingBasalSegments,
     learningMatchingApicalSegments,
     basalSegmentsToPunish,
     apicalSegmentsToPunish,
     newSegmentCells,
     learningCells) = self._calculateLearning(activeColumns,
                                              burstingColumns,
                                              correctPredictedCells,
                                              self.activeBasalSegments,
                                              self.activeApicalSegments,
                                              self.matchingBasalSegments,
                                              self.matchingApicalSegments,
                                              self.basalPotentialOverlaps,
                                              self.apicalPotentialOverlaps)

    if learn:
      # Learn on existing segments
      for learningSegments in (learningActiveBasalSegments,
                               learningMatchingBasalSegments):
        self._learn(self.basalConnections, self.rng, learningSegments,
                    basalReinforceCandidates, basalGrowthCandidates,
                    self.basalPotentialOverlaps,
                    self.initialPermanence, self.sampleSize,
                    self.permanenceIncrement, self.permanenceDecrement,
                    self.maxSynapsesPerSegment)

      for learningSegments in (learningActiveApicalSegments,
                               learningMatchingApicalSegments):
        self._learn(self.apicalConnections, self.rng, learningSegments,
                    apicalReinforceCandidates, apicalGrowthCandidates,
                    self.apicalPotentialOverlaps, self.initialPermanence,
                    self.sampleSize, self.permanenceIncrement,
                    self.permanenceDecrement, self.maxSynapsesPerSegment)

      # Punish incorrect predictions
      if self.basalPredictedSegmentDecrement != 0.0:
        self.basalConnections.adjustActiveSynapses(
          basalSegmentsToPunish, basalReinforceCandidates,
          -self.basalPredictedSegmentDecrement)

      if self.apicalPredictedSegmentDecrement != 0.0:
        self.apicalConnections.adjustActiveSynapses(
          apicalSegmentsToPunish, apicalReinforceCandidates,
          -self.apicalPredictedSegmentDecrement)

      # Only grow segments if there is basal *and* apical input.
      if len(basalGrowthCandidates) > 0 and len(apicalGrowthCandidates) > 0:
        self._learnOnNewSegments(self.basalConnections, self.rng,
                                 newSegmentCells, basalGrowthCandidates,
                                 self.initialPermanence, self.sampleSize,
                                 self.maxSynapsesPerSegment)
        self._learnOnNewSegments(self.apicalConnections, self.rng,
                                 newSegmentCells, apicalGrowthCandidates,
                                 self.initialPermanence, self.sampleSize,
                                 self.maxSynapsesPerSegment)


    # Save the results
    newActiveCells.sort()
    learningCells.sort()
    self.activeCells = newActiveCells
    self.winnerCells = learningCells
    self.predictedActiveCells = correctPredictedCells


  def _calculateLearning(self,
                         activeColumns,
                         burstingColumns,
                         correctPredictedCells,
                         activeBasalSegments,
                         activeApicalSegments,
                         matchingBasalSegments,
                         matchingApicalSegments,
                         basalPotentialOverlaps,
                         apicalPotentialOverlaps):
    """
    Learning occurs on pairs of segments. Correctly predicted cells always have
    active basal and apical segments, and we learn on these segments. In
    bursting columns, we either learn on an existing segment pair, or we grow a
    new pair of segments.

    @param activeColumns (numpy array)
    @param burstingColumns (numpy array)
    @param correctPredictedCells (numpy array)
    @param activeBasalSegments (numpy array)
    @param activeApicalSegments (numpy array)
    @param matchingBasalSegments (numpy array)
    @param matchingApicalSegments (numpy array)
    @param basalPotentialOverlaps (numpy array)
    @param apicalPotentialOverlaps (numpy array)

    @return (tuple)
    - learningActiveBasalSegments (numpy array)
      Active basal segments on correct predicted cells

    - learningActiveApicalSegments (numpy array)
      Active apical segments on correct predicted cells

    - learningMatchingBasalSegments (numpy array)
      Matching basal segments selected for learning in bursting columns

    - learningMatchingApicalSegments (numpy array)
      Matching apical segments selected for learning in bursting columns

    - basalSegmentsToPunish (numpy array)
      Basal segments that should be punished for predicting an inactive column

    - apicalSegmentsToPunish (numpy array)
      Apical segments that should be punished for predicting an inactive column

    - newSegmentCells (numpy array)
      Cells in bursting columns that were selected to grow new segments

    - learningCells (numpy array)
      Every cell that has a learning segment or was selected to grow a segment
    """

    # Correctly predicted columns
    learningActiveBasalSegments = self.basalConnections.filterSegmentsByCell(
      activeBasalSegments, correctPredictedCells)
    learningActiveApicalSegments = self.apicalConnections.filterSegmentsByCell(
      activeApicalSegments, correctPredictedCells)

    # Bursting columns
    cellsForMatchingBasal = self.basalConnections.mapSegmentsToCells(
      matchingBasalSegments)
    cellsForMatchingApical = self.apicalConnections.mapSegmentsToCells(
      matchingApicalSegments)
    matchingCells = np.intersect1d(
      cellsForMatchingBasal, cellsForMatchingApical)

    (matchingCellsInBurstingColumns,
     burstingColumnsWithNoMatch) = np2.setCompare(
       matchingCells, burstingColumns, matchingCells / self.cellsPerColumn,
       rightMinusLeft=True)

    (learningMatchingBasalSegments,
     learningMatchingApicalSegments) = self._chooseBestSegmentPairPerColumn(
       matchingCellsInBurstingColumns, matchingBasalSegments,
       matchingApicalSegments, basalPotentialOverlaps, apicalPotentialOverlaps)
    newSegmentCells = self._getCellsWithFewestSegments(
      burstingColumnsWithNoMatch)

    # Incorrectly predicted columns
    if self.basalPredictedSegmentDecrement > 0.0:
      correctMatchingBasalMask = np.in1d(
        cellsForMatchingBasal / self.cellsPerColumn, activeColumns)
      basalSegmentsToPunish = matchingBasalSegments[~correctMatchingBasalMask]
    else:
      basalSegmentsToPunish = ()

    if self.apicalPredictedSegmentDecrement > 0.0:
      correctMatchingApicalMask = np.in1d(
        cellsForMatchingApical / self.cellsPerColumn, activeColumns)
      apicalSegmentsToPunish = matchingApicalSegments[~correctMatchingApicalMask]
    else:
      apicalSegmentsToPunish = ()

    # Make a list of every cell that is learning
    learningCells =  np.concatenate(
      (correctPredictedCells,
       self.basalConnections.mapSegmentsToCells(learningMatchingBasalSegments),
       newSegmentCells))

    return (learningActiveBasalSegments,
            learningActiveApicalSegments,
            learningMatchingBasalSegments,
            learningMatchingApicalSegments,
            basalSegmentsToPunish,
            apicalSegmentsToPunish,
            newSegmentCells,
            learningCells)


  @staticmethod
  def _calculateSegmentActivity(connections, activeInput, connectedPermanence,
                                activationThreshold, minThreshold,
                                reducedThreshold,
                                reducedThresholdCells = ()):
    """
    Calculate the active and matching basal segments for this timestep.

    @param connections (SparseMatrixConnections)
    @param activeInput (numpy array)

    @return (tuple)
    - activeSegments (numpy array)
      Dendrite segments with enough active connected synapses to cause a
      dendritic spike

    - matchingSegments (numpy array)
      Dendrite segments with enough active potential synapses to be selected for
      learning in a bursting column

    - potentialOverlaps (numpy array)
      The number of active potential synapses for each segment.
      Includes counts for active, matching, and nonmatching segments.
    """
    # Active apical segments lower the activation threshold for basal segments
    overlaps = connections.computeActivity(activeInput, connectedPermanence)
    outrightActiveSegments = np.flatnonzero(overlaps >= activationThreshold)
    if (reducedThreshold != activationThreshold and
            len(reducedThresholdCells) > 0):
        potentiallyActiveSegments = np.flatnonzero(
            (overlaps < activationThreshold) & (overlaps >= reducedThreshold))
        cellsOfCASegments = connections.mapSegmentsToCells(
            potentiallyActiveSegments)
        # apically active segments are condit. active segments from apically
        # active cells
        conditionallyActiveSegments = potentiallyActiveSegments[
            np.in1d(cellsOfCASegments, reducedThresholdCells)]
        activeSegments = np.concatenate((outrightActiveSegments,
                                         conditionallyActiveSegments))
    else:
        activeSegments = outrightActiveSegments



    # Matching
    potentialOverlaps = connections.computeActivity(activeInput)
    matchingSegments = np.flatnonzero(potentialOverlaps >= minThreshold)

    return (activeSegments,
            matchingSegments,
            potentialOverlaps)

  @staticmethod
  def _learn(connections, rng, learningSegments, activeInput, growthCandidates,
             potentialOverlaps, initialPermanence, sampleSize,
             permanenceIncrement, permanenceDecrement, maxSynapsesPerSegment):
    """
    Adjust synapse permanences, and grow new synapses.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param activeInput (numpy array)
    @param growthCandidates (numpy array)
    @param potentialOverlaps (numpy array)
    """

    # Learn on existing segments
    connections.adjustSynapses(learningSegments, activeInput,
                               permanenceIncrement, -permanenceDecrement)

    # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
    # grow per segment. "maxNew" might be a number or it might be a list of
    # numbers.
    if sampleSize == -1:
      maxNew = len(growthCandidates)
    else:
      maxNew = sampleSize - potentialOverlaps[learningSegments]

    if maxSynapsesPerSegment != -1:
      synapseCounts = connections.mapSegmentsToSynapseCounts(
        learningSegments)
      numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
      maxNew = np.where(maxNew <= numSynapsesToReachMax,
                        maxNew, numSynapsesToReachMax)

    connections.growSynapsesToSample(learningSegments, growthCandidates,
                                     maxNew, initialPermanence, rng)


  @staticmethod
  def _learnOnNewSegments(connections, rng, newSegmentCells, growthCandidates,
                          initialPermanence, sampleSize, maxSynapsesPerSegment):
    """
    Create new segments, and grow synapses on them.

    @param connections (SparseMatrixConnections)
    @param rng (Random)
    @param newSegmentCells (numpy array)
    @param growthCandidates (numpy array)
    """

    numNewSynapses = len(growthCandidates)

    if sampleSize != -1:
      numNewSynapses = min(numNewSynapses, sampleSize)

    if maxSynapsesPerSegment != -1:
      numNewSynapses = min(numNewSynapses, maxSynapsesPerSegment)

    newSegments = connections.createSegments(newSegmentCells)
    connections.growSynapsesToSample(newSegments, growthCandidates,
                                     numNewSynapses, initialPermanence,
                                     rng)


  def _chooseBestSegmentPairPerColumn(self,
                                      matchingCellsInBurstingColumns,
                                      matchingBasalSegments,
                                      matchingApicalSegments,
                                      basalPotentialOverlaps,
                                      apicalPotentialOverlaps):
    """
    Choose the best pair of matching segments - one basal and one apical - for
    each column. Pairs are ranked by the sum of their potential overlaps.
    When there's a tie, the first pair wins.

    @param matchingCellsInBurstingColumns (numpy array)
    Cells in bursting columns that have at least one matching basal segment and
    at least one matching apical segment

    @param matchingBasalSegments (numpy array)
    @param matchingApicalSegments (numpy array)
    @param basalPotentialOverlaps (numpy array)
    @param apicalPotentialOverlaps (numpy array)

    @return (tuple)
    - learningBasalSegments (numpy array)
      The selected basal segments

    - learningApicalSegments (numpy array)
      The selected apical segments
    """

    basalCandidateSegments = self.basalConnections.filterSegmentsByCell(
      matchingBasalSegments, matchingCellsInBurstingColumns)
    apicalCandidateSegments = self.apicalConnections.filterSegmentsByCell(
      matchingApicalSegments, matchingCellsInBurstingColumns)

    # Sort everything once rather than inside of each call to argmaxMulti.
    self.basalConnections.sortSegmentsByCell(basalCandidateSegments)
    self.apicalConnections.sortSegmentsByCell(apicalCandidateSegments)

    # Narrow it down to one pair per cell.
    oneBasalPerCellFilter = np2.argmaxMulti(
      basalPotentialOverlaps[basalCandidateSegments],
      self.basalConnections.mapSegmentsToCells(basalCandidateSegments),
      assumeSorted=True)
    basalCandidateSegments = basalCandidateSegments[oneBasalPerCellFilter]
    oneApicalPerCellFilter = np2.argmaxMulti(
      apicalPotentialOverlaps[apicalCandidateSegments],
      self.apicalConnections.mapSegmentsToCells(apicalCandidateSegments),
      assumeSorted=True)
    apicalCandidateSegments = apicalCandidateSegments[oneApicalPerCellFilter]

    # Narrow it down to one pair per column.
    cellScores = (basalPotentialOverlaps[basalCandidateSegments] +
                  apicalPotentialOverlaps[apicalCandidateSegments])
    columnsForCandidates = (
      self.basalConnections.mapSegmentsToCells(basalCandidateSegments) /
      self.cellsPerColumn)
    onePerColumnFilter = np2.argmaxMulti(cellScores, columnsForCandidates,
                                         assumeSorted=True)

    learningBasalSegments = basalCandidateSegments[onePerColumnFilter]
    learningApicalSegments = apicalCandidateSegments[onePerColumnFilter]

    return (learningBasalSegments,
            learningApicalSegments)


  def _getCellsWithFewestSegments(self, columns):
    """
    For each column, get the cell that has the fewest total segments (basal or
    apical). Break ties randomly.

    @param columns (numpy array)
    Columns to check

    @return (numpy array)
    One cell for each of the provided columns
    """
    candidateCells = np2.getAllCellsInColumns(columns, self.cellsPerColumn)

    # Arrange the segment counts into one row per minicolumn.
    segmentCounts = np.reshape(
      self.basalConnections.getSegmentCounts(candidateCells) +
      self.apicalConnections.getSegmentCounts(candidateCells),
      newshape=(len(columns),
                self.cellsPerColumn))

    # Filter to just the cells that are tied for fewest in their minicolumn.
    minSegmentCounts = np.amin(segmentCounts, axis=1, keepdims=True)
    candidateCells = candidateCells[np.flatnonzero(segmentCounts ==
                                                   minSegmentCounts)]

    # Filter to one cell per column, choosing randomly from the minimums.
    # To do the random choice, add a random offset to each index in-place, using
    # casting to floor the result.
    (_,
     onePerColumnFilter,
     numCandidatesInColumns) = np.unique(candidateCells / self.cellsPerColumn,
                                         return_index=True, return_counts=True)

    offsetPercents = np.empty(len(columns), dtype="float32")
    self.rng.initializeReal32Array(offsetPercents)

    np.add(onePerColumnFilter,
           offsetPercents*numCandidatesInColumns,
           out=onePerColumnFilter,
           casting="unsafe")

    return candidateCells[onePerColumnFilter]


  def getActiveCells(self):
    """
    @return (numpy array)
    Active cells
    """
    return self.activeCells


  def getPredictedActiveCells(self):
    """
    @return (numpy array)
    Active cells that were correctly predicted
    """
    return np.intersect1d(self.activeCells, self.predictedCells)


  def getWinnerCells(self):
    """
    @return (numpy array)
    Cells that were selected for learning
    """
    return self.winnerCells


  def getPredictedCells(self):
    """
    @return (numpy array)
    Cells that were predicted for this timestep
    """
    return self.predictedCells


  def getActiveBasalSegments(self):
    """
    @return (numpy array)
    Active basal segments for this timestep
    """
    return self.activeBasalSegments


  def getActiveApicalSegments(self):
    """
    @return (numpy array)
    Matching basal segments for this timestep
    """
    return self.activeApicalSegments


  def numberOfColumns(self):
    """ Returns the number of columns in this layer.

    @return (int) Number of columns
    """
    return self.columnCount


  def numberOfCells(self):
    """
    Returns the number of cells in this layer.

    @return (int) Number of cells
    """
    return self.numberOfColumns() * self.cellsPerColumn


  def getCellsPerColumn(self):
    """
    Returns the number of cells per column.

    @return (int) The number of cells per column.
    """
    return self.cellsPerColumn


  def getActivationThreshold(self):
    """
    Returns the activation threshold.
    @return (int) The activation threshold.
    """
    return self.activationThreshold


  def setActivationThreshold(self, activationThreshold):
    """
    Sets the activation threshold.
    @param activationThreshold (int) activation threshold.
    """
    self.activationThreshold = activationThreshold


  def getInitialPermanence(self):
    """
    Get the initial permanence.
    @return (float) The initial permanence.
    """
    return self.initialPermanence


  def setInitialPermanence(self, initialPermanence):
    """
    Sets the initial permanence.
    @param initialPermanence (float) The initial permanence.
    """
    self.initialPermanence = initialPermanence


  def getMinThreshold(self):
    """
    Returns the min threshold.
    @return (int) The min threshold.
    """
    return self.minThreshold


  def setMinThreshold(self, minThreshold):
    """
    Sets the min threshold.
    @param minThreshold (int) min threshold.
    """
    self.minThreshold = minThreshold


  def getSampleSize(self):
    """
    Gets the sampleSize.
    @return (int)
    """
    return self.sampleSize


  def setSampleSize(self, sampleSize):
    """
    Sets the sampleSize.
    @param sampleSize (int)
    """
    self.sampleSize = sampleSize


  def getPermanenceIncrement(self):
    """
    Get the permanence increment.
    @return (float) The permanence increment.
    """
    return self.permanenceIncrement


  def setPermanenceIncrement(self, permanenceIncrement):
    """
    Sets the permanence increment.
    @param permanenceIncrement (float) The permanence increment.
    """
    self.permanenceIncrement = permanenceIncrement


  def getPermanenceDecrement(self):
    """
    Get the permanence decrement.
    @return (float) The permanence decrement.
    """
    return self.permanenceDecrement


  def setPermanenceDecrement(self, permanenceDecrement):
    """
    Sets the permanence decrement.
    @param permanenceDecrement (float) The permanence decrement.
    """
    self.permanenceDecrement = permanenceDecrement


  def getBasalPredictedSegmentDecrement(self):
    """
    Get the predicted segment decrement.
    @return (float) The predicted segment decrement.
    """
    return self.basalPredictedSegmentDecrement


  def setBasalPredictedSegmentDecrement(self, predictedSegmentDecrement):
    """
    Sets the predicted segment decrement.
    @param predictedSegmentDecrement (float) The predicted segment decrement.
    """
    self.basalPredictedSegmentDecrement = basalPredictedSegmentDecrement


  def getApicalPredictedSegmentDecrement(self):
    """
    Get the predicted segment decrement.
    @return (float) The predicted segment decrement.
    """
    return self.apicalPredictedSegmentDecrement


  def setApicalPredictedSegmentDecrement(self, predictedSegmentDecrement):
    """
    Sets the predicted segment decrement.
    @param predictedSegmentDecrement (float) The predicted segment decrement.
    """
    self.apicalPredictedSegmentDecrement = apicalPredictedSegmentDecrement


  def getConnectedPermanence(self):
    """
    Get the connected permanence.
    @return (float) The connected permanence.
    """
    return self.connectedPermanence


  def setConnectedPermanence(self, connectedPermanence):
    """
    Sets the connected permanence.
    @param connectedPermanence (float) The connected permanence.
    """
    self.connectedPermanence = connectedPermanence
class ApicalDependentTemporalMemory(object):
  """
  An alternate approach to apical dendrites. Every cell SDR is specific to both
  the basal the apical input. Prediction requires both basal and apical support.

  A normal TemporalMemory trained on the sequences "A B C D" and "A B C E" will
  not assign "B" and "C" SDRs specific to their full sequence. These two
  sequences will use the same B' and C' SDRs. When the sequence reaches D/E,
  the SDRs finally diverge.

  With this algorithm, the SDRs diverge immediately, because the SDRs are
  specific to the apical input. But if there's never any apical input, there
  will never be predictions.
  """

  def __init__(self,
               columnDimensions=(2048,),
               basalInputDimensions=(),
               apicalInputDimensions=(),
               cellsPerColumn=32,
               activationThreshold=13,
               initialPermanence=0.21,
               connectedPermanence=0.50,
               minThreshold=10,
               sampleSize=20,
               permanenceIncrement=0.1,
               permanenceDecrement=0.1,
               predictedSegmentDecrement=0.0,
               maxNewSynapseCount=None,
               maxSynapsesPerSegment=-1,
               maxSegmentsPerCell=None,
               seed=42):

    self.columnDimensions = columnDimensions
    self.numColumns = self._numPoints(columnDimensions)
    self.basalInputDimensions = basalInputDimensions
    self.apicalInputDimensions = apicalInputDimensions

    self.cellsPerColumn = cellsPerColumn
    self.initialPermanence = initialPermanence
    self.connectedPermanence = connectedPermanence
    self.minThreshold = minThreshold

    self.sampleSize = sampleSize
    if maxNewSynapseCount is not None:
      print "Parameter 'maxNewSynapseCount' is deprecated. Use 'sampleSize'."
      self.sampleSize = maxNewSynapseCount

    if maxSegmentsPerCell is not None:
      print "Warning: ignoring parameter 'maxSegmentsPerCell'"

    self.permanenceIncrement = permanenceIncrement
    self.permanenceDecrement = permanenceDecrement
    self.predictedSegmentDecrement = predictedSegmentDecrement
    self.activationThreshold = activationThreshold
    self.maxSynapsesPerSegment = maxSynapsesPerSegment

    self.basalConnections = SparseMatrixConnections(
      self.numColumns*cellsPerColumn, self._numPoints(basalInputDimensions))
    self.apicalConnections = SparseMatrixConnections(
      self.numColumns*cellsPerColumn, self._numPoints(apicalInputDimensions))
    self.rng = Random(seed)
    self.activeCells = EMPTY_UINT_ARRAY
    self.winnerCells = EMPTY_UINT_ARRAY
    self.prevPredictedCells = EMPTY_UINT_ARRAY


  def reset(self):
    self.activeCells = EMPTY_UINT_ARRAY
    self.winnerCells = EMPTY_UINT_ARRAY
    self.prevPredictedCells = EMPTY_UINT_ARRAY


  def compute(self,
              activeColumns,
              basalInput,
              basalGrowthCandidates,
              apicalInput,
              apicalGrowthCandidates,
              learn=True):
    """
    @param activeColumns (numpy array)
    @param basalInput (numpy array)
    @param basalGrowthCandidates (numpy array)
    @param apicalInput (numpy array)
    @param apicalGrowthCandidates (numpy array)
    @param learn (bool)
    """
    # Calculate predictions for this timestep
    (activeBasalSegments,
     matchingBasalSegments,
     basalPotentialOverlaps) = self._calculateSegmentActivity(
       self.basalConnections, basalInput, self.connectedPermanence,
       self.activationThreshold, self.minThreshold)

    (activeApicalSegments,
     matchingApicalSegments,
     apicalPotentialOverlaps) = self._calculateSegmentActivity(
       self.apicalConnections, apicalInput, self.connectedPermanence,
       self.activationThreshold, self.minThreshold)

    predictedCells = np.intersect1d(
      self.basalConnections.mapSegmentsToCells(activeBasalSegments),
      self.apicalConnections.mapSegmentsToCells(activeApicalSegments))

    # Calculate active cells
    (correctPredictedCells,
     burstingColumns) = np2.setCompare(predictedCells, activeColumns,
                                       predictedCells / self.cellsPerColumn,
                                       rightMinusLeft=True)
    newActiveCells = np.concatenate((correctPredictedCells,
                                     np2.getAllCellsInColumns(
                                       burstingColumns, self.cellsPerColumn)))

    # Calculate learning
    (learningActiveBasalSegments,
     learningActiveApicalSegments,
     learningMatchingBasalSegments,
     learningMatchingApicalSegments,
     basalSegmentsToPunish,
     apicalSegmentsToPunish,
     newSegmentCells,
     learningCells) = self._calculateLearning(activeColumns,
                                              burstingColumns,
                                              correctPredictedCells,
                                              activeBasalSegments,
                                              activeApicalSegments,
                                              matchingBasalSegments,
                                              matchingApicalSegments,
                                              basalPotentialOverlaps,
                                              apicalPotentialOverlaps)

    if learn:
      # Learn on existing segments
      for learningSegments in (learningActiveBasalSegments,
                               learningMatchingBasalSegments):
        self._learn(self.basalConnections, self.rng, learningSegments,
                    basalInput, basalGrowthCandidates, basalPotentialOverlaps,
                    self.initialPermanence, self.sampleSize,
                    self.permanenceIncrement, self.permanenceDecrement,
                    self.maxSynapsesPerSegment)

      for learningSegments in (learningActiveApicalSegments,
                               learningMatchingApicalSegments):

        self._learn(self.apicalConnections, self.rng, learningSegments,
                    apicalInput, apicalGrowthCandidates,
                    apicalPotentialOverlaps, self.initialPermanence,
                    self.sampleSize, self.permanenceIncrement,
                    self.permanenceDecrement, self.maxSynapsesPerSegment)

      # Punish incorrect predictions
      if self.predictedSegmentDecrement != 0.0:
        self.basalConnections.adjustActiveSynapses(
          basalSegmentsToPunish, basalInput, -self.predictedSegmentDecrement)
        self.apicalConnections.adjustActiveSynapses(
          apicalSegmentsToPunish, apicalInput, -self.predictedSegmentDecrement)

      # Only grow segments if there is basal *and* apical input.
      if len(basalGrowthCandidates) > 0 and len(apicalGrowthCandidates) > 0:
        self._learnOnNewSegments(self.basalConnections, self.rng,
                                 newSegmentCells, basalGrowthCandidates,
                                 self.initialPermanence, self.sampleSize,
                                 self.maxSynapsesPerSegment)
        self._learnOnNewSegments(self.apicalConnections, self.rng,
                                 newSegmentCells, apicalGrowthCandidates,
                                 self.initialPermanence, self.sampleSize,
                                 self.maxSynapsesPerSegment)


    # Save the results
    self.prevPredictedCells = predictedCells
    self.activeCells = newActiveCells
    self.winnerCells = learningCells


  def _calculateLearning(self,
                         activeColumns,
                         burstingColumns,
                         correctPredictedCells,
                         activeBasalSegments,
                         activeApicalSegments,
                         matchingBasalSegments,
                         matchingApicalSegments,
                         basalPotentialOverlaps,
                         apicalPotentialOverlaps):
    """
    Learning occurs on pairs of segments. Correctly predicted cells always have
    active basal and apical segments, and we learn on these segments. In
    bursting columns, we either learn on an existing segment pair, or we grow a
    new pair of segments.

    @param activeColumns (numpy array)
    @param burstingColumns (numpy array)
    @param correctPredictedCells (numpy array)
    @param activeBasalSegments (numpy array)
    @param activeApicalSegments (numpy array)
    @param matchingBasalSegments (numpy array)
    @param matchingApicalSegments (numpy array)
    @param basalPotentialOverlaps (numpy array)
    @param apicalPotentialOverlaps (numpy array)

    @return (tuple)
    - learningActiveBasalSegments (numpy array)
      Active basal segments on correct predicted cells

    - learningActiveApicalSegments (numpy array)
      Active apical segments on correct predicted cells

    - learningMatchingBasalSegments (numpy array)
      Matching basal segments selected for learning in bursting columns

    - learningMatchingApicalSegments (numpy array)
      Matching apical segments selected for learning in bursting columns

    - basalSegmentsToPunish (numpy array)
      Basal segments that should be punished for predicting an inactive column

    - apicalSegmentsToPunish (numpy array)
      Apical segments that should be punished for predicting an inactive column

    - newSegmentCells (numpy array)
      Cells in bursting columns that were selected to grow new segments

    - learningCells (numpy array)
      Every cell that has a learning segment or was selected to grow a segment
    """

    # Correctly predicted columns
    learningActiveBasalSegments = self.basalConnections.filterSegmentsByCell(
      activeBasalSegments, correctPredictedCells)
    learningActiveApicalSegments = self.apicalConnections.filterSegmentsByCell(
      activeApicalSegments, correctPredictedCells)

    # Bursting columns
    cellsForMatchingBasal = self.basalConnections.mapSegmentsToCells(
      matchingBasalSegments)
    cellsForMatchingApical = self.apicalConnections.mapSegmentsToCells(
      matchingApicalSegments)
    matchingCells = np.intersect1d(
      cellsForMatchingBasal, cellsForMatchingApical)

    (matchingCellsInBurstingColumns,
     burstingColumnsWithNoMatch) = np2.setCompare(
       matchingCells, burstingColumns, matchingCells / self.cellsPerColumn,
       rightMinusLeft=True)

    (learningMatchingBasalSegments,
     learningMatchingApicalSegments) = self._chooseBestSegmentPairPerColumn(
       matchingCellsInBurstingColumns, matchingBasalSegments,
       matchingApicalSegments, basalPotentialOverlaps, apicalPotentialOverlaps)
    newSegmentCells = self._getCellsWithFewestSegments(
      burstingColumnsWithNoMatch)

    # Incorrectly predicted columns
    if self.predictedSegmentDecrement > 0.0:
      correctMatchingBasalMask = np.in1d(
        cellsForMatchingBasal / self.cellsPerColumn, activeColumns)
      correctMatchingApicalMask = np.in1d(
        cellsForMatchingApical / self.cellsPerColumn, activeColumns)

      basalSegmentsToPunish = matchingBasalSegments[~correctMatchingBasalMask]
      apicalSegmentsToPunish = matchingApicalSegments[~correctMatchingApicalMask]
    else:
      basalSegmentsToPunish = ()
      apicalSegmentsToPunish = ()

    # Make a list of every cell that is learning
    learningCells =  np.concatenate(
      (correctPredictedCells,
       self.basalConnections.mapSegmentsToCells(learningMatchingBasalSegments),
       newSegmentCells))

    return (learningActiveBasalSegments,
            learningActiveApicalSegments,
            learningMatchingBasalSegments,
            learningMatchingApicalSegments,
            basalSegmentsToPunish,
            apicalSegmentsToPunish,
            newSegmentCells,
            learningCells)


  @staticmethod
  def _calculateSegmentActivity(connections, activeInput, connectedPermanence,
                                activationThreshold, minThreshold):
    """
    Calculate the active and matching segments for this timestep.

    @param connections (SparseMatrixConnections)
    @param activeInput (numpy array)

    @return (tuple)
    - activeSegments (numpy array)
      Dendrite segments with enough active connected synapses to cause a
      dendritic spike

    - matchingSegments (numpy array)
      Dendrite segments with enough active potential synapses to be selected for
      learning in a bursting column

    - potentialOverlaps (numpy array)
      The number of active potential synapses for each segment.
      Includes counts for active, matching, and nonmatching segments.
    """

    # Active
    overlaps = connections.computeActivity(activeInput, connectedPermanence)
    activeSegments = np.flatnonzero(overlaps >= activationThreshold)

    # Matching
    potentialOverlaps = connections.computeActivity(activeInput)
    matchingSegments = np.flatnonzero(potentialOverlaps >= minThreshold)

    return (activeSegments,
            matchingSegments,
            potentialOverlaps)


  @staticmethod
  def _learn(connections, rng, learningSegments, activeInput, growthCandidates,
             potentialOverlaps, initialPermanence, sampleSize,
             permanenceIncrement, permanenceDecrement, maxSynapsesPerSegment):
    """
    Adjust synapse permanences, and grow new synapses.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param segmentsToPunish (numpy array)
    @param newSegmentCells (numpy array)
    @param activeInput (numpy array)
    @param growthCandidates (numpy array)
    @param potentialOverlaps (numpy array)
    """

    # Learn on existing segments
    connections.adjustSynapses(learningSegments, activeInput,
                               permanenceIncrement, -permanenceDecrement)

    # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
    # grow per segment. "maxNew" might be a number or it might be a list of
    # numbers.
    if sampleSize == -1:
      maxNew = len(growthCandidates)
    else:
      maxNew = sampleSize - potentialOverlaps[learningSegments]

    if maxSynapsesPerSegment != -1:
      synapseCounts = connections.mapSegmentsToSynapseCounts(
        learningSegments)
      numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
      maxNew = np.where(maxNew <= numSynapsesToReachMax,
                        maxNew, numSynapsesToReachMax)

    connections.growSynapsesToSample(learningSegments, growthCandidates,
                                     maxNew, initialPermanence, rng)


  @staticmethod
  def _learnOnNewSegments(connections, rng, newSegmentCells, growthCandidates,
                          initialPermanence, sampleSize, maxSynapsesPerSegment):
    """
    Create new segments, and grow synapses on them.

    @param connections (SparseMatrixConnections)
    @param rng (Random)
    @param newSegmentCells (numpy array)
    @param growthCandidates (numpy array)
    """

    numNewSynapses = len(growthCandidates)

    if sampleSize != -1:
      numNewSynapses = min(numNewSynapses, sampleSize)

    if maxSynapsesPerSegment != -1:
      numNewSynapses = min(numNewSynapses, maxSynapsesPerSegment)

    newSegments = connections.createSegments(newSegmentCells)
    connections.growSynapsesToSample(newSegments, growthCandidates,
                                     numNewSynapses, initialPermanence,
                                     rng)


  def _chooseBestSegmentPairPerColumn(self,
                                      matchingCellsInBurstingColumns,
                                      matchingBasalSegments,
                                      matchingApicalSegments,
                                      basalPotentialOverlaps,
                                      apicalPotentialOverlaps):
    """
    Choose the best pair of matching segments - one basal and one apical - for
    each column. Pairs are ranked by the sum of their potential overlaps.
    When there's a tie, the first pair wins.

    @param matchingCellsInBurstingColumns (numpy array)
    Cells in bursting columns that have at least one matching basal segment and
    at least one matching apical segment

    @param matchingBasalSegments (numpy array)
    @param matchingApicalSegments (numpy array)
    @param basalPotentialOverlaps (numpy array)
    @param apicalPotentialOverlaps (numpy array)

    @return (tuple)
    - learningBasalSegments (numpy array)
      The selected basal segments

    - learningApicalSegments (numpy array)
      The selected apical segments
    """

    basalCandidateSegments = self.basalConnections.filterSegmentsByCell(
      matchingBasalSegments, matchingCellsInBurstingColumns)
    apicalCandidateSegments = self.apicalConnections.filterSegmentsByCell(
      matchingApicalSegments, matchingCellsInBurstingColumns)

    # Sort everything once rather than inside of each call to argmaxMulti.
    self.basalConnections.sortSegmentsByCell(basalCandidateSegments)
    self.apicalConnections.sortSegmentsByCell(apicalCandidateSegments)

    # Narrow it down to one pair per cell.
    oneBasalPerCellFilter = np2.argmaxMulti(
      basalPotentialOverlaps[basalCandidateSegments],
      self.basalConnections.mapSegmentsToCells(basalCandidateSegments),
      assumeSorted=True)
    basalCandidateSegments = basalCandidateSegments[oneBasalPerCellFilter]
    oneApicalPerCellFilter = np2.argmaxMulti(
      apicalPotentialOverlaps[apicalCandidateSegments],
      self.apicalConnections.mapSegmentsToCells(apicalCandidateSegments),
      assumeSorted=True)
    apicalCandidateSegments = apicalCandidateSegments[oneApicalPerCellFilter]

    # Narrow it down to one pair per column.
    cellScores = (basalPotentialOverlaps[basalCandidateSegments] +
                  apicalPotentialOverlaps[apicalCandidateSegments])
    columnsForCandidates = (
      self.basalConnections.mapSegmentsToCells(basalCandidateSegments) /
      self.cellsPerColumn)
    onePerColumnFilter = np2.argmaxMulti(cellScores, columnsForCandidates,
                                         assumeSorted=True)

    learningBasalSegments = basalCandidateSegments[onePerColumnFilter]
    learningApicalSegments = apicalCandidateSegments[onePerColumnFilter]

    return (learningBasalSegments,
            learningApicalSegments)


  def _getCellsWithFewestSegments(self, columns):
    """
    For each column, get the cell that has the fewest total segments (basal or
    apical). Break ties randomly.

    @param columns (numpy array)
    Columns to check

    @return (numpy array)
    One cell for each of the provided columns
    """
    candidateCells = np2.getAllCellsInColumns(columns, self.cellsPerColumn)

    # Arrange the segment counts into one row per minicolumn.
    segmentCounts = np.reshape(
      self.basalConnections.getSegmentCounts(candidateCells) +
      self.apicalConnections.getSegmentCounts(candidateCells),
      newshape=(len(columns),
                self.cellsPerColumn))

    # Filter to just the cells that are tied for fewest in their minicolumn.
    minSegmentCounts = np.amin(segmentCounts, axis=1, keepdims=True)
    candidateCells = candidateCells[np.flatnonzero(segmentCounts ==
                                                   minSegmentCounts)]

    # Filter to one cell per column, choosing randomly from the minimums.
    # To do the random choice, add a random offset to each index in-place, using
    # casting to floor the result.
    (_,
     onePerColumnFilter,
     numCandidatesInColumns) = np.unique(candidateCells / self.cellsPerColumn,
                                         return_index=True, return_counts=True)

    offsetPercents = np.empty(len(columns), dtype="float32")
    self.rng.initializeReal32Array(offsetPercents)

    np.add(onePerColumnFilter,
           offsetPercents*numCandidatesInColumns,
           out=onePerColumnFilter,
           casting="unsafe")

    return candidateCells[onePerColumnFilter]


  @staticmethod
  def _numPoints(dimensions):
    """
    Get the number of discrete points in a set of dimensions.

    @param dimensions (sequence of integers)
    @return (int)
    """
    if len(dimensions) == 0:
      return 0
    else:
      return reduce(operator.mul, dimensions, 1)


  def getActiveCells(self):
    return self.activeCells


  def getWinnerCells(self):
    return self.winnerCells


  def getPreviouslyPredictedCells(self):
    return self.prevPredictedCells
Beispiel #10
0
class ApicalTiebreakTemporalMemory(object):
    """
  A generalized Temporal Memory with apical dendrites that add a "tiebreak".

  Basal connections are used to implement traditional Temporal Memory.

  The apical connections are used for further disambiguation. If multiple cells
  in a minicolumn have active basal segments, each of those cells is predicted,
  unless one of them also has an active apical segment, in which case only the
  cells with active basal and apical segments are predicted.

  In other words, the apical connections have no effect unless the basal input
  is a union of SDRs (e.g. from bursting minicolumns).

  This class is generalized in two ways:

  - This class does not specify when a 'timestep' begins and ends. It exposes
    two main methods: 'depolarizeCells' and 'activateCells', and callers or
    subclasses can introduce the notion of a timestep.
  - This class is unaware of whether its 'basalInput' or 'apicalInput' are from
    internal or external cells. They are just cell numbers. The caller knows
    what these cell numbers mean, but the TemporalMemory doesn't.
  """
    def __init__(self,
                 columnCount=2048,
                 basalInputSize=0,
                 apicalInputSize=0,
                 cellsPerColumn=32,
                 activationThreshold=13,
                 reducedBasalThreshold=13,
                 initialPermanence=0.21,
                 connectedPermanence=0.50,
                 minThreshold=10,
                 sampleSize=20,
                 permanenceIncrement=0.1,
                 permanenceDecrement=0.1,
                 basalPredictedSegmentDecrement=0.0,
                 apicalPredictedSegmentDecrement=0.0,
                 maxSynapsesPerSegment=-1,
                 seed=42):
        """
    @param columnCount (int)
    The number of minicolumns

    @param basalInputSize (sequence)
    The number of bits in the basal input

    @param apicalInputSize (int)
    The number of bits in the apical input

    @param cellsPerColumn (int)
    Number of cells per column

    @param activationThreshold (int)
    If the number of active connected synapses on a segment is at least this
    threshold, the segment is said to be active.

    @param reducedBasalThreshold (int)
    The activation threshold of basal (lateral) segments for cells that have
    active apical segments. If equal to activationThreshold (default),
    this parameter has no effect.

    @param initialPermanence (float)
    Initial permanence of a new synapse

    @param connectedPermanence (float)
    If the permanence value for a synapse is greater than this value, it is said
    to be connected.

    @param minThreshold (int)
    If the number of potential synapses active on a segment is at least this
    threshold, it is said to be "matching" and is eligible for learning.

    @param sampleSize (int)
    How much of the active SDR to sample with synapses.

    @param permanenceIncrement (float)
    Amount by which permanences of synapses are incremented during learning.

    @param permanenceDecrement (float)
    Amount by which permanences of synapses are decremented during learning.

    @param basalPredictedSegmentDecrement (float)
    Amount by which segments are punished for incorrect predictions.

    @param apicalPredictedSegmentDecrement (float)
    Amount by which segments are punished for incorrect predictions.

    @param maxSynapsesPerSegment
    The maximum number of synapses per segment.

    @param seed (int)
    Seed for the random number generator.
    """

        self.columnCount = columnCount
        self.cellsPerColumn = cellsPerColumn
        self.initialPermanence = initialPermanence
        self.connectedPermanence = connectedPermanence
        self.reducedBasalThreshold = reducedBasalThreshold
        self.minThreshold = minThreshold
        self.sampleSize = sampleSize
        self.permanenceIncrement = permanenceIncrement
        self.permanenceDecrement = permanenceDecrement
        self.basalPredictedSegmentDecrement = basalPredictedSegmentDecrement
        self.apicalPredictedSegmentDecrement = apicalPredictedSegmentDecrement
        self.activationThreshold = activationThreshold
        self.maxSynapsesPerSegment = maxSynapsesPerSegment

        self.basalConnections = SparseMatrixConnections(
            columnCount * cellsPerColumn, basalInputSize)
        self.apicalConnections = SparseMatrixConnections(
            columnCount * cellsPerColumn, apicalInputSize)
        self.rng = Random(seed)
        self.activeCells = np.empty(0, dtype="uint32")
        self.winnerCells = np.empty(0, dtype="uint32")
        self.predictedCells = np.empty(0, dtype="uint32")
        self.predictedActiveCells = np.empty(0, dtype="uint32")
        self.activeBasalSegments = np.empty(0, dtype="uint32")
        self.activeApicalSegments = np.empty(0, dtype="uint32")
        self.matchingBasalSegments = np.empty(0, dtype="uint32")
        self.matchingApicalSegments = np.empty(0, dtype="uint32")
        self.basalPotentialOverlaps = np.empty(0, dtype="int32")
        self.apicalPotentialOverlaps = np.empty(0, dtype="int32")

        self.useApicalTiebreak = True
        self.useApicalModulationBasalThreshold = True

    def reset(self):
        """
    Clear all cell and segment activity.
    """
        self.activeCells = np.empty(0, dtype="uint32")
        self.winnerCells = np.empty(0, dtype="uint32")
        self.predictedCells = np.empty(0, dtype="uint32")
        self.predictedActiveCells = np.empty(0, dtype="uint32")
        self.activeBasalSegments = np.empty(0, dtype="uint32")
        self.activeApicalSegments = np.empty(0, dtype="uint32")
        self.matchingBasalSegments = np.empty(0, dtype="uint32")
        self.matchingApicalSegments = np.empty(0, dtype="uint32")
        self.basalPotentialOverlaps = np.empty(0, dtype="int32")
        self.apicalPotentialOverlaps = np.empty(0, dtype="int32")

    def depolarizeCells(self, basalInput, apicalInput, learn):
        """
    Calculate predictions.

    @param basalInput (numpy array)
    List of active input bits for the basal dendrite segments

    @param apicalInput (numpy array)
    List of active input bits for the apical dendrite segments

    @param learn (bool)
    Whether learning is enabled. Some TM implementations may depolarize cells
    differently or do segment activity bookkeeping when learning is enabled.
    """
        (activeApicalSegments, matchingApicalSegments,
         apicalPotentialOverlaps) = self._calculateApicalSegmentActivity(
             self.apicalConnections, apicalInput, self.connectedPermanence,
             self.activationThreshold, self.minThreshold)

        if learn or self.useApicalModulationBasalThreshold == False:
            reducedBasalThresholdCells = ()
        else:
            reducedBasalThresholdCells = self.apicalConnections.mapSegmentsToCells(
                activeApicalSegments)

        (activeBasalSegments, matchingBasalSegments,
         basalPotentialOverlaps) = self._calculateBasalSegmentActivity(
             self.basalConnections, basalInput, reducedBasalThresholdCells,
             self.connectedPermanence, self.activationThreshold,
             self.minThreshold, self.reducedBasalThreshold)

        predictedCells = self._calculatePredictedCells(activeBasalSegments,
                                                       activeApicalSegments)

        self.predictedCells = predictedCells
        self.activeBasalSegments = activeBasalSegments
        self.activeApicalSegments = activeApicalSegments
        self.matchingBasalSegments = matchingBasalSegments
        self.matchingApicalSegments = matchingApicalSegments
        self.basalPotentialOverlaps = basalPotentialOverlaps
        self.apicalPotentialOverlaps = apicalPotentialOverlaps

    def activateCells(self,
                      activeColumns,
                      basalReinforceCandidates,
                      apicalReinforceCandidates,
                      basalGrowthCandidates,
                      apicalGrowthCandidates,
                      learn=True):
        """
    Activate cells in the specified columns, using the result of the previous
    'depolarizeCells' as predictions. Then learn.

    @param activeColumns (numpy array)
    List of active columns

    @param basalReinforceCandidates (numpy array)
    List of bits that the active cells may reinforce basal synapses to.

    @param apicalReinforceCandidates (numpy array)
    List of bits that the active cells may reinforce apical synapses to.

    @param basalGrowthCandidates (numpy array)
    List of bits that the active cells may grow new basal synapses to.

    @param apicalGrowthCandidates (numpy array)
    List of bits that the active cells may grow new apical synapses to

    @param learn (bool)
    Whether to grow / reinforce / punish synapses
    """

        # Calculate active cells
        (correctPredictedCells, burstingColumns) = np2.setCompare(
            self.predictedCells,
            activeColumns,
            self.predictedCells / self.cellsPerColumn,
            rightMinusLeft=True)
        newActiveCells = np.concatenate(
            (correctPredictedCells,
             np2.getAllCellsInColumns(burstingColumns, self.cellsPerColumn)))

        # Calculate learning
        (learningActiveBasalSegments, learningMatchingBasalSegments,
         basalSegmentsToPunish, newBasalSegmentCells,
         learningCells) = self._calculateBasalLearning(
             activeColumns, burstingColumns, correctPredictedCells,
             self.activeBasalSegments, self.matchingBasalSegments,
             self.basalPotentialOverlaps)

        (learningActiveApicalSegments, learningMatchingApicalSegments,
         apicalSegmentsToPunish,
         newApicalSegmentCells) = self._calculateApicalLearning(
             learningCells, activeColumns, self.activeApicalSegments,
             self.matchingApicalSegments, self.apicalPotentialOverlaps)

        # Learn
        if learn:
            # Learn on existing segments
            for learningSegments in (learningActiveBasalSegments,
                                     learningMatchingBasalSegments):
                self._learn(self.basalConnections, self.rng, learningSegments,
                            basalReinforceCandidates, basalGrowthCandidates,
                            self.basalPotentialOverlaps,
                            self.initialPermanence, self.sampleSize,
                            self.permanenceIncrement, self.permanenceDecrement,
                            self.maxSynapsesPerSegment)

            for learningSegments in (learningActiveApicalSegments,
                                     learningMatchingApicalSegments):

                self._learn(self.apicalConnections, self.rng, learningSegments,
                            apicalReinforceCandidates, apicalGrowthCandidates,
                            self.apicalPotentialOverlaps,
                            self.initialPermanence, self.sampleSize,
                            self.permanenceIncrement, self.permanenceDecrement,
                            self.maxSynapsesPerSegment)

            # Punish incorrect predictions
            if self.basalPredictedSegmentDecrement != 0.0:
                self.basalConnections.adjustActiveSynapses(
                    basalSegmentsToPunish, basalReinforceCandidates,
                    -self.basalPredictedSegmentDecrement)

            if self.apicalPredictedSegmentDecrement != 0.0:
                self.apicalConnections.adjustActiveSynapses(
                    apicalSegmentsToPunish, apicalReinforceCandidates,
                    -self.apicalPredictedSegmentDecrement)

            # Grow new segments
            if len(basalGrowthCandidates) > 0:
                self._learnOnNewSegments(self.basalConnections, self.rng,
                                         newBasalSegmentCells,
                                         basalGrowthCandidates,
                                         self.initialPermanence,
                                         self.sampleSize,
                                         self.maxSynapsesPerSegment)

            if len(apicalGrowthCandidates) > 0:
                self._learnOnNewSegments(self.apicalConnections, self.rng,
                                         newApicalSegmentCells,
                                         apicalGrowthCandidates,
                                         self.initialPermanence,
                                         self.sampleSize,
                                         self.maxSynapsesPerSegment)

        # Save the results
        newActiveCells.sort()
        learningCells.sort()
        self.activeCells = newActiveCells
        self.winnerCells = learningCells
        self.predictedActiveCells = correctPredictedCells

    def _calculateBasalLearning(self, activeColumns, burstingColumns,
                                correctPredictedCells, activeBasalSegments,
                                matchingBasalSegments, basalPotentialOverlaps):
        """
    Basic Temporal Memory learning. Correctly predicted cells always have
    active basal segments, and we learn on these segments. In bursting
    columns, we either learn on an existing basal segment, or we grow a new one.

    The only influence apical dendrites have on basal learning is: the apical
    dendrites influence which cells are considered "predicted". So an active
    apical dendrite can prevent some basal segments in active columns from
    learning.

    @param correctPredictedCells (numpy array)
    @param burstingColumns (numpy array)
    @param activeBasalSegments (numpy array)
    @param matchingBasalSegments (numpy array)
    @param basalPotentialOverlaps (numpy array)

    @return (tuple)
    - learningActiveBasalSegments (numpy array)
      Active basal segments on correct predicted cells

    - learningMatchingBasalSegments (numpy array)
      Matching basal segments selected for learning in bursting columns

    - basalSegmentsToPunish (numpy array)
      Basal segments that should be punished for predicting an inactive column

    - newBasalSegmentCells (numpy array)
      Cells in bursting columns that were selected to grow new basal segments

    - learningCells (numpy array)
      Cells that have learning basal segments or are selected to grow a basal
      segment
    """

        # Correctly predicted columns
        learningActiveBasalSegments = self.basalConnections.filterSegmentsByCell(
            activeBasalSegments, correctPredictedCells)

        cellsForMatchingBasal = self.basalConnections.mapSegmentsToCells(
            matchingBasalSegments)
        matchingCells = np.unique(cellsForMatchingBasal)

        (matchingCellsInBurstingColumns,
         burstingColumnsWithNoMatch) = np2.setCompare(matchingCells,
                                                      burstingColumns,
                                                      matchingCells /
                                                      self.cellsPerColumn,
                                                      rightMinusLeft=True)

        learningMatchingBasalSegments = self._chooseBestSegmentPerColumn(
            self.basalConnections, matchingCellsInBurstingColumns,
            matchingBasalSegments, basalPotentialOverlaps, self.cellsPerColumn)
        newBasalSegmentCells = self._getCellsWithFewestSegments(
            self.basalConnections, self.rng, burstingColumnsWithNoMatch,
            self.cellsPerColumn)

        learningCells = np.concatenate(
            (correctPredictedCells,
             self.basalConnections.mapSegmentsToCells(
                 learningMatchingBasalSegments), newBasalSegmentCells))

        # Incorrectly predicted columns
        correctMatchingBasalMask = np.in1d(
            cellsForMatchingBasal / self.cellsPerColumn, activeColumns)

        basalSegmentsToPunish = matchingBasalSegments[
            ~correctMatchingBasalMask]

        return (learningActiveBasalSegments, learningMatchingBasalSegments,
                basalSegmentsToPunish, newBasalSegmentCells, learningCells)

    def _calculateApicalLearning(self, learningCells, activeColumns,
                                 activeApicalSegments, matchingApicalSegments,
                                 apicalPotentialOverlaps):
        """
    Calculate apical learning for each learning cell.

    The set of learning cells was determined completely from basal segments.
    Do all apical learning on the same cells.

    Learn on any active segments on learning cells. For cells without active
    segments, learn on the best matching segment. For cells without a matching
    segment, grow a new segment.

    @param learningCells (numpy array)
    @param correctPredictedCells (numpy array)
    @param activeApicalSegments (numpy array)
    @param matchingApicalSegments (numpy array)
    @param apicalPotentialOverlaps (numpy array)

    @return (tuple)
    - learningActiveApicalSegments (numpy array)
      Active apical segments on correct predicted cells

    - learningMatchingApicalSegments (numpy array)
      Matching apical segments selected for learning in bursting columns

    - apicalSegmentsToPunish (numpy array)
      Apical segments that should be punished for predicting an inactive column

    - newApicalSegmentCells (numpy array)
      Cells in bursting columns that were selected to grow new apical segments
    """

        # Cells with active apical segments
        learningActiveApicalSegments = self.apicalConnections.filterSegmentsByCell(
            activeApicalSegments, learningCells)

        # Cells with matching apical segments
        learningCellsWithoutActiveApical = np.setdiff1d(
            learningCells,
            self.apicalConnections.mapSegmentsToCells(
                learningActiveApicalSegments))
        cellsForMatchingApical = self.apicalConnections.mapSegmentsToCells(
            matchingApicalSegments)
        learningCellsWithMatchingApical = np.intersect1d(
            learningCellsWithoutActiveApical, cellsForMatchingApical)
        learningMatchingApicalSegments = self._chooseBestSegmentPerCell(
            self.apicalConnections, learningCellsWithMatchingApical,
            matchingApicalSegments, apicalPotentialOverlaps)

        # Cells that need to grow an apical segment
        newApicalSegmentCells = np.setdiff1d(learningCellsWithoutActiveApical,
                                             learningCellsWithMatchingApical)

        # Incorrectly predicted columns
        correctMatchingApicalMask = np.in1d(
            cellsForMatchingApical / self.cellsPerColumn, activeColumns)

        apicalSegmentsToPunish = matchingApicalSegments[
            ~correctMatchingApicalMask]

        return (learningActiveApicalSegments, learningMatchingApicalSegments,
                apicalSegmentsToPunish, newApicalSegmentCells)

    @staticmethod
    def _calculateApicalSegmentActivity(connections, activeInput,
                                        connectedPermanence,
                                        activationThreshold, minThreshold):
        """
    Calculate the active and matching apical segments for this timestep.

    @param connections (SparseMatrixConnections)
    @param activeInput (numpy array)

    @return (tuple)
    - activeSegments (numpy array)
      Dendrite segments with enough active connected synapses to cause a
      dendritic spike

    - matchingSegments (numpy array)
      Dendrite segments with enough active potential synapses to be selected for
      learning in a bursting column

    - potentialOverlaps (numpy array)
      The number of active potential synapses for each segment.
      Includes counts for active, matching, and nonmatching segments.
    """

        # Active
        overlaps = connections.computeActivity(activeInput,
                                               connectedPermanence)
        activeSegments = np.flatnonzero(overlaps >= activationThreshold)

        # Matching
        potentialOverlaps = connections.computeActivity(activeInput)
        matchingSegments = np.flatnonzero(potentialOverlaps >= minThreshold)

        return (activeSegments, matchingSegments, potentialOverlaps)

    @staticmethod
    def _calculateBasalSegmentActivity(connections, activeInput,
                                       reducedBasalThresholdCells,
                                       connectedPermanence,
                                       activationThreshold, minThreshold,
                                       reducedBasalThreshold):
        """
    Calculate the active and matching basal segments for this timestep.

    The difference with _calculateApicalSegmentActivity is that cells
    with active apical segments (collected in reducedBasalThresholdCells) have
    a lower activation threshold for their basal segments (set by
    reducedBasalThreshold parameter).

    @param connections (SparseMatrixConnections)
    @param activeInput (numpy array)

    @return (tuple)
    - activeSegments (numpy array)
      Dendrite segments with enough active connected synapses to cause a
      dendritic spike

    - matchingSegments (numpy array)
      Dendrite segments with enough active potential synapses to be selected for
      learning in a bursting column

    - potentialOverlaps (numpy array)
      The number of active potential synapses for each segment.
      Includes counts for active, matching, and nonmatching segments.
    """
        # Active apical segments lower the activation threshold for basal (lateral) segments
        overlaps = connections.computeActivity(activeInput,
                                               connectedPermanence)
        outrightActiveSegments = np.flatnonzero(
            overlaps >= activationThreshold)
        if reducedBasalThreshold != activationThreshold and len(
                reducedBasalThresholdCells) > 0:
            potentiallyActiveSegments = np.flatnonzero(
                (overlaps < activationThreshold)
                & (overlaps >= reducedBasalThreshold))
            cellsOfCASegments = connections.mapSegmentsToCells(
                potentiallyActiveSegments)
            # apically active segments are condit. active segments from apically active cells
            conditionallyActiveSegments = potentiallyActiveSegments[np.in1d(
                cellsOfCASegments, reducedBasalThresholdCells)]
            activeSegments = np.concatenate(
                (outrightActiveSegments, conditionallyActiveSegments))
        else:
            activeSegments = outrightActiveSegments

        # Matching
        potentialOverlaps = connections.computeActivity(activeInput)
        matchingSegments = np.flatnonzero(potentialOverlaps >= minThreshold)

        return (activeSegments, matchingSegments, potentialOverlaps)

    def _calculatePredictedCells(self, activeBasalSegments,
                                 activeApicalSegments):
        """
    Calculate the predicted cells, given the set of active segments.

    An active basal segment is enough to predict a cell.
    An active apical segment is *not* enough to predict a cell.

    When a cell has both types of segments active, other cells in its minicolumn
    must also have both types of segments to be considered predictive.

    @param activeBasalSegments (numpy array)
    @param activeApicalSegments (numpy array)

    @return (numpy array)
    """

        cellsForBasalSegments = self.basalConnections.mapSegmentsToCells(
            activeBasalSegments)
        cellsForApicalSegments = self.apicalConnections.mapSegmentsToCells(
            activeApicalSegments)

        fullyDepolarizedCells = np.intersect1d(cellsForBasalSegments,
                                               cellsForApicalSegments)
        partlyDepolarizedCells = np.setdiff1d(cellsForBasalSegments,
                                              fullyDepolarizedCells)

        inhibitedMask = np.in1d(partlyDepolarizedCells / self.cellsPerColumn,
                                fullyDepolarizedCells / self.cellsPerColumn)
        predictedCells = np.append(fullyDepolarizedCells,
                                   partlyDepolarizedCells[~inhibitedMask])

        if self.useApicalTiebreak == False:
            predictedCells = cellsForBasalSegments

        return predictedCells

    @staticmethod
    def _learn(connections, rng, learningSegments, activeInput,
               growthCandidates, potentialOverlaps, initialPermanence,
               sampleSize, permanenceIncrement, permanenceDecrement,
               maxSynapsesPerSegment):
        """
    Adjust synapse permanences, grow new synapses, and grow new segments.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param activeInput (numpy array)
    @param growthCandidates (numpy array)
    @param potentialOverlaps (numpy array)
    """

        # Learn on existing segments
        connections.adjustSynapses(learningSegments, activeInput,
                                   permanenceIncrement, -permanenceDecrement)

        # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
        # grow per segment. "maxNew" might be a number or it might be a list of
        # numbers.
        if sampleSize == -1:
            maxNew = len(growthCandidates)
        else:
            maxNew = sampleSize - potentialOverlaps[learningSegments]

        if maxSynapsesPerSegment != -1:
            synapseCounts = connections.mapSegmentsToSynapseCounts(
                learningSegments)
            numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
            maxNew = np.where(maxNew <= numSynapsesToReachMax, maxNew,
                              numSynapsesToReachMax)

        connections.growSynapsesToSample(learningSegments, growthCandidates,
                                         maxNew, initialPermanence, rng)

    @staticmethod
    def _learnOnNewSegments(connections, rng, newSegmentCells,
                            growthCandidates, initialPermanence, sampleSize,
                            maxSynapsesPerSegment):

        numNewSynapses = len(growthCandidates)

        if sampleSize != -1:
            numNewSynapses = min(numNewSynapses, sampleSize)

        if maxSynapsesPerSegment != -1:
            numNewSynapses = min(numNewSynapses, maxSynapsesPerSegment)

        newSegments = connections.createSegments(newSegmentCells)
        connections.growSynapsesToSample(newSegments, growthCandidates,
                                         numNewSynapses, initialPermanence,
                                         rng)

    @classmethod
    def _chooseBestSegmentPerCell(cls, connections, cells, allMatchingSegments,
                                  potentialOverlaps):
        """
    For each specified cell, choose its matching segment with largest number
    of active potential synapses. When there's a tie, the first segment wins.

    @param connections (SparseMatrixConnections)
    @param cells (numpy array)
    @param allMatchingSegments (numpy array)
    @param potentialOverlaps (numpy array)

    @return (numpy array)
    One segment per cell
    """

        candidateSegments = connections.filterSegmentsByCell(
            allMatchingSegments, cells)

        # Narrow it down to one pair per cell.
        onePerCellFilter = np2.argmaxMulti(
            potentialOverlaps[candidateSegments],
            connections.mapSegmentsToCells(candidateSegments))
        learningSegments = candidateSegments[onePerCellFilter]

        return learningSegments

    @classmethod
    def _chooseBestSegmentPerColumn(cls, connections, matchingCells,
                                    allMatchingSegments, potentialOverlaps,
                                    cellsPerColumn):
        """
    For all the columns covered by 'matchingCells', choose the column's matching
    segment with largest number of active potential synapses. When there's a
    tie, the first segment wins.

    @param connections (SparseMatrixConnections)
    @param matchingCells (numpy array)
    @param allMatchingSegments (numpy array)
    @param potentialOverlaps (numpy array)
    """

        candidateSegments = connections.filterSegmentsByCell(
            allMatchingSegments, matchingCells)

        # Narrow it down to one segment per column.
        cellScores = potentialOverlaps[candidateSegments]
        columnsForCandidates = (
            connections.mapSegmentsToCells(candidateSegments) / cellsPerColumn)
        onePerColumnFilter = np2.argmaxMulti(cellScores, columnsForCandidates)

        learningSegments = candidateSegments[onePerColumnFilter]

        return learningSegments

    @classmethod
    def _getCellsWithFewestSegments(cls, connections, rng, columns,
                                    cellsPerColumn):
        """
    For each column, get the cell that has the fewest total basal segments.
    Break ties randomly.

    @param connections (SparseMatrixConnections)
    @param rng (Random)
    @param columns (numpy array) Columns to check

    @return (numpy array)
    One cell for each of the provided columns
    """
        candidateCells = np2.getAllCellsInColumns(columns, cellsPerColumn)

        # Arrange the segment counts into one row per minicolumn.
        segmentCounts = np.reshape(
            connections.getSegmentCounts(candidateCells),
            newshape=(len(columns), cellsPerColumn))

        # Filter to just the cells that are tied for fewest in their minicolumn.
        minSegmentCounts = np.amin(segmentCounts, axis=1, keepdims=True)
        candidateCells = candidateCells[np.flatnonzero(
            segmentCounts == minSegmentCounts)]

        # Filter to one cell per column, choosing randomly from the minimums.
        # To do the random choice, add a random offset to each index in-place, using
        # casting to floor the result.
        (_, onePerColumnFilter,
         numCandidatesInColumns) = np.unique(candidateCells / cellsPerColumn,
                                             return_index=True,
                                             return_counts=True)

        offsetPercents = np.empty(len(columns), dtype="float32")
        rng.initializeReal32Array(offsetPercents)

        np.add(onePerColumnFilter,
               offsetPercents * numCandidatesInColumns,
               out=onePerColumnFilter,
               casting="unsafe")

        return candidateCells[onePerColumnFilter]

    def getActiveCells(self):
        """
    @return (numpy array)
    Active cells
    """
        return self.activeCells

    def getPredictedActiveCells(self):
        """
    @return (numpy array)
    Active cells that were correctly predicted
    """
        return self.predictedActiveCells

    def getWinnerCells(self):
        """
    @return (numpy array)
    Cells that were selected for learning
    """
        return self.winnerCells

    def getActiveBasalSegments(self):
        """
    @return (numpy array)
    Active basal segments for this timestep
    """
        return self.activeBasalSegments

    def getActiveApicalSegments(self):
        """
    @return (numpy array)
    Matching basal segments for this timestep
    """
        return self.activeApicalSegments

    def numberOfColumns(self):
        """ Returns the number of columns in this layer.

    @return (int) Number of columns
    """
        return self.columnCount

    def numberOfCells(self):
        """
    Returns the number of cells in this layer.

    @return (int) Number of cells
    """
        return self.numberOfColumns() * self.cellsPerColumn

    def getCellsPerColumn(self):
        """
    Returns the number of cells per column.

    @return (int) The number of cells per column.
    """
        return self.cellsPerColumn

    def getActivationThreshold(self):
        """
    Returns the activation threshold.
    @return (int) The activation threshold.
    """
        return self.activationThreshold

    def setActivationThreshold(self, activationThreshold):
        """
    Sets the activation threshold.
    @param activationThreshold (int) activation threshold.
    """
        self.activationThreshold = activationThreshold

    def getReducedBasalThreshold(self):
        """
    Returns the reduced basal activation threshold for apically active cells.
    @return (int) The activation threshold.
    """
        return self.reducedBasalThreshold

    def setReducedBasalThreshold(self, reducedBasalThreshold):
        """
    Sets the reduced basal activation threshold for apically active cells.
    @param reducedBasalThreshold (int) activation threshold.
    """
        self.reducedBasalThreshold = reducedBasalThreshold

    def getInitialPermanence(self):
        """
    Get the initial permanence.
    @return (float) The initial permanence.
    """
        return self.initialPermanence

    def setInitialPermanence(self, initialPermanence):
        """
    Sets the initial permanence.
    @param initialPermanence (float) The initial permanence.
    """
        self.initialPermanence = initialPermanence

    def getMinThreshold(self):
        """
    Returns the min threshold.
    @return (int) The min threshold.
    """
        return self.minThreshold

    def setMinThreshold(self, minThreshold):
        """
    Sets the min threshold.
    @param minThreshold (int) min threshold.
    """
        self.minThreshold = minThreshold

    def getSampleSize(self):
        """
    Gets the sampleSize.
    @return (int)
    """
        return self.sampleSize

    def setSampleSize(self, sampleSize):
        """
    Sets the sampleSize.
    @param sampleSize (int)
    """
        self.sampleSize = sampleSize

    def getPermanenceIncrement(self):
        """
    Get the permanence increment.
    @return (float) The permanence increment.
    """
        return self.permanenceIncrement

    def setPermanenceIncrement(self, permanenceIncrement):
        """
    Sets the permanence increment.
    @param permanenceIncrement (float) The permanence increment.
    """
        self.permanenceIncrement = permanenceIncrement

    def getPermanenceDecrement(self):
        """
    Get the permanence decrement.
    @return (float) The permanence decrement.
    """
        return self.permanenceDecrement

    def setPermanenceDecrement(self, permanenceDecrement):
        """
    Sets the permanence decrement.
    @param permanenceDecrement (float) The permanence decrement.
    """
        self.permanenceDecrement = permanenceDecrement

    def getBasalPredictedSegmentDecrement(self):
        """
    Get the predicted segment decrement.
    @return (float) The predicted segment decrement.
    """
        return self.basalPredictedSegmentDecrement

    def setBasalPredictedSegmentDecrement(self, predictedSegmentDecrement):
        """
    Sets the predicted segment decrement.
    @param predictedSegmentDecrement (float) The predicted segment decrement.
    """
        self.basalPredictedSegmentDecrement = basalPredictedSegmentDecrement

    def getApicalPredictedSegmentDecrement(self):
        """
    Get the predicted segment decrement.
    @return (float) The predicted segment decrement.
    """
        return self.apicalPredictedSegmentDecrement

    def setApicalPredictedSegmentDecrement(self, predictedSegmentDecrement):
        """
    Sets the predicted segment decrement.
    @param predictedSegmentDecrement (float) The predicted segment decrement.
    """
        self.apicalPredictedSegmentDecrement = apicalPredictedSegmentDecrement

    def getConnectedPermanence(self):
        """
    Get the connected permanence.
    @return (float) The connected permanence.
    """
        return self.connectedPermanence

    def setConnectedPermanence(self, connectedPermanence):
        """
    Sets the connected permanence.
    @param connectedPermanence (float) The connected permanence.
    """
        self.connectedPermanence = connectedPermanence

    def getUseApicalTieBreak(self):
        """
    Get whether we actually use apical tie-break.
    @return (Bool) Whether apical tie-break is used.
    """
        return self.useApicalTiebreak

    def setUseApicalTiebreak(self, useApicalTiebreak):
        """
    Sets whether we actually use apical tie-break.
    @param useApicalTiebreak (Bool) Whether apical tie-break is used.
    """
        self.useApicalTiebreak = useApicalTiebreak

    def getUseApicalModulationBasalThreshold(self):
        """
    Get whether we actually use apical modulation of basal threshold.
    @return (Bool) Whether apical modulation is used.
    """
        return self.useApicalModulationBasalThreshold

    def setUseApicalModulationBasalThreshold(
            self, useApicalModulationBasalThreshold):
        """
    Sets whether we actually use apical modulation of basal threshold.
    @param useApicalModulationBasalThreshold (Bool) Whether apical modulation is used.
    """
        self.useApicalModulationBasalThreshold = useApicalModulationBasalThreshold
class SuperficialLocationModule2D(object):
    """
  A model of a location module. It's similar to a grid cell module, but it uses
  squares rather than triangles.

  The cells are arranged into a m*n rectangle which is tiled onto 2D space.
  Each cell represents a small rectangle in each tile.

  +------+------+------++------+------+------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #1  |  #2  |  #3  ||  #1  |  #2  |  #3  |
  |      |      |      ||      |      |      |
  +--------------------++--------------------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #4  |  #5  |  #6  ||  #4  |  #5  |  #6  |
  |      |      |      ||      |      |      |
  +--------------------++--------------------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #7  |  #8  |  #9  ||  #7  |  #8  |  #9  |
  |      |      |      ||      |      |      |
  +------+------+------++------+------+------+

  We assume that path integration works *somehow*. This model receives a "delta
  location" vector, and it shifts the active cells accordingly. The model stores
  intermediate coordinates of active cells. Whenever sensory cues activate a
  cell, the model adds this cell to the list of coordinates being shifted.
  Whenever sensory cues cause a cell to become inactive, that cell is removed
  from the list of coordinates.

  (This model doesn't attempt to propose how "path integration" works. It
  attempts to show how locations are anchored to sensory cues.)

  When orientation is set to 0 degrees, the deltaLocation is a [di, dj],
  moving di cells "down" and dj cells "right".

  When orientation is set to 90 degrees, the deltaLocation is essentially a
  [dx, dy], applied in typical x,y coordinates with the origin on the bottom
  left.

  Usage:

  Adjust the location in response to motor input:
    lm.shift([di, dj])

  Adjust the location in response to sensory input:
    lm.anchor(anchorInput)

  Learn an anchor input for the current location:
    lm.learn(anchorInput)

  The "anchor input" is typically a feature-location pair SDR.

  During inference, you'll typically call:
    lm.shift(...)
    # Consume lm.getActiveCells()
    # ...
    lm.anchor(...)

  During learning, you'll do the same, but you'll call lm.learn() instead of
  lm.anchor().
  """
    def __init__(self,
                 cellDimensions,
                 moduleMapDimensions,
                 orientation,
                 anchorInputSize,
                 pointOffsets=(0.5, ),
                 activationThreshold=10,
                 initialPermanence=0.21,
                 connectedPermanence=0.50,
                 learningThreshold=10,
                 sampleSize=20,
                 permanenceIncrement=0.1,
                 permanenceDecrement=0.0,
                 maxSynapsesPerSegment=-1,
                 seed=42):
        """
    @param cellDimensions (tuple(int, int))
    Determines the number of cells. Determines how space is divided between the
    cells.

    @param moduleMapDimensions (tuple(float, float))
    Determines the amount of world space covered by all of the cells combined.
    In grid cell terminology, this is equivalent to the "scale" of a module.
    A module with a scale of "5cm" would have moduleMapDimensions=(5.0, 5.0).

    @param orientation (float)
    The rotation of this map, measured in radians.

    @param anchorInputSize (int)
    The number of input bits in the anchor input.

    @param pointOffsets (list of floats)
    These must each be between 0.0 and 1.0. Every time a cell is activated by
    anchor input, this class adds a "point" which is shifted in subsequent
    motions. By default, this point is placed at the center of the cell. This
    parameter allows you to control where the point is placed and whether multiple
    are placed. For example, With value [0.2, 0.8], it will place 4 points:
    [0.2, 0.2], [0.2, 0.8], [0.8, 0.2], [0.8, 0.8]
    """

        self.cellDimensions = np.asarray(cellDimensions, dtype="int")
        self.moduleMapDimensions = np.asarray(moduleMapDimensions,
                                              dtype="float")
        self.cellFieldsPerUnitDistance = self.cellDimensions / self.moduleMapDimensions

        self.orientation = orientation
        self.rotationMatrix = np.array(
            [[math.cos(orientation), -math.sin(orientation)],
             [math.sin(orientation),
              math.cos(orientation)]])

        self.pointOffsets = pointOffsets

        # These coordinates are in units of "cell fields".
        self.activePoints = np.empty((0, 2), dtype="float")
        self.cellsForActivePoints = np.empty(0, dtype="int")

        self.activeCells = np.empty(0, dtype="int")
        self.activeSegments = np.empty(0, dtype="uint32")

        self.connections = SparseMatrixConnections(np.prod(cellDimensions),
                                                   anchorInputSize)

        self.initialPermanence = initialPermanence
        self.connectedPermanence = connectedPermanence
        self.learningThreshold = learningThreshold
        self.sampleSize = sampleSize
        self.permanenceIncrement = permanenceIncrement
        self.permanenceDecrement = permanenceDecrement
        self.activationThreshold = activationThreshold
        self.maxSynapsesPerSegment = maxSynapsesPerSegment

        self.rng = Random(seed)

    def reset(self):
        """
    Clear the active cells.
    """
        self.activePoints = np.empty((0, 2), dtype="float")
        self.cellsForActivePoints = np.empty(0, dtype="int")
        self.activeCells = np.empty(0, dtype="int")

    def _computeActiveCells(self):
        # Round each coordinate to the nearest cell.
        flooredActivePoints = np.floor(self.activePoints).astype("int")

        # Convert coordinates to cell numbers.
        self.cellsForActivePoints = (np.ravel_multi_index(
            flooredActivePoints.T, self.cellDimensions))
        self.activeCells = np.unique(self.cellsForActivePoints)

    def activateRandomLocation(self):
        """
    Set the location to a random point.
    """
        self.activePoints = np.array(
            [np.random.random(2) * self.cellDimensions])
        self._computeActiveCells()

    def shift(self, deltaLocation):
        """
    Shift the current active cells by a vector.

    @param deltaLocation (pair of floats)
    A translation vector [di, dj].
    """
        # Calculate delta in the module's coordinates.
        deltaLocationInCellFields = (
            np.matmul(self.rotationMatrix, deltaLocation) *
            self.cellFieldsPerUnitDistance)

        # Shift the active coordinates.
        np.add(self.activePoints,
               deltaLocationInCellFields,
               out=self.activePoints)
        np.mod(self.activePoints, self.cellDimensions, out=self.activePoints)

        self._computeActiveCells()

    def anchor(self, anchorInput):
        """
    Infer the location from sensory input. Activate any cells with enough active
    synapses to this sensory input. Deactivate all other cells.

    @param anchorInput (numpy array)
    A sensory input. This will often come from a feature-location pair layer.
    """
        if len(anchorInput) == 0:
            return

        overlaps = self.connections.computeActivity(anchorInput,
                                                    self.connectedPermanence)
        activeSegments = np.where(overlaps >= self.activationThreshold)[0]

        sensorySupportedCells = np.unique(
            self.connections.mapSegmentsToCells(activeSegments))

        inactivated = np.setdiff1d(self.activeCells, sensorySupportedCells)
        inactivatedIndices = np.in1d(self.cellsForActivePoints,
                                     inactivated).nonzero()[0]
        if inactivatedIndices.size > 0:
            self.activePoints = np.delete(self.activePoints,
                                          inactivatedIndices,
                                          axis=0)

        activated = np.setdiff1d(sensorySupportedCells, self.activeCells)

        activatedCoordsBase = np.transpose(
            np.unravel_index(activated, self.cellDimensions)).astype('float')

        activatedCoords = np.concatenate([
            activatedCoordsBase + [iOffset, jOffset]
            for iOffset in self.pointOffsets for jOffset in self.pointOffsets
        ])
        if activatedCoords.size > 0:
            self.activePoints = np.append(self.activePoints,
                                          activatedCoords,
                                          axis=0)

        self._computeActiveCells()
        self.activeSegments = activeSegments

    def learn(self, anchorInput):
        """
    Associate this location with a sensory input. Subsequently, anchorInput will
    activate the current location during anchor().

    @param anchorInput (numpy array)
    A sensory input. This will often come from a feature-location pair layer.
    """
        overlaps = self.connections.computeActivity(anchorInput,
                                                    self.connectedPermanence)
        activeSegments = np.where(overlaps >= self.activationThreshold)[0]

        potentialOverlaps = self.connections.computeActivity(anchorInput)
        matchingSegments = np.where(
            potentialOverlaps >= self.learningThreshold)[0]

        # Cells with a active segment: reinforce the segment
        cellsForActiveSegments = self.connections.mapSegmentsToCells(
            activeSegments)
        learningActiveSegments = activeSegments[np.in1d(
            cellsForActiveSegments, self.activeCells)]
        remainingCells = np.setdiff1d(self.activeCells, cellsForActiveSegments)

        # Remaining cells with a matching segment: reinforce the best
        # matching segment.
        candidateSegments = self.connections.filterSegmentsByCell(
            matchingSegments, remainingCells)
        cellsForCandidateSegments = (
            self.connections.mapSegmentsToCells(candidateSegments))
        candidateSegments = candidateSegments[np.in1d(
            cellsForCandidateSegments, remainingCells)]
        onePerCellFilter = np2.argmaxMulti(
            potentialOverlaps[candidateSegments], cellsForCandidateSegments)
        learningMatchingSegments = candidateSegments[onePerCellFilter]

        newSegmentCells = np.setdiff1d(remainingCells,
                                       cellsForCandidateSegments)

        for learningSegments in (learningActiveSegments,
                                 learningMatchingSegments):
            self._learn(self.connections, self.rng, learningSegments,
                        anchorInput, potentialOverlaps, self.initialPermanence,
                        self.sampleSize, self.permanenceIncrement,
                        self.permanenceDecrement, self.maxSynapsesPerSegment)

        # Remaining cells without a matching segment: grow one.
        numNewSynapses = len(anchorInput)

        if self.sampleSize != -1:
            numNewSynapses = min(numNewSynapses, self.sampleSize)

        if self.maxSynapsesPerSegment != -1:
            numNewSynapses = min(numNewSynapses, self.maxSynapsesPerSegment)

        newSegments = self.connections.createSegments(newSegmentCells)

        self.connections.growSynapsesToSample(newSegments, anchorInput,
                                              numNewSynapses,
                                              self.initialPermanence, self.rng)

        self.activeSegments = activeSegments

    @staticmethod
    def _learn(connections, rng, learningSegments, activeInput,
               potentialOverlaps, initialPermanence, sampleSize,
               permanenceIncrement, permanenceDecrement,
               maxSynapsesPerSegment):
        """
    Adjust synapse permanences, grow new synapses, and grow new segments.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param segmentsToPunish (numpy array)
    @param activeInput (numpy array)
    @param potentialOverlaps (numpy array)
    """
        # Learn on existing segments
        connections.adjustSynapses(learningSegments, activeInput,
                                   permanenceIncrement, -permanenceDecrement)

        # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
        # grow per segment. "maxNew" might be a number or it might be a list of
        # numbers.
        if sampleSize == -1:
            maxNew = len(activeInput)
        else:
            maxNew = sampleSize - potentialOverlaps[learningSegments]

        if maxSynapsesPerSegment != -1:
            synapseCounts = connections.mapSegmentsToSynapseCounts(
                learningSegments)
            numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
            maxNew = np.where(maxNew <= numSynapsesToReachMax, maxNew,
                              numSynapsesToReachMax)

        connections.growSynapsesToSample(learningSegments, activeInput, maxNew,
                                         initialPermanence, rng)

    def getActiveCells(self):
        return self.activeCells

    def numberOfCells(self):
        return np.prod(self.cellDimensions)
Beispiel #12
0
class Thalamus(object):
  """

  A simple discrete time thalamus.  This thalamus has a 2D TRN layer and a 2D
  relay cell layer. L6 cells project to the dendrites of TRN cells - these
  connections are learned. TRN cells project to the dendrites of relay cells in
  a fixed fan-out pattern. A 2D feed forward input source projects to the relay
  cells in a fixed fan-out pattern.

  The output of the thalamus is the activity of each relay cell. This activity
  can be in one of three states: inactive, active (tonic), and active (burst).

  TRN cells control whether the relay cells will burst. If any dendrite on a TRN
  cell recognizes the current L6 pattern, it de-inactivates the T-type CA2+
  channels on the dendrites of any relay cell it projects to. These relay cells
  are then in "burst-ready mode".

  Feed forward activity is in the form of a binary vector corresponding to
  active/spiking axons (e.g. from ganglion cells). Any relay cells that receive
  input from an axon will either output tonic or burst activity depending on the
  state of the T-type CA2+ channels on their dendrites. Relay cells that don't
  receive input will remain inactive, regardless of their dendritic state.

  Usage:

    1. Train the TRN cells on a bunch of L6 patterns: learnL6Pattern()

    2. De-inactivate relay cells by sending in an L6 pattern: deInactivateCells()

    3. Compute feed forward activity for an input: computeFeedForwardActivity()

    4. reset()

    5. Goto 2

  """

  def __init__(self,
               trnCellShape=(32, 32),
               relayCellShape=(32, 32),
               inputShape=(32, 32),
               l6CellCount=1024,
               trnThreshold=10,
               relayThreshold=1,
               seed=42):
    """

    :param trnCellShape:
      a 2D shape for the TRN

    :param relayCellShape:
      a 2D shape for the relay cells

    :param l6CellCount:
      number of L6 cells

    :param trnThreshold:
      dendritic threshold for TRN cells. This is the min number of active L6
      cells on a dendrite for the TRN cell to recognize a pattern on that
      dendrite.

    :param relayThreshold:
      dendritic threshold for relay cells. This is the min number of active TRN
      cells on a dendrite for the relay cell to recognize a pattern on that
      dendrite.

    :param seed:
        Seed for the random number generator.
    """

    self.trnCellShape = trnCellShape
    self.trnWidth = trnCellShape[0]
    self.trnHeight = trnCellShape[1]
    self.relayCellShape = relayCellShape
    self.relayWidth = relayCellShape[0]
    self.relayHeight = relayCellShape[1]
    self.l6CellCount = l6CellCount
    self.trnThreshold = trnThreshold
    self.relayThreshold = relayThreshold
    self.inputShape = inputShape
    self.seed = seed
    self.rng = Random(seed)
    self.trnActivationThreshold = 5

    self.trnConnections = SparseMatrixConnections(
      trnCellShape[0]*trnCellShape[1], l6CellCount)

    self.relayConnections = SparseMatrixConnections(
      relayCellShape[0]*relayCellShape[1],
      trnCellShape[0]*trnCellShape[1])

    # Initialize/reset variables that are updated with calls to compute
    self.reset()

    self._initializeTRNToRelayCellConnections()


  def learnL6Pattern(self, l6Pattern, cellsToLearnOn):
    """
    Learn the given l6Pattern on TRN cell dendrites. The TRN cells to learn
    are given in cellsTeLearnOn. Each of these cells will learn this pattern on
    a single dendritic segment.

    :param l6Pattern:
      An SDR from L6. List of indices corresponding to L6 cells.

    :param cellsToLearnOn:
      Each cell index is (x,y) corresponding to the TRN cells that should learn
      this pattern. For each cell, create a new dendrite that stores this
      pattern. The SDR is stored on this dendrite


    """
    cellIndices = [self.trnCellIndex(x) for x in cellsToLearnOn]
    newSegments = self.trnConnections.createSegments(cellIndices)
    self.trnConnections.growSynapses(newSegments, l6Pattern, 1.0)

    # print("Learning L6 SDR:", l6Pattern,
    #       "new segments: ", newSegments,
    #       "cells:", self.trnConnections.mapSegmentsToCells(newSegments))


  def deInactivateCells(self, l6Input):
    """
    Activate trnCells according to the l6Input. These in turn will impact 
    bursting mode in relay cells that are connected to these trnCells.
    Given the feedForwardInput, compute which cells will be silent, tonic,
    or bursting.
    
    :param l6Input:

    :return: nothing
    """

    # Figure out which TRN cells recognize the L6 pattern.
    self.trnOverlaps = self.trnConnections.computeActivity(l6Input, 0.5)
    self.activeTRNSegments = np.flatnonzero(
      self.trnOverlaps >= self.trnActivationThreshold)
    self.activeTRNCellIndices = self.trnConnections.mapSegmentsToCells(
      self.activeTRNSegments)

    # print("trnOverlaps:", self.trnOverlaps,
    #       "active segments:", self.activeTRNSegments)
    for s, idx in zip(self.activeTRNSegments, self.activeTRNCellIndices):
      print(self.trnOverlaps[s], idx, self.trnIndextoCoord(idx))


    # Figure out which relay cells have dendrites in de-inactivated state
    self.relayOverlaps = self.relayConnections.computeActivity(
      self.activeTRNCellIndices, 0.5
    )
    self.activeRelaySegments = np.flatnonzero(
      self.relayOverlaps >= self.relayThreshold)
    self.burstReadyCellIndices = self.relayConnections.mapSegmentsToCells(
      self.activeRelaySegments)

    self.burstReadyCells.reshape(-1)[self.burstReadyCellIndices] = 1


  def computeFeedForwardActivity(self, feedForwardInput):
    """
    Activate trnCells according to the l6Input. These in turn will impact
    bursting mode in relay cells that are connected to these trnCells.
    Given the feedForwardInput, compute which cells will be silent, tonic,
    or bursting.

    :param feedForwardInput:
      a numpy matrix of shape relayCellShape containing 0's and 1's

    :return:
      feedForwardInput is modified to contain 0, 1, or 2. A "2" indicates
      bursting cells.
    """
    feedForwardInput += self.burstReadyCells * feedForwardInput


  def reset(self):
    """
    Set everything back to zero
    """
    self.trnOverlaps = []
    self.activeTRNSegments = []
    self.activeTRNCellIndices = []
    self.relayOverlaps = []
    self.activeRelaySegments = []
    self.burstReadyCellIndices = []
    self.burstReadyCells = np.zeros((self.relayWidth, self.relayHeight))


  def trnCellIndex(self, coord):
    """
    Map a 2D coordinate to 1D cell index.

    :param coord: a 2D coordinate

    :return: integer index
    """
    return coord[1] * self.trnWidth + coord[0]


  def trnIndextoCoord(self, i):
    """
    Map 1D cell index to a 2D coordinate

    :param i: integer 1D cell index

    :return: (x, y), a 2D coordinate
    """
    x = i % self.trnWidth
    y = i / self.trnWidth
    return x, y


  def relayCellIndex(self, coord):
    """
    Map a 2D coordinate to 1D cell index.

    :param coord: a 2D coordinate

    :return: integer index
    """
    return coord[1] * self.relayWidth + coord[0]


  def relayIndextoCoord(self, i):
    """
    Map 1D cell index to a 2D coordinate

    :param i: integer 1D cell index

    :return: (x, y), a 2D coordinate
    """
    x = i % self.relayWidth
    y = i / self.relayWidth
    return x, y


  def _initializeTRNToRelayCellConnections(self):
    """
    Initialize TRN to relay cell connectivity. For each relay cell, create a
    dendritic segment for each TRN cell it connects to.
    """
    for x in range(self.relayWidth):
      for y in range(self.relayHeight):

        # Create one dendrite for each trn cell that projects to this relay cell
        # This dendrite contains one synapse corresponding to this TRN->relay
        # connection.
        relayCellIndex = self.relayCellIndex((x,y))
        trnCells = self._preSynapticTRNCells(x, y)
        for trnCell in trnCells:
          newSegment = self.relayConnections.createSegments([relayCellIndex])
          self.relayConnections.growSynapses(newSegment,
                                             [self.trnCellIndex(trnCell)], 1.0)


  def _preSynapticTRNCells(self, i, j):
    """
    Given a relay cell at the given coordinate, return a list of the (x,y)
    coordinates of all TRN cells that project to it.

    :param relayCellCoordinate:

    :return:
    """
    xmin = max(i - 1, 0)
    xmax = min(i + 2, self.trnWidth)
    ymin = max(j - 1, 0)
    ymax = min(j + 2, self.trnHeight)
    trnCells = [
      (x, y) for x in range(xmin, xmax) for y in range(ymin, ymax)
    ]

    return trnCells
class SuperficialLocationModule2D(object):
  """
  A model of a location module. It's similar to a grid cell module, but it uses
  squares rather than triangles.

  The cells are arranged into a m*n rectangle which is tiled onto 2D space.
  Each cell represents a small rectangle in each tile.

  +------+------+------++------+------+------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #1  |  #2  |  #3  ||  #1  |  #2  |  #3  |
  |      |      |      ||      |      |      |
  +--------------------++--------------------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #4  |  #5  |  #6  ||  #4  |  #5  |  #6  |
  |      |      |      ||      |      |      |
  +--------------------++--------------------+
  | Cell | Cell | Cell || Cell | Cell | Cell |
  |  #7  |  #8  |  #9  ||  #7  |  #8  |  #9  |
  |      |      |      ||      |      |      |
  +------+------+------++------+------+------+

  We assume that path integration works *somehow*. This model receives a "delta
  location" vector, and it shifts the active cells accordingly. The model stores
  intermediate coordinates of active cells. Whenever sensory cues activate a
  cell, the model adds this cell to the list of coordinates being shifted.
  Whenever sensory cues cause a cell to become inactive, that cell is removed
  from the list of coordinates.

  (This model doesn't attempt to propose how "path integration" works. It
  attempts to show how locations are anchored to sensory cues.)

  When orientation is set to 0 degrees, the deltaLocation is a [di, dj],
  moving di cells "down" and dj cells "right".

  When orientation is set to 90 degrees, the deltaLocation is essentially a
  [dx, dy], applied in typical x,y coordinates with the origin on the bottom
  left.

  Usage:

  Adjust the location in response to motor input:
    lm.shift([di, dj])

  Adjust the location in response to sensory input:
    lm.anchor(anchorInput)

  Learn an anchor input for the current location:
    lm.learn(anchorInput)

  The "anchor input" is typically a feature-location pair SDR.

  During inference, you'll typically call:
    lm.shift(...)
    # Consume lm.getActiveCells()
    # ...
    lm.anchor(...)

  During learning, you'll do the same, but you'll call lm.learn() instead of
  lm.anchor().
  """


  def __init__(self,
               cellDimensions,
               moduleMapDimensions,
               orientation,
               anchorInputSize,
               pointOffsets=(0.5,),
               activationThreshold=10,
               initialPermanence=0.21,
               connectedPermanence=0.50,
               learningThreshold=10,
               sampleSize=20,
               permanenceIncrement=0.1,
               permanenceDecrement=0.0,
               maxSynapsesPerSegment=-1,
               seed=42):
    """
    @param cellDimensions (tuple(int, int))
    Determines the number of cells. Determines how space is divided between the
    cells.

    @param moduleMapDimensions (tuple(float, float))
    Determines the amount of world space covered by all of the cells combined.
    In grid cell terminology, this is equivalent to the "scale" of a module.
    A module with a scale of "5cm" would have moduleMapDimensions=(5.0, 5.0).

    @param orientation (float)
    The rotation of this map, measured in radians.

    @param anchorInputSize (int)
    The number of input bits in the anchor input.

    @param pointOffsets (list of floats)
    These must each be between 0.0 and 1.0. Every time a cell is activated by
    anchor input, this class adds a "point" which is shifted in subsequent
    motions. By default, this point is placed at the center of the cell. This
    parameter allows you to control where the point is placed and whether multiple
    are placed. For example, With value [0.2, 0.8], it will place 4 points:
    [0.2, 0.2], [0.2, 0.8], [0.8, 0.2], [0.8, 0.8]
    """

    self.cellDimensions = np.asarray(cellDimensions, dtype="int")
    self.moduleMapDimensions = np.asarray(moduleMapDimensions, dtype="float")
    self.cellFieldsPerUnitDistance = self.cellDimensions / self.moduleMapDimensions

    self.orientation = orientation
    self.rotationMatrix = np.array(
      [[math.cos(orientation), -math.sin(orientation)],
       [math.sin(orientation), math.cos(orientation)]])

    self.pointOffsets = pointOffsets

    # These coordinates are in units of "cell fields".
    self.activePoints = np.empty((0,2), dtype="float")
    self.cellsForActivePoints = np.empty(0, dtype="int")

    self.activeCells = np.empty(0, dtype="int")
    self.activeSegments = np.empty(0, dtype="uint32")

    self.connections = SparseMatrixConnections(np.prod(cellDimensions),
                                               anchorInputSize)

    self.initialPermanence = initialPermanence
    self.connectedPermanence = connectedPermanence
    self.learningThreshold = learningThreshold
    self.sampleSize = sampleSize
    self.permanenceIncrement = permanenceIncrement
    self.permanenceDecrement = permanenceDecrement
    self.activationThreshold = activationThreshold
    self.maxSynapsesPerSegment = maxSynapsesPerSegment

    self.rng = Random(seed)


  def reset(self):
    """
    Clear the active cells.
    """
    self.activePoints = np.empty((0,2), dtype="float")
    self.cellsForActivePoints = np.empty(0, dtype="int")
    self.activeCells = np.empty(0, dtype="int")


  def _computeActiveCells(self):
    # Round each coordinate to the nearest cell.
    flooredActivePoints = np.floor(self.activePoints).astype("int")

    # Convert coordinates to cell numbers.
    self.cellsForActivePoints = (
      np.ravel_multi_index(flooredActivePoints.T, self.cellDimensions))
    self.activeCells = np.unique(self.cellsForActivePoints)


  def activateRandomLocation(self):
    """
    Set the location to a random point.
    """
    self.activePoints = np.array([np.random.random(2) * self.cellDimensions])
    self._computeActiveCells()


  def shift(self, deltaLocation):
    """
    Shift the current active cells by a vector.

    @param deltaLocation (pair of floats)
    A translation vector [di, dj].
    """
    # Calculate delta in the module's coordinates.
    deltaLocationInCellFields = (np.matmul(self.rotationMatrix, deltaLocation) *
                                 self.cellFieldsPerUnitDistance)

    # Shift the active coordinates.
    np.add(self.activePoints, deltaLocationInCellFields, out=self.activePoints)
    np.mod(self.activePoints, self.cellDimensions, out=self.activePoints)

    self._computeActiveCells()


  def anchor(self, anchorInput):
    """
    Infer the location from sensory input. Activate any cells with enough active
    synapses to this sensory input. Deactivate all other cells.

    @param anchorInput (numpy array)
    A sensory input. This will often come from a feature-location pair layer.
    """
    if len(anchorInput) == 0:
      return

    overlaps = self.connections.computeActivity(anchorInput,
                                                self.connectedPermanence)
    activeSegments = np.where(overlaps >= self.activationThreshold)[0]

    sensorySupportedCells = np.unique(
      self.connections.mapSegmentsToCells(activeSegments))

    inactivated = np.setdiff1d(self.activeCells, sensorySupportedCells)
    inactivatedIndices = np.in1d(self.cellsForActivePoints,
                                 inactivated).nonzero()[0]
    if inactivatedIndices.size > 0:
      self.activePoints = np.delete(self.activePoints, inactivatedIndices,
                                    axis=0)

    activated = np.setdiff1d(sensorySupportedCells, self.activeCells)

    activatedCoordsBase = np.transpose(
      np.unravel_index(activated, self.cellDimensions)).astype('float')

    activatedCoords = np.concatenate(
      [activatedCoordsBase + [iOffset, jOffset]
       for iOffset in self.pointOffsets
       for jOffset in self.pointOffsets]
    )
    if activatedCoords.size > 0:
      self.activePoints = np.append(self.activePoints, activatedCoords, axis=0)

    self._computeActiveCells()
    self.activeSegments = activeSegments


  def learn(self, anchorInput):
    """
    Associate this location with a sensory input. Subsequently, anchorInput will
    activate the current location during anchor().

    @param anchorInput (numpy array)
    A sensory input. This will often come from a feature-location pair layer.
    """
    overlaps = self.connections.computeActivity(anchorInput,
                                                self.connectedPermanence)
    activeSegments = np.where(overlaps >= self.activationThreshold)[0]

    potentialOverlaps = self.connections.computeActivity(anchorInput)
    matchingSegments = np.where(potentialOverlaps >=
                                self.learningThreshold)[0]

    # Cells with a active segment: reinforce the segment
    cellsForActiveSegments = self.connections.mapSegmentsToCells(
      activeSegments)
    learningActiveSegments = activeSegments[
      np.in1d(cellsForActiveSegments, self.activeCells)]
    remainingCells = np.setdiff1d(self.activeCells, cellsForActiveSegments)

    # Remaining cells with a matching segment: reinforce the best
    # matching segment.
    candidateSegments = self.connections.filterSegmentsByCell(
      matchingSegments, remainingCells)
    cellsForCandidateSegments = (
      self.connections.mapSegmentsToCells(candidateSegments))
    candidateSegments = candidateSegments[
      np.in1d(cellsForCandidateSegments, remainingCells)]
    onePerCellFilter = np2.argmaxMulti(potentialOverlaps[candidateSegments],
                                       cellsForCandidateSegments)
    learningMatchingSegments = candidateSegments[onePerCellFilter]

    newSegmentCells = np.setdiff1d(remainingCells, cellsForCandidateSegments)

    for learningSegments in (learningActiveSegments,
                             learningMatchingSegments):
      self._learn(self.connections, self.rng, learningSegments,
                  anchorInput, potentialOverlaps,
                  self.initialPermanence, self.sampleSize,
                  self.permanenceIncrement, self.permanenceDecrement,
                  self.maxSynapsesPerSegment)

    # Remaining cells without a matching segment: grow one.
    numNewSynapses = len(anchorInput)

    if self.sampleSize != -1:
      numNewSynapses = min(numNewSynapses, self.sampleSize)

    if self.maxSynapsesPerSegment != -1:
      numNewSynapses = min(numNewSynapses, self.maxSynapsesPerSegment)

    newSegments = self.connections.createSegments(newSegmentCells)

    self.connections.growSynapsesToSample(
      newSegments, anchorInput, numNewSynapses,
      self.initialPermanence, self.rng)

    self.activeSegments = activeSegments


  @staticmethod
  def _learn(connections, rng, learningSegments, activeInput,
             potentialOverlaps, initialPermanence, sampleSize,
             permanenceIncrement, permanenceDecrement, maxSynapsesPerSegment):
    """
    Adjust synapse permanences, grow new synapses, and grow new segments.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param segmentsToPunish (numpy array)
    @param activeInput (numpy array)
    @param potentialOverlaps (numpy array)
    """
    # Learn on existing segments
    connections.adjustSynapses(learningSegments, activeInput,
                               permanenceIncrement, -permanenceDecrement)

    # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
    # grow per segment. "maxNew" might be a number or it might be a list of
    # numbers.
    if sampleSize == -1:
      maxNew = len(activeInput)
    else:
      maxNew = sampleSize - potentialOverlaps[learningSegments]

    if maxSynapsesPerSegment != -1:
      synapseCounts = connections.mapSegmentsToSynapseCounts(
        learningSegments)
      numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
      maxNew = np.where(maxNew <= numSynapsesToReachMax,
                        maxNew, numSynapsesToReachMax)

    connections.growSynapsesToSample(learningSegments, activeInput,
                                     maxNew, initialPermanence, rng)


  def getActiveCells(self):
    return self.activeCells


  def numberOfCells(self):
    return np.prod(self.cellDimensions)
class ApicalTiebreakTemporalMemory(object):
  """
  A generalized Temporal Memory with apical dendrites that add a "tiebreak".

  Basal connections are used to implement traditional Temporal Memory.

  The apical connections are used for further disambiguation. If multiple cells
  in a minicolumn have active basal segments, each of those cells is predicted,
  unless one of them also has an active apical segment, in which case only the
  cells with active basal and apical segments are predicted.

  In other words, the apical connections have no effect unless the basal input
  is a union of SDRs (e.g. from bursting minicolumns).

  This class is generalized in two ways:

  - This class does not specify when a 'timestep' begins and ends. It exposes
    two main methods: 'depolarizeCells' and 'activateCells', and callers or
    subclasses can introduce the notion of a timestep.
  - This class is unaware of whether its 'basalInput' or 'apicalInput' are from
    internal or external cells. They are just cell numbers. The caller knows
    what these cell numbers mean, but the TemporalMemory doesn't.
  """

  def __init__(self,
               columnCount=2048,
               basalInputSize=0,
               apicalInputSize=0,
               cellsPerColumn=32,
               activationThreshold=13,
               reducedBasalThreshold=13,
               initialPermanence=0.21,
               connectedPermanence=0.50,
               minThreshold=10,
               sampleSize=20,
               permanenceIncrement=0.1,
               permanenceDecrement=0.1,
               basalPredictedSegmentDecrement=0.0,
               apicalPredictedSegmentDecrement=0.0,
               maxSynapsesPerSegment=-1,
               seed=42):
    """
    @param columnCount (int)
    The number of minicolumns

    @param basalInputSize (sequence)
    The number of bits in the basal input

    @param apicalInputSize (int)
    The number of bits in the apical input

    @param cellsPerColumn (int)
    Number of cells per column

    @param activationThreshold (int)
    If the number of active connected synapses on a segment is at least this
    threshold, the segment is said to be active.

    @param reducedBasalThreshold (int)
    The activation threshold of basal (lateral) segments for cells that have
    active apical segments. If equal to activationThreshold (default),
    this parameter has no effect.

    @param initialPermanence (float)
    Initial permanence of a new synapse

    @param connectedPermanence (float)
    If the permanence value for a synapse is greater than this value, it is said
    to be connected.

    @param minThreshold (int)
    If the number of potential synapses active on a segment is at least this
    threshold, it is said to be "matching" and is eligible for learning.

    @param sampleSize (int)
    How much of the active SDR to sample with synapses.

    @param permanenceIncrement (float)
    Amount by which permanences of synapses are incremented during learning.

    @param permanenceDecrement (float)
    Amount by which permanences of synapses are decremented during learning.

    @param basalPredictedSegmentDecrement (float)
    Amount by which segments are punished for incorrect predictions.

    @param apicalPredictedSegmentDecrement (float)
    Amount by which segments are punished for incorrect predictions.

    @param maxSynapsesPerSegment
    The maximum number of synapses per segment.

    @param seed (int)
    Seed for the random number generator.
    """

    self.columnCount = columnCount
    self.cellsPerColumn = cellsPerColumn
    self.initialPermanence = initialPermanence
    self.connectedPermanence = connectedPermanence
    self.reducedBasalThreshold = reducedBasalThreshold
    self.minThreshold = minThreshold
    self.sampleSize = sampleSize
    self.permanenceIncrement = permanenceIncrement
    self.permanenceDecrement = permanenceDecrement
    self.basalPredictedSegmentDecrement = basalPredictedSegmentDecrement
    self.apicalPredictedSegmentDecrement = apicalPredictedSegmentDecrement
    self.activationThreshold = activationThreshold
    self.maxSynapsesPerSegment = maxSynapsesPerSegment

    self.basalConnections = SparseMatrixConnections(columnCount*cellsPerColumn,
                                                    basalInputSize)
    self.apicalConnections = SparseMatrixConnections(columnCount*cellsPerColumn,
                                                     apicalInputSize)
    self.rng = Random(seed)
    self.activeCells = np.empty(0, dtype="uint32")
    self.winnerCells = np.empty(0, dtype="uint32")
    self.predictedCells = np.empty(0, dtype="uint32")
    self.predictedActiveCells = np.empty(0, dtype="uint32")
    self.activeBasalSegments = np.empty(0, dtype="uint32")
    self.activeApicalSegments = np.empty(0, dtype="uint32")
    self.matchingBasalSegments = np.empty(0, dtype="uint32")
    self.matchingApicalSegments = np.empty(0, dtype="uint32")
    self.basalPotentialOverlaps = np.empty(0, dtype="int32")
    self.apicalPotentialOverlaps = np.empty(0, dtype="int32")

    self.useApicalTiebreak=True
    self.useApicalModulationBasalThreshold=True


  def reset(self):
    """
    Clear all cell and segment activity.
    """
    self.activeCells = np.empty(0, dtype="uint32")
    self.winnerCells = np.empty(0, dtype="uint32")
    self.predictedCells = np.empty(0, dtype="uint32")
    self.predictedActiveCells = np.empty(0, dtype="uint32")
    self.activeBasalSegments = np.empty(0, dtype="uint32")
    self.activeApicalSegments = np.empty(0, dtype="uint32")
    self.matchingBasalSegments = np.empty(0, dtype="uint32")
    self.matchingApicalSegments = np.empty(0, dtype="uint32")
    self.basalPotentialOverlaps = np.empty(0, dtype="int32")
    self.apicalPotentialOverlaps = np.empty(0, dtype="int32")


  def depolarizeCells(self, basalInput, apicalInput, learn):
    """
    Calculate predictions.

    @param basalInput (numpy array)
    List of active input bits for the basal dendrite segments

    @param apicalInput (numpy array)
    List of active input bits for the apical dendrite segments

    @param learn (bool)
    Whether learning is enabled. Some TM implementations may depolarize cells
    differently or do segment activity bookkeeping when learning is enabled.
    """
    (activeApicalSegments,
     matchingApicalSegments,
     apicalPotentialOverlaps) = self._calculateApicalSegmentActivity(
       self.apicalConnections, apicalInput, self.connectedPermanence,
       self.activationThreshold, self.minThreshold)

    if learn or self.useApicalModulationBasalThreshold==False:
      reducedBasalThresholdCells = ()
    else:
      reducedBasalThresholdCells = self.apicalConnections.mapSegmentsToCells(
        activeApicalSegments)

    (activeBasalSegments,
     matchingBasalSegments,
     basalPotentialOverlaps) = self._calculateBasalSegmentActivity(
       self.basalConnections, basalInput, reducedBasalThresholdCells,
       self.connectedPermanence,
       self.activationThreshold, self.minThreshold, self.reducedBasalThreshold)

    predictedCells = self._calculatePredictedCells(activeBasalSegments,
                                                   activeApicalSegments)

    self.predictedCells = predictedCells
    self.activeBasalSegments = activeBasalSegments
    self.activeApicalSegments = activeApicalSegments
    self.matchingBasalSegments = matchingBasalSegments
    self.matchingApicalSegments = matchingApicalSegments
    self.basalPotentialOverlaps = basalPotentialOverlaps
    self.apicalPotentialOverlaps = apicalPotentialOverlaps


  def activateCells(self,
                    activeColumns,
                    basalReinforceCandidates,
                    apicalReinforceCandidates,
                    basalGrowthCandidates,
                    apicalGrowthCandidates,
                    learn=True):
    """
    Activate cells in the specified columns, using the result of the previous
    'depolarizeCells' as predictions. Then learn.

    @param activeColumns (numpy array)
    List of active columns

    @param basalReinforceCandidates (numpy array)
    List of bits that the active cells may reinforce basal synapses to.

    @param apicalReinforceCandidates (numpy array)
    List of bits that the active cells may reinforce apical synapses to.

    @param basalGrowthCandidates (numpy array)
    List of bits that the active cells may grow new basal synapses to.

    @param apicalGrowthCandidates (numpy array)
    List of bits that the active cells may grow new apical synapses to

    @param learn (bool)
    Whether to grow / reinforce / punish synapses
    """

    # Calculate active cells
    (correctPredictedCells,
     burstingColumns) = np2.setCompare(self.predictedCells, activeColumns,
                                       self.predictedCells / self.cellsPerColumn,
                                       rightMinusLeft=True)
    newActiveCells = np.concatenate((correctPredictedCells,
                                     np2.getAllCellsInColumns(
                                       burstingColumns, self.cellsPerColumn)))

    # Calculate learning
    (learningActiveBasalSegments,
     learningMatchingBasalSegments,
     basalSegmentsToPunish,
     newBasalSegmentCells,
     learningCells) = self._calculateBasalLearning(
       activeColumns, burstingColumns, correctPredictedCells,
       self.activeBasalSegments, self.matchingBasalSegments,
       self.basalPotentialOverlaps)

    (learningActiveApicalSegments,
     learningMatchingApicalSegments,
     apicalSegmentsToPunish,
     newApicalSegmentCells) = self._calculateApicalLearning(
       learningCells, activeColumns, self.activeApicalSegments,
       self.matchingApicalSegments, self.apicalPotentialOverlaps)

    # Learn
    if learn:
      # Learn on existing segments
      for learningSegments in (learningActiveBasalSegments,
                               learningMatchingBasalSegments):
        self._learn(self.basalConnections, self.rng, learningSegments,
                    basalReinforceCandidates, basalGrowthCandidates,
                    self.basalPotentialOverlaps,
                    self.initialPermanence, self.sampleSize,
                    self.permanenceIncrement, self.permanenceDecrement,
                    self.maxSynapsesPerSegment)

      for learningSegments in (learningActiveApicalSegments,
                               learningMatchingApicalSegments):

        self._learn(self.apicalConnections, self.rng, learningSegments,
                    apicalReinforceCandidates, apicalGrowthCandidates,
                    self.apicalPotentialOverlaps, self.initialPermanence,
                    self.sampleSize, self.permanenceIncrement,
                    self.permanenceDecrement, self.maxSynapsesPerSegment)

      # Punish incorrect predictions
      if self.basalPredictedSegmentDecrement != 0.0:
        self.basalConnections.adjustActiveSynapses(
          basalSegmentsToPunish, basalReinforceCandidates,
          -self.basalPredictedSegmentDecrement)

      if self.apicalPredictedSegmentDecrement != 0.0:
        self.apicalConnections.adjustActiveSynapses(
          apicalSegmentsToPunish, apicalReinforceCandidates,
          -self.apicalPredictedSegmentDecrement)

      # Grow new segments
      if len(basalGrowthCandidates) > 0:
        self._learnOnNewSegments(self.basalConnections, self.rng,
                                 newBasalSegmentCells, basalGrowthCandidates,
                                 self.initialPermanence, self.sampleSize,
                                 self.maxSynapsesPerSegment)

      if len(apicalGrowthCandidates) > 0:
        self._learnOnNewSegments(self.apicalConnections, self.rng,
                                 newApicalSegmentCells, apicalGrowthCandidates,
                                 self.initialPermanence, self.sampleSize,
                                 self.maxSynapsesPerSegment)

    # Save the results
    newActiveCells.sort()
    learningCells.sort()
    self.activeCells = newActiveCells
    self.winnerCells = learningCells
    self.predictedActiveCells = correctPredictedCells


  def _calculateBasalLearning(self,
                              activeColumns,
                              burstingColumns,
                              correctPredictedCells,
                              activeBasalSegments,
                              matchingBasalSegments,
                              basalPotentialOverlaps):
    """
    Basic Temporal Memory learning. Correctly predicted cells always have
    active basal segments, and we learn on these segments. In bursting
    columns, we either learn on an existing basal segment, or we grow a new one.

    The only influence apical dendrites have on basal learning is: the apical
    dendrites influence which cells are considered "predicted". So an active
    apical dendrite can prevent some basal segments in active columns from
    learning.

    @param correctPredictedCells (numpy array)
    @param burstingColumns (numpy array)
    @param activeBasalSegments (numpy array)
    @param matchingBasalSegments (numpy array)
    @param basalPotentialOverlaps (numpy array)

    @return (tuple)
    - learningActiveBasalSegments (numpy array)
      Active basal segments on correct predicted cells

    - learningMatchingBasalSegments (numpy array)
      Matching basal segments selected for learning in bursting columns

    - basalSegmentsToPunish (numpy array)
      Basal segments that should be punished for predicting an inactive column

    - newBasalSegmentCells (numpy array)
      Cells in bursting columns that were selected to grow new basal segments

    - learningCells (numpy array)
      Cells that have learning basal segments or are selected to grow a basal
      segment
    """

    # Correctly predicted columns
    learningActiveBasalSegments = self.basalConnections.filterSegmentsByCell(
      activeBasalSegments, correctPredictedCells)

    cellsForMatchingBasal = self.basalConnections.mapSegmentsToCells(
      matchingBasalSegments)
    matchingCells = np.unique(cellsForMatchingBasal)

    (matchingCellsInBurstingColumns,
     burstingColumnsWithNoMatch) = np2.setCompare(
       matchingCells, burstingColumns, matchingCells / self.cellsPerColumn,
       rightMinusLeft=True)

    learningMatchingBasalSegments = self._chooseBestSegmentPerColumn(
      self.basalConnections, matchingCellsInBurstingColumns,
      matchingBasalSegments, basalPotentialOverlaps, self.cellsPerColumn)
    newBasalSegmentCells = self._getCellsWithFewestSegments(
      self.basalConnections, self.rng, burstingColumnsWithNoMatch,
      self.cellsPerColumn)

    learningCells = np.concatenate(
      (correctPredictedCells,
       self.basalConnections.mapSegmentsToCells(learningMatchingBasalSegments),
       newBasalSegmentCells))

    # Incorrectly predicted columns
    correctMatchingBasalMask = np.in1d(
      cellsForMatchingBasal / self.cellsPerColumn, activeColumns)

    basalSegmentsToPunish = matchingBasalSegments[~correctMatchingBasalMask]

    return (learningActiveBasalSegments,
            learningMatchingBasalSegments,
            basalSegmentsToPunish,
            newBasalSegmentCells,
            learningCells)


  def _calculateApicalLearning(self,
                               learningCells,
                               activeColumns,
                               activeApicalSegments,
                               matchingApicalSegments,
                               apicalPotentialOverlaps):
    """
    Calculate apical learning for each learning cell.

    The set of learning cells was determined completely from basal segments.
    Do all apical learning on the same cells.

    Learn on any active segments on learning cells. For cells without active
    segments, learn on the best matching segment. For cells without a matching
    segment, grow a new segment.

    @param learningCells (numpy array)
    @param correctPredictedCells (numpy array)
    @param activeApicalSegments (numpy array)
    @param matchingApicalSegments (numpy array)
    @param apicalPotentialOverlaps (numpy array)

    @return (tuple)
    - learningActiveApicalSegments (numpy array)
      Active apical segments on correct predicted cells

    - learningMatchingApicalSegments (numpy array)
      Matching apical segments selected for learning in bursting columns

    - apicalSegmentsToPunish (numpy array)
      Apical segments that should be punished for predicting an inactive column

    - newApicalSegmentCells (numpy array)
      Cells in bursting columns that were selected to grow new apical segments
    """

    # Cells with active apical segments
    learningActiveApicalSegments = self.apicalConnections.filterSegmentsByCell(
      activeApicalSegments, learningCells)

    # Cells with matching apical segments
    learningCellsWithoutActiveApical = np.setdiff1d(
      learningCells,
      self.apicalConnections.mapSegmentsToCells(learningActiveApicalSegments))
    cellsForMatchingApical = self.apicalConnections.mapSegmentsToCells(
      matchingApicalSegments)
    learningCellsWithMatchingApical = np.intersect1d(
      learningCellsWithoutActiveApical, cellsForMatchingApical)
    learningMatchingApicalSegments = self._chooseBestSegmentPerCell(
      self.apicalConnections, learningCellsWithMatchingApical,
      matchingApicalSegments, apicalPotentialOverlaps)

    # Cells that need to grow an apical segment
    newApicalSegmentCells = np.setdiff1d(learningCellsWithoutActiveApical,
                                         learningCellsWithMatchingApical)

    # Incorrectly predicted columns
    correctMatchingApicalMask = np.in1d(
      cellsForMatchingApical / self.cellsPerColumn, activeColumns)

    apicalSegmentsToPunish = matchingApicalSegments[~correctMatchingApicalMask]

    return (learningActiveApicalSegments,
            learningMatchingApicalSegments,
            apicalSegmentsToPunish,
            newApicalSegmentCells)


  @staticmethod
  def _calculateApicalSegmentActivity(connections, activeInput, connectedPermanence,
                                activationThreshold, minThreshold):
    """
    Calculate the active and matching apical segments for this timestep.

    @param connections (SparseMatrixConnections)
    @param activeInput (numpy array)

    @return (tuple)
    - activeSegments (numpy array)
      Dendrite segments with enough active connected synapses to cause a
      dendritic spike

    - matchingSegments (numpy array)
      Dendrite segments with enough active potential synapses to be selected for
      learning in a bursting column

    - potentialOverlaps (numpy array)
      The number of active potential synapses for each segment.
      Includes counts for active, matching, and nonmatching segments.
    """

    # Active
    overlaps = connections.computeActivity(activeInput, connectedPermanence)
    activeSegments = np.flatnonzero(overlaps >= activationThreshold)

    # Matching
    potentialOverlaps = connections.computeActivity(activeInput)
    matchingSegments = np.flatnonzero(potentialOverlaps >= minThreshold)

    return (activeSegments,
            matchingSegments,
            potentialOverlaps)


  @staticmethod
  def _calculateBasalSegmentActivity(connections, activeInput,
                                reducedBasalThresholdCells, connectedPermanence,
                                activationThreshold, minThreshold, reducedBasalThreshold):
    """
    Calculate the active and matching basal segments for this timestep.

    The difference with _calculateApicalSegmentActivity is that cells
    with active apical segments (collected in reducedBasalThresholdCells) have
    a lower activation threshold for their basal segments (set by
    reducedBasalThreshold parameter).

    @param connections (SparseMatrixConnections)
    @param activeInput (numpy array)

    @return (tuple)
    - activeSegments (numpy array)
      Dendrite segments with enough active connected synapses to cause a
      dendritic spike

    - matchingSegments (numpy array)
      Dendrite segments with enough active potential synapses to be selected for
      learning in a bursting column

    - potentialOverlaps (numpy array)
      The number of active potential synapses for each segment.
      Includes counts for active, matching, and nonmatching segments.
    """
    # Active apical segments lower the activation threshold for basal (lateral) segments
    overlaps = connections.computeActivity(activeInput, connectedPermanence)
    outrightActiveSegments = np.flatnonzero(overlaps >= activationThreshold)
    if reducedBasalThreshold != activationThreshold and len(reducedBasalThresholdCells) > 0:
        potentiallyActiveSegments = np.flatnonzero((overlaps < activationThreshold)
                                        & (overlaps >= reducedBasalThreshold))
        cellsOfCASegments = connections.mapSegmentsToCells(potentiallyActiveSegments)
        # apically active segments are condit. active segments from apically active cells
        conditionallyActiveSegments = potentiallyActiveSegments[np.in1d(cellsOfCASegments,
                                                reducedBasalThresholdCells)]
        activeSegments = np.concatenate((outrightActiveSegments, conditionallyActiveSegments))
    else:
        activeSegments = outrightActiveSegments



    # Matching
    potentialOverlaps = connections.computeActivity(activeInput)
    matchingSegments = np.flatnonzero(potentialOverlaps >= minThreshold)

    return (activeSegments,
            matchingSegments,
            potentialOverlaps)


  def _calculatePredictedCells(self, activeBasalSegments, activeApicalSegments):
    """
    Calculate the predicted cells, given the set of active segments.

    An active basal segment is enough to predict a cell.
    An active apical segment is *not* enough to predict a cell.

    When a cell has both types of segments active, other cells in its minicolumn
    must also have both types of segments to be considered predictive.

    @param activeBasalSegments (numpy array)
    @param activeApicalSegments (numpy array)

    @return (numpy array)
    """

    cellsForBasalSegments = self.basalConnections.mapSegmentsToCells(
      activeBasalSegments)
    cellsForApicalSegments = self.apicalConnections.mapSegmentsToCells(
      activeApicalSegments)

    fullyDepolarizedCells = np.intersect1d(cellsForBasalSegments,
                                           cellsForApicalSegments)
    partlyDepolarizedCells = np.setdiff1d(cellsForBasalSegments,
                                          fullyDepolarizedCells)

    inhibitedMask = np.in1d(partlyDepolarizedCells / self.cellsPerColumn,
                            fullyDepolarizedCells / self.cellsPerColumn)
    predictedCells = np.append(fullyDepolarizedCells,
                               partlyDepolarizedCells[~inhibitedMask])

    if self.useApicalTiebreak == False:
        predictedCells = cellsForBasalSegments

    return predictedCells


  @staticmethod
  def _learn(connections, rng, learningSegments, activeInput, growthCandidates,
             potentialOverlaps, initialPermanence, sampleSize,
             permanenceIncrement, permanenceDecrement, maxSynapsesPerSegment):
    """
    Adjust synapse permanences, grow new synapses, and grow new segments.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param activeInput (numpy array)
    @param growthCandidates (numpy array)
    @param potentialOverlaps (numpy array)
    """

    # Learn on existing segments
    connections.adjustSynapses(learningSegments, activeInput,
                               permanenceIncrement, -permanenceDecrement)

    # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
    # grow per segment. "maxNew" might be a number or it might be a list of
    # numbers.
    if sampleSize == -1:
      maxNew = len(growthCandidates)
    else:
      maxNew = sampleSize - potentialOverlaps[learningSegments]

    if maxSynapsesPerSegment != -1:
      synapseCounts = connections.mapSegmentsToSynapseCounts(
        learningSegments)
      numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
      maxNew = np.where(maxNew <= numSynapsesToReachMax,
                        maxNew, numSynapsesToReachMax)

    connections.growSynapsesToSample(learningSegments, growthCandidates,
                                     maxNew, initialPermanence, rng)


  @staticmethod
  def _learnOnNewSegments(connections, rng, newSegmentCells, growthCandidates,
                          initialPermanence, sampleSize, maxSynapsesPerSegment):

    numNewSynapses = len(growthCandidates)

    if sampleSize != -1:
      numNewSynapses = min(numNewSynapses, sampleSize)

    if maxSynapsesPerSegment != -1:
      numNewSynapses = min(numNewSynapses, maxSynapsesPerSegment)

    newSegments = connections.createSegments(newSegmentCells)
    connections.growSynapsesToSample(newSegments, growthCandidates,
                                     numNewSynapses, initialPermanence,
                                     rng)


  @classmethod
  def _chooseBestSegmentPerCell(cls,
                                connections,
                                cells,
                                allMatchingSegments,
                                potentialOverlaps):
    """
    For each specified cell, choose its matching segment with largest number
    of active potential synapses. When there's a tie, the first segment wins.

    @param connections (SparseMatrixConnections)
    @param cells (numpy array)
    @param allMatchingSegments (numpy array)
    @param potentialOverlaps (numpy array)

    @return (numpy array)
    One segment per cell
    """

    candidateSegments = connections.filterSegmentsByCell(allMatchingSegments,
                                                         cells)

    # Narrow it down to one pair per cell.
    onePerCellFilter = np2.argmaxMulti(potentialOverlaps[candidateSegments],
                                       connections.mapSegmentsToCells(
                                         candidateSegments))
    learningSegments = candidateSegments[onePerCellFilter]

    return learningSegments


  @classmethod
  def _chooseBestSegmentPerColumn(cls, connections, matchingCells,
                                  allMatchingSegments, potentialOverlaps,
                                  cellsPerColumn):
    """
    For all the columns covered by 'matchingCells', choose the column's matching
    segment with largest number of active potential synapses. When there's a
    tie, the first segment wins.

    @param connections (SparseMatrixConnections)
    @param matchingCells (numpy array)
    @param allMatchingSegments (numpy array)
    @param potentialOverlaps (numpy array)
    """

    candidateSegments = connections.filterSegmentsByCell(allMatchingSegments,
                                                         matchingCells)

    # Narrow it down to one segment per column.
    cellScores = potentialOverlaps[candidateSegments]
    columnsForCandidates = (connections.mapSegmentsToCells(candidateSegments) /
                            cellsPerColumn)
    onePerColumnFilter = np2.argmaxMulti(cellScores, columnsForCandidates)

    learningSegments = candidateSegments[onePerColumnFilter]

    return learningSegments


  @classmethod
  def _getCellsWithFewestSegments(cls, connections, rng, columns,
                                  cellsPerColumn):
    """
    For each column, get the cell that has the fewest total basal segments.
    Break ties randomly.

    @param connections (SparseMatrixConnections)
    @param rng (Random)
    @param columns (numpy array) Columns to check

    @return (numpy array)
    One cell for each of the provided columns
    """
    candidateCells = np2.getAllCellsInColumns(columns, cellsPerColumn)

    # Arrange the segment counts into one row per minicolumn.
    segmentCounts = np.reshape(connections.getSegmentCounts(candidateCells),
                               newshape=(len(columns),
                                         cellsPerColumn))

    # Filter to just the cells that are tied for fewest in their minicolumn.
    minSegmentCounts = np.amin(segmentCounts, axis=1, keepdims=True)
    candidateCells = candidateCells[np.flatnonzero(segmentCounts ==
                                                   minSegmentCounts)]

    # Filter to one cell per column, choosing randomly from the minimums.
    # To do the random choice, add a random offset to each index in-place, using
    # casting to floor the result.
    (_,
     onePerColumnFilter,
     numCandidatesInColumns) = np.unique(candidateCells / cellsPerColumn,
                                         return_index=True, return_counts=True)

    offsetPercents = np.empty(len(columns), dtype="float32")
    rng.initializeReal32Array(offsetPercents)

    np.add(onePerColumnFilter,
           offsetPercents*numCandidatesInColumns,
           out=onePerColumnFilter,
           casting="unsafe")

    return candidateCells[onePerColumnFilter]


  def getActiveCells(self):
    """
    @return (numpy array)
    Active cells
    """
    return self.activeCells


  def getPredictedActiveCells(self):
    """
    @return (numpy array)
    Active cells that were correctly predicted
    """
    return self.predictedActiveCells


  def getWinnerCells(self):
    """
    @return (numpy array)
    Cells that were selected for learning
    """
    return self.winnerCells


  def getActiveBasalSegments(self):
    """
    @return (numpy array)
    Active basal segments for this timestep
    """
    return self.activeBasalSegments


  def getActiveApicalSegments(self):
    """
    @return (numpy array)
    Matching basal segments for this timestep
    """
    return self.activeApicalSegments


  def numberOfColumns(self):
    """ Returns the number of columns in this layer.

    @return (int) Number of columns
    """
    return self.columnCount


  def numberOfCells(self):
    """
    Returns the number of cells in this layer.

    @return (int) Number of cells
    """
    return self.numberOfColumns() * self.cellsPerColumn


  def getCellsPerColumn(self):
    """
    Returns the number of cells per column.

    @return (int) The number of cells per column.
    """
    return self.cellsPerColumn


  def getActivationThreshold(self):
    """
    Returns the activation threshold.
    @return (int) The activation threshold.
    """
    return self.activationThreshold


  def setActivationThreshold(self, activationThreshold):
    """
    Sets the activation threshold.
    @param activationThreshold (int) activation threshold.
    """
    self.activationThreshold = activationThreshold


  def getReducedBasalThreshold(self):
    """
    Returns the reduced basal activation threshold for apically active cells.
    @return (int) The activation threshold.
    """
    return self.reducedBasalThreshold


  def setReducedBasalThreshold(self, reducedBasalThreshold):
    """
    Sets the reduced basal activation threshold for apically active cells.
    @param reducedBasalThreshold (int) activation threshold.
    """
    self.reducedBasalThreshold = reducedBasalThreshold


  def getInitialPermanence(self):
    """
    Get the initial permanence.
    @return (float) The initial permanence.
    """
    return self.initialPermanence


  def setInitialPermanence(self, initialPermanence):
    """
    Sets the initial permanence.
    @param initialPermanence (float) The initial permanence.
    """
    self.initialPermanence = initialPermanence


  def getMinThreshold(self):
    """
    Returns the min threshold.
    @return (int) The min threshold.
    """
    return self.minThreshold


  def setMinThreshold(self, minThreshold):
    """
    Sets the min threshold.
    @param minThreshold (int) min threshold.
    """
    self.minThreshold = minThreshold


  def getSampleSize(self):
    """
    Gets the sampleSize.
    @return (int)
    """
    return self.sampleSize


  def setSampleSize(self, sampleSize):
    """
    Sets the sampleSize.
    @param sampleSize (int)
    """
    self.sampleSize = sampleSize


  def getPermanenceIncrement(self):
    """
    Get the permanence increment.
    @return (float) The permanence increment.
    """
    return self.permanenceIncrement


  def setPermanenceIncrement(self, permanenceIncrement):
    """
    Sets the permanence increment.
    @param permanenceIncrement (float) The permanence increment.
    """
    self.permanenceIncrement = permanenceIncrement


  def getPermanenceDecrement(self):
    """
    Get the permanence decrement.
    @return (float) The permanence decrement.
    """
    return self.permanenceDecrement


  def setPermanenceDecrement(self, permanenceDecrement):
    """
    Sets the permanence decrement.
    @param permanenceDecrement (float) The permanence decrement.
    """
    self.permanenceDecrement = permanenceDecrement


  def getBasalPredictedSegmentDecrement(self):
    """
    Get the predicted segment decrement.
    @return (float) The predicted segment decrement.
    """
    return self.basalPredictedSegmentDecrement


  def setBasalPredictedSegmentDecrement(self, predictedSegmentDecrement):
    """
    Sets the predicted segment decrement.
    @param predictedSegmentDecrement (float) The predicted segment decrement.
    """
    self.basalPredictedSegmentDecrement = basalPredictedSegmentDecrement


  def getApicalPredictedSegmentDecrement(self):
    """
    Get the predicted segment decrement.
    @return (float) The predicted segment decrement.
    """
    return self.apicalPredictedSegmentDecrement


  def setApicalPredictedSegmentDecrement(self, predictedSegmentDecrement):
    """
    Sets the predicted segment decrement.
    @param predictedSegmentDecrement (float) The predicted segment decrement.
    """
    self.apicalPredictedSegmentDecrement = apicalPredictedSegmentDecrement


  def getConnectedPermanence(self):
    """
    Get the connected permanence.
    @return (float) The connected permanence.
    """
    return self.connectedPermanence


  def setConnectedPermanence(self, connectedPermanence):
    """
    Sets the connected permanence.
    @param connectedPermanence (float) The connected permanence.
    """
    self.connectedPermanence = connectedPermanence


  def getUseApicalTieBreak(self):
    """
    Get whether we actually use apical tie-break.
    @return (Bool) Whether apical tie-break is used.
    """
    return self.useApicalTiebreak


  def setUseApicalTiebreak(self, useApicalTiebreak):
    """
    Sets whether we actually use apical tie-break.
    @param useApicalTiebreak (Bool) Whether apical tie-break is used.
    """
    self.useApicalTiebreak = useApicalTiebreak


  def getUseApicalModulationBasalThreshold(self):
    """
    Get whether we actually use apical modulation of basal threshold.
    @return (Bool) Whether apical modulation is used.
    """
    return self.useApicalModulationBasalThreshold


  def setUseApicalModulationBasalThreshold(self, useApicalModulationBasalThreshold):
    """
    Sets whether we actually use apical modulation of basal threshold.
    @param useApicalModulationBasalThreshold (Bool) Whether apical modulation is used.
    """
    self.useApicalModulationBasalThreshold = useApicalModulationBasalThreshold
Beispiel #15
0
class SingleLayerLocationMemory(object):
    """
  A layer of cells which learns how to take a "delta location" (e.g. a motor
  command or a proprioceptive delta) and update its active cells to represent
  the new location.

  Its active cells might represent a union of locations.
  As the location changes, the featureLocationInput causes this union to narrow
  down until the location is inferred.

  This layer receives absolute proprioceptive info as proximal input.
  For now, we assume that there's a one-to-one mapping between absolute
  proprioceptive input and the location SDR. So rather than modeling
  proximal synapses, we'll just relay the proprioceptive SDR. In the future
  we might want to consider a many-to-one mapping of proprioceptive inputs
  to location SDRs.

  After this layer is trained, it no longer needs the proprioceptive input.
  The delta location will drive the layer. The current active cells and the
  other distal connections will work together with this delta location to
  activate a new set of cells.

  When no cells are active, activate a large union of possible locations.
  With subsequent inputs, the union will narrow down to a single location SDR.
  """
    def __init__(self,
                 cellCount,
                 deltaLocationInputSize,
                 featureLocationInputSize,
                 activationThreshold=13,
                 initialPermanence=0.21,
                 connectedPermanence=0.50,
                 learningThreshold=10,
                 sampleSize=20,
                 permanenceIncrement=0.1,
                 permanenceDecrement=0.1,
                 maxSynapsesPerSegment=-1,
                 seed=42):

        # For transition learning, every segment is split into two parts.
        # For the segment to be active, both parts must be active.
        self.internalConnections = SparseMatrixConnections(
            cellCount, cellCount)
        self.deltaConnections = SparseMatrixConnections(
            cellCount, deltaLocationInputSize)

        # Distal segments that receive input from the layer that represents
        # feature-locations.
        self.featureLocationConnections = SparseMatrixConnections(
            cellCount, featureLocationInputSize)

        self.activeCells = np.empty(0, dtype="uint32")
        self.activeDeltaSegments = np.empty(0, dtype="uint32")
        self.activeFeatureLocationSegments = np.empty(0, dtype="uint32")

        self.initialPermanence = initialPermanence
        self.connectedPermanence = connectedPermanence
        self.learningThreshold = learningThreshold
        self.sampleSize = sampleSize
        self.permanenceIncrement = permanenceIncrement
        self.permanenceDecrement = permanenceDecrement
        self.activationThreshold = activationThreshold
        self.maxSynapsesPerSegment = maxSynapsesPerSegment

        self.rng = Random(seed)

    def reset(self):
        """
    Deactivate all cells.
    """

        self.activeCells = np.empty(0, dtype="uint32")
        self.activeDeltaSegments = np.empty(0, dtype="uint32")
        self.activeFeatureLocationSegments = np.empty(0, dtype="uint32")

    def compute(self,
                deltaLocation=(),
                newLocation=(),
                featureLocationInput=(),
                featureLocationGrowthCandidates=(),
                learn=True):
        """
    Run one time step of the Location Memory algorithm.

    @param deltaLocation (sorted numpy array)
    @param newLocation (sorted numpy array)
    @param featureLocationInput (sorted numpy array)
    @param featureLocationGrowthCandidates (sorted numpy array)
    """
        prevActiveCells = self.activeCells

        self.activeDeltaSegments = np.where(
            (self.internalConnections.computeActivity(
                prevActiveCells, self.connectedPermanence) >=
             self.activationThreshold)
            & (self.deltaConnections.computeActivity(
                deltaLocation, self.connectedPermanence) >=
               self.activationThreshold))[0]

        # When we're moving, the feature-location input has no effect.
        if len(deltaLocation) == 0:
            self.activeFeatureLocationSegments = np.where(
                self.featureLocationConnections.computeActivity(
                    featureLocationInput, self.connectedPermanence) >=
                self.activationThreshold)[0]
        else:
            self.activeFeatureLocationSegments = np.empty(0, dtype="uint32")

        if len(newLocation) > 0:
            # Drive activations by relaying this location SDR.
            self.activeCells = newLocation

            if learn:
                # Learn the delta.
                self._learnTransition(prevActiveCells, deltaLocation,
                                      newLocation)

                # Learn the featureLocationInput.
                self._learnFeatureLocationPair(
                    newLocation, featureLocationInput,
                    featureLocationGrowthCandidates)

        elif len(prevActiveCells) > 0:
            if len(deltaLocation) > 0:
                # Drive activations by applying the deltaLocation to the current location.
                # Completely ignore the featureLocationInput. It's outdated, associated
                # with the previous location.

                cellsForDeltaSegments = self.internalConnections.mapSegmentsToCells(
                    self.activeDeltaSegments)

                self.activeCells = np.unique(cellsForDeltaSegments)
            else:
                # Keep previous active cells active.
                # Modulate with the featureLocationInput.

                if len(self.activeFeatureLocationSegments) > 0:

                    cellsForFeatureLocationSegments = (
                        self.featureLocationConnections.mapSegmentsToCells(
                            self.activeFeatureLocationSegments))
                    self.activeCells = np.intersect1d(
                        prevActiveCells, cellsForFeatureLocationSegments)
                else:
                    self.activeCells = prevActiveCells

        elif len(featureLocationInput) > 0:
            # Drive activations with the featureLocationInput.

            cellsForFeatureLocationSegments = (
                self.featureLocationConnections.mapSegmentsToCells(
                    self.activeFeatureLocationSegments))

            self.activeCells = np.unique(cellsForFeatureLocationSegments)

    def _learnTransition(self, prevActiveCells, deltaLocation, newLocation):
        """
    For each cell in the newLocation SDR, learn the transition of prevLocation
    (i.e. prevActiveCells) + deltaLocation.

    The transition might be already known. In that case, just reinforce the
    existing segments.
    """

        prevLocationPotentialOverlaps = self.internalConnections.computeActivity(
            prevActiveCells)
        deltaPotentialOverlaps = self.deltaConnections.computeActivity(
            deltaLocation)

        matchingDeltaSegments = np.where(
            (prevLocationPotentialOverlaps >= self.learningThreshold)
            & (deltaPotentialOverlaps >= self.learningThreshold))[0]

        # Cells with a active segment pair: reinforce the segment
        cellsForActiveSegments = self.internalConnections.mapSegmentsToCells(
            self.activeDeltaSegments)
        learningActiveDeltaSegments = self.activeDeltaSegments[np.in1d(
            cellsForActiveSegments, newLocation)]
        remainingCells = np.setdiff1d(newLocation, cellsForActiveSegments)

        # Remaining cells with a matching segment pair: reinforce the best matching
        # segment pair.
        candidateSegments = self.internalConnections.filterSegmentsByCell(
            matchingDeltaSegments, remainingCells)
        cellsForCandidateSegments = self.internalConnections.mapSegmentsToCells(
            candidateSegments)
        candidateSegments = matchingDeltaSegments[np.in1d(
            cellsForCandidateSegments, remainingCells)]
        onePerCellFilter = np2.argmaxMulti(
            prevLocationPotentialOverlaps[candidateSegments] +
            deltaPotentialOverlaps[candidateSegments],
            cellsForCandidateSegments)
        learningMatchingDeltaSegments = candidateSegments[onePerCellFilter]

        newDeltaSegmentCells = np.setdiff1d(remainingCells,
                                            cellsForCandidateSegments)

        for learningSegments in (learningActiveDeltaSegments,
                                 learningMatchingDeltaSegments):
            self._learn(self.internalConnections, self.rng, learningSegments,
                        prevActiveCells, prevActiveCells,
                        prevLocationPotentialOverlaps, self.initialPermanence,
                        self.sampleSize, self.permanenceIncrement,
                        self.permanenceDecrement, self.maxSynapsesPerSegment)
            self._learn(self.deltaConnections, self.rng, learningSegments,
                        deltaLocation, deltaLocation, deltaPotentialOverlaps,
                        self.initialPermanence, self.sampleSize,
                        self.permanenceIncrement, self.permanenceDecrement,
                        self.maxSynapsesPerSegment)

        numNewLocationSynapses = len(prevActiveCells)
        numNewDeltaSynapses = len(deltaLocation)

        if self.sampleSize != -1:
            numNewLocationSynapses = min(numNewLocationSynapses,
                                         self.sampleSize)
            numNewDeltaSynapses = min(numNewDeltaSynapses, self.sampleSize)

        if self.maxSynapsesPerSegment != -1:
            numNewLocationSynapses = min(numNewLocationSynapses,
                                         self.maxSynapsesPerSegment)
            numNewDeltaSynapses = min(numNewLocationSynapses,
                                      self.maxSynapsesPerSegment)

        newPrevLocationSegments = self.internalConnections.createSegments(
            newDeltaSegmentCells)
        newDeltaSegments = self.deltaConnections.createSegments(
            newDeltaSegmentCells)

        assert np.array_equal(newPrevLocationSegments, newDeltaSegments)

        self.internalConnections.growSynapsesToSample(newPrevLocationSegments,
                                                      prevActiveCells,
                                                      numNewLocationSynapses,
                                                      self.initialPermanence,
                                                      self.rng)
        self.deltaConnections.growSynapsesToSample(newDeltaSegments,
                                                   deltaLocation,
                                                   numNewDeltaSynapses,
                                                   self.initialPermanence,
                                                   self.rng)

    def _learnFeatureLocationPair(self, newLocation, featureLocationInput,
                                  featureLocationGrowthCandidates):
        """
    Grow / reinforce synapses between the location layer's dendrites and the
    input layer's active cells.
    """

        potentialOverlaps = self.featureLocationConnections.computeActivity(
            featureLocationInput)
        matchingSegments = np.where(
            potentialOverlaps > self.learningThreshold)[0]

        # Cells with a active segment pair: reinforce the segment
        cellsForActiveSegments = self.featureLocationConnections.mapSegmentsToCells(
            self.activeFeatureLocationSegments)
        learningActiveSegments = self.activeFeatureLocationSegments[np.in1d(
            cellsForActiveSegments, newLocation)]
        remainingCells = np.setdiff1d(newLocation, cellsForActiveSegments)

        # Remaining cells with a matching segment pair: reinforce the best matching
        # segment pair.
        candidateSegments = self.featureLocationConnections.filterSegmentsByCell(
            matchingSegments, remainingCells)
        cellsForCandidateSegments = (self.featureLocationConnections.
                                     mapSegmentsToCells(candidateSegments))
        candidateSegments = candidateSegments[np.in1d(
            cellsForCandidateSegments, remainingCells)]
        onePerCellFilter = np2.argmaxMulti(
            potentialOverlaps[candidateSegments], cellsForCandidateSegments)
        learningMatchingSegments = candidateSegments[onePerCellFilter]

        newSegmentCells = np.setdiff1d(remainingCells,
                                       cellsForCandidateSegments)

        for learningSegments in (learningActiveSegments,
                                 learningMatchingSegments):
            self._learn(self.featureLocationConnections, self.rng,
                        learningSegments, featureLocationInput,
                        featureLocationGrowthCandidates, potentialOverlaps,
                        self.initialPermanence, self.sampleSize,
                        self.permanenceIncrement, self.permanenceDecrement,
                        self.maxSynapsesPerSegment)

        numNewSynapses = len(featureLocationInput)

        if self.sampleSize != -1:
            numNewSynapses = min(numNewSynapses, self.sampleSize)

        if self.maxSynapsesPerSegment != -1:
            numNewSynapses = min(numNewSynapses, self.maxSynapsesPerSegment)

        newSegments = self.featureLocationConnections.createSegments(
            newSegmentCells)

        self.featureLocationConnections.growSynapsesToSample(
            newSegments, featureLocationGrowthCandidates, numNewSynapses,
            self.initialPermanence, self.rng)

    @staticmethod
    def _learn(connections, rng, learningSegments, activeInput,
               growthCandidates, potentialOverlaps, initialPermanence,
               sampleSize, permanenceIncrement, permanenceDecrement,
               maxSynapsesPerSegment):
        """
    Adjust synapse permanences, grow new synapses, and grow new segments.

    @param learningActiveSegments (numpy array)
    @param learningMatchingSegments (numpy array)
    @param segmentsToPunish (numpy array)
    @param activeInput (numpy array)
    @param growthCandidates (numpy array)
    @param potentialOverlaps (numpy array)
    """

        # Learn on existing segments
        connections.adjustSynapses(learningSegments, activeInput,
                                   permanenceIncrement, -permanenceDecrement)

        # Grow new synapses. Calculate "maxNew", the maximum number of synapses to
        # grow per segment. "maxNew" might be a number or it might be a list of
        # numbers.
        if sampleSize == -1:
            maxNew = len(growthCandidates)
        else:
            maxNew = sampleSize - potentialOverlaps[learningSegments]

        if maxSynapsesPerSegment != -1:
            synapseCounts = connections.mapSegmentsToSynapseCounts(
                learningSegments)
            numSynapsesToReachMax = maxSynapsesPerSegment - synapseCounts
            maxNew = np.where(maxNew <= numSynapsesToReachMax, maxNew,
                              numSynapsesToReachMax)

        connections.growSynapsesToSample(learningSegments, growthCandidates,
                                         maxNew, initialPermanence, rng)

    def getActiveCells(self):
        return self.activeCells