Exemplo n.º 1
0
    def testCheckpointLearned(self):
        # Create a model and give it some inputs to learn.
        tm1 = BacktrackingTM(numberOfCols=100,
                             cellsPerColumn=12,
                             verbosity=VERBOSITY)
        sequences = [self.generateSequence() for _ in xrange(5)]
        train = list(itertools.chain.from_iterable(sequences[:3]))
        for bottomUpInput in train:
            if bottomUpInput is None:
                tm1.reset()
            else:
                tm1.compute(bottomUpInput, True, True)

        # Serialize and deserialized the TM.
        checkpointPath = os.path.join(self._tmpDir, 'a')
        tm1.saveToFile(checkpointPath)
        tm2 = pickle.loads(pickle.dumps(tm1))
        tm2.loadFromFile(checkpointPath)

        # Check that the TMs are the same.
        self.assertTMsEqual(tm1, tm2)

        # Feed some data into the models.
        test = list(itertools.chain.from_iterable(sequences[3:]))
        for bottomUpInput in test:
            if bottomUpInput is None:
                tm1.reset()
                tm2.reset()
            else:
                result1 = tm1.compute(bottomUpInput, True, True)
                result2 = tm2.compute(bottomUpInput, True, True)

                self.assertTMsEqual(tm1, tm2)
                self.assertTrue(numpy.array_equal(result1, result2))
Exemplo n.º 2
0
    def testCheckpointMiddleOfSequence2(self):
        """More complex test of checkpointing in the middle of a sequence."""
        tm1 = BacktrackingTM(2048, 32, 0.21, 0.5, 11, 20, 0.1, 0.1, 1.0, 0.0,
                             14, False, 5, 2, False, 1960, 0, False, 3, 10, 5,
                             0, 32, 128, 32, 'normal')
        tm2 = BacktrackingTM(2048, 32, 0.21, 0.5, 11, 20, 0.1, 0.1, 1.0, 0.0,
                             14, False, 5, 2, False, 1960, 0, False, 3, 10, 5,
                             0, 32, 128, 32, 'normal')

        with open(resource_filename(__name__, 'data/tm_input.csv'),
                  'r') as fin:
            reader = csv.reader(fin)
            records = []
            for bottomUpInStr in fin:
                bottomUpIn = numpy.array(eval('[' + bottomUpInStr.strip() +
                                              ']'),
                                         dtype='int32')
                records.append(bottomUpIn)

        i = 1
        for r in records[:250]:
            print i
            i += 1
            output1 = tm1.compute(r, True, True)
            output2 = tm2.compute(r, True, True)
            self.assertTrue(numpy.array_equal(output1, output2))

        print 'Serializing and deserializing models.'

        savePath1 = os.path.join(self._tmpDir, 'tm1.bin')
        tm1.saveToFile(savePath1)
        tm3 = pickle.loads(pickle.dumps(tm1))
        tm3.loadFromFile(savePath1)

        savePath2 = os.path.join(self._tmpDir, 'tm2.bin')
        tm2.saveToFile(savePath2)
        tm4 = pickle.loads(pickle.dumps(tm2))
        tm4.loadFromFile(savePath2)

        self.assertTMsEqual(tm1, tm3)
        self.assertTMsEqual(tm2, tm4)

        for r in records[250:]:
            print i
            i += 1
            out1 = tm1.compute(r, True, True)
            out2 = tm2.compute(r, True, True)
            out3 = tm3.compute(r, True, True)
            out4 = tm4.compute(r, True, True)

            self.assertTrue(numpy.array_equal(out1, out2))
            self.assertTrue(numpy.array_equal(out1, out3))
            self.assertTrue(numpy.array_equal(out1, out4))

        self.assertTMsEqual(tm1, tm2)
        self.assertTMsEqual(tm1, tm3)
        self.assertTMsEqual(tm2, tm4)