예제 #1
0
    def test_save_and_load(self):
        metric = Metric({"name": "logloss"})
        cat = CatBoostAlgorithm(self.params)
        cat.fit(self.X, self.y)
        y_predicted = cat.predict(self.X)
        loss = metric(self.y, y_predicted)

        with tempfile.NamedTemporaryFile() as tmp:
            cat.save(tmp.name)
            cat2 = CatBoostAlgorithm(self.params)
            self.assertTrue(cat.uid != cat2.uid)
            self.assertTrue(cat2.model is not None)
            cat2.load(tmp.name)

            y_predicted = cat2.predict(self.X)
            loss2 = metric(self.y, y_predicted)
            assert_almost_equal(loss, loss2)
예제 #2
0
    def test_save_and_load(self):
        metric = Metric({"name": "logloss"})
        cat = CatBoostAlgorithm(self.params)
        cat.fit(self.X, self.y)
        y_predicted = cat.predict(self.X)
        loss = metric(self.y, y_predicted)

        filename = os.path.join(tempfile.gettempdir(), os.urandom(12).hex())

        cat.save(filename)
        cat2 = CatBoostAlgorithm(self.params)
        self.assertTrue(cat.uid != cat2.uid)
        self.assertTrue(cat2.model is not None)
        cat2.load(filename)
        # Finished with the file, delete it
        os.remove(filename)

        y_predicted = cat2.predict(self.X)
        loss2 = metric(self.y, y_predicted)
        assert_almost_equal(loss, loss2)
    def test_save_and_load(self):
        metric = Metric({"name": "logloss"})
        cat = CatBoostAlgorithm(self.params)
        cat.fit(self.X, self.y)
        y_predicted = cat.predict(self.X)
        loss = metric(self.y, y_predicted)

        json_desc = cat.save()
        cat2 = CatBoostAlgorithm({})
        self.assertTrue(cat.uid != cat2.uid)
        self.assertTrue(cat2.model is not None)
        cat2.load(json_desc)
        self.assertTrue(cat.uid == cat2.uid)

        y_predicted = cat2.predict(self.X)
        loss2 = metric(self.y, y_predicted)
        assert_almost_equal(loss, loss2)