示例#1
0
def getImageData(numInputVectors):
  from htmresearch.algorithms.image_sparse_net import ImageSparseNet

  DATA_PATH = "../sparse_net/data/IMAGES.mat"
  DATA_NAME = "IMAGES"

  DEFAULT_SPARSENET_PARAMS = {
    "filterDim": 64,
    "outputDim": 64,
    "batchSize": numInputVectors,
    "numLcaIterations": 75,
    "learningRate": 2.0,
    "decayCycle": 100,
    "learningRateDecay": 1.0,
    "lcaLearningRate": 0.1,
    "thresholdDecay": 0.95,
    "minThreshold": 1.0,
    "thresholdType": 'soft',
    "verbosity": 0,  # can be changed to print training loss
    "showEvery": 500,
    "seed": 42,
  }

  network = ImageSparseNet(**DEFAULT_SPARSENET_PARAMS)

  print "Loading training data..."
  images = network.loadMatlabImages(DATA_PATH, DATA_NAME)

  nDim1, nDim2, numImages = images.shape
  binaryImages = np.zeros(images.shape)
  for i in range(numImages):
    binaryImages[:, :, i] = convertToBinaryImage(images[:, :, i])

  inputVectors = network._getDataBatch(binaryImages)
  inputVectors = inputVectors.T
  return inputVectors
def getImageData(numInputVectors):
    from htmresearch.algorithms.image_sparse_net import ImageSparseNet

    DATA_PATH = "../sparse_net/data/IMAGES.mat"
    DATA_NAME = "IMAGES"

    DEFAULT_SPARSENET_PARAMS = {
        "filterDim": 64,
        "outputDim": 64,
        "batchSize": numInputVectors,
        "numLcaIterations": 75,
        "learningRate": 2.0,
        "decayCycle": 100,
        "learningRateDecay": 1.0,
        "lcaLearningRate": 0.1,
        "thresholdDecay": 0.95,
        "minThreshold": 1.0,
        "thresholdType": 'soft',
        "verbosity": 0,  # can be changed to print training loss
        "showEvery": 500,
        "seed": 42,
    }

    network = ImageSparseNet(**DEFAULT_SPARSENET_PARAMS)

    print "Loading training data..."
    images = network.loadMatlabImages(DATA_PATH, DATA_NAME)

    nDim1, nDim2, numImages = images.shape
    binaryImages = np.zeros(images.shape)
    for i in range(numImages):
        binaryImages[:, :, i] = convertToBinaryImage(images[:, :, i])

    inputVectors = network._getDataBatch(binaryImages)
    inputVectors = inputVectors.T
    return inputVectors
def runExperiment():
    print "Creating network..."
    network = ImageSparseNet(**DEFAULT_SPARSENET_PARAMS)

    print "Loading training data..."
    images = network.loadMatlabImages(DATA_PATH, DATA_NAME)
    print

    print "Training {0}...".format(network)
    network.train(images, numIterations=5000)

    print "Saving loss history and function basis..."
    network.plotLoss(filename=LOSS_HISTORY_PATH)
    network.plotBasis(filename=BASIS_FUNCTIONS_PATH)

    if capnp:
        print
        print "Saving model..."
        proto1 = SparseNetProto.new_message()
        network.write(proto1)
        with open(SERIALIZATION_PATH, 'wb') as f:
            proto1.write(f)

        print "Loading model..."
        with open(SERIALIZATION_PATH, 'rb') as f:
            proto2 = SparseNetProto.read(f)
            newNetwork = ImageSparseNet.read(proto2)

            print "Checking that loaded model is the same as before..."
            if newNetwork != network:
                raise ValueError("Model is different!")
            else:
                print "Model is the same."

            print
            print "Training {0} again...".format(network)
            newNetwork.train(images, numIterations=5000)

            print "Saving loss history and function basis..."
            newNetwork.plotLoss(filename=LOSS_HISTORY_PATH2)
            newNetwork.plotBasis(filename=BASIS_FUNCTIONS_PATH2)
def runExperiment():
  print "Creating network..."
  network = ImageSparseNet(**DEFAULT_SPARSENET_PARAMS)

  print "Loading training data..."
  images = network.loadMatlabImages(DATA_PATH, DATA_NAME)
  print

  print "Training {0}...".format(network)
  network.train(images, numIterations=5000)

  print "Saving loss history and function basis..."
  network.plotLoss(filename=LOSS_HISTORY_PATH)
  network.plotBasis(filename=BASIS_FUNCTIONS_PATH)

  if capnp:
    print
    print "Saving model..."
    proto1 = SparseNetProto.new_message()
    network.write(proto1)
    with open(SERIALIZATION_PATH, 'wb') as f:
      proto1.write(f)

    print "Loading model..."
    with open(SERIALIZATION_PATH, 'rb') as f:
      proto2 = SparseNetProto.read(f)
      newNetwork = ImageSparseNet.read(proto2)

      print "Checking that loaded model is the same as before..."
      if newNetwork != network:
        raise ValueError("Model is different!")
      else:
        print "Model is the same."

      print
      print "Training {0} again...".format(network)
      newNetwork.train(images, numIterations=5000)

      print "Saving loss history and function basis..."
      newNetwork.plotLoss(filename=LOSS_HISTORY_PATH2)
      newNetwork.plotBasis(filename=BASIS_FUNCTIONS_PATH2)