def test_save_and_load_model(self): model = TensorRec(n_components=10) model.fit(self.interactions, self.user_features, self.item_features, epochs=10) predictions = model.predict(user_features=self.user_features, item_features=self.item_features) ranks = model.predict_rank(user_features=self.user_features, item_features=self.item_features) model.save_model(directory_path=self.test_dir) # Check that, after saving, the same predictions come back predictions_after_save = model.predict( user_features=self.user_features, item_features=self.item_features) ranks_after_save = model.predict_rank(user_features=self.user_features, item_features=self.item_features) self.assertTrue((predictions == predictions_after_save).all()) self.assertTrue((ranks == ranks_after_save).all()) # Blow away the session set_session(None) tf.reset_default_graph() # Reload the model, predict, and check for equal predictions new_model = TensorRec.load_model(directory_path=self.test_dir) new_predictions = new_model.predict(user_features=self.user_features, item_features=self.item_features) new_ranks = new_model.predict_rank(user_features=self.user_features, item_features=self.item_features) self.assertTrue((predictions == new_predictions).all()) self.assertTrue((ranks == new_ranks).all())
def test_save_and_load_model(self): model = TensorRec(n_components=10) model.fit(self.interactions, self.user_features, self.item_features, epochs=10) predictions = model.predict(user_features=self.user_features, item_features=self.item_features) ranks = model.predict_rank(user_features=self.user_features, item_features=self.item_features) model.save_model(directory_path=self.test_dir) # Check that, after saving, the same predictions come back predictions_after_save = model.predict(user_features=self.user_features, item_features=self.item_features) ranks_after_save = model.predict_rank(user_features=self.user_features, item_features=self.item_features) self.assertTrue((predictions == predictions_after_save).all()) self.assertTrue((ranks == ranks_after_save).all()) # Blow away the session set_session(None) tf.reset_default_graph() # Reload the model, predict, and check for equal predictions new_model = TensorRec.load_model(directory_path=self.test_dir) new_predictions = new_model.predict(user_features=self.user_features, item_features=self.item_features) new_ranks = new_model.predict_rank(user_features=self.user_features, item_features=self.item_features) self.assertTrue((predictions == new_predictions).all()) self.assertTrue((ranks == new_ranks).all())
def test_save_and_load_model_same_session(self): model = TensorRec(n_components=10) model.fit(self.interactions, self.user_features, self.item_features, epochs=10) predictions = model.predict(user_features=self.user_features, item_features=self.item_features) ranks = model.predict_rank(user_features=self.user_features, item_features=self.item_features) model.save_model(directory_path=self.test_dir) # Reload the model, predict, and check for equal predictions new_model = TensorRec.load_model(directory_path=self.test_dir) new_predictions = new_model.predict(user_features=self.user_features, item_features=self.item_features) new_ranks = new_model.predict_rank(user_features=self.user_features, item_features=self.item_features) self.assertEqual(predictions.all(), new_predictions.all()) self.assertEqual(ranks.all(), new_ranks.all())
def test_save_and_load_model_same_session(self): model = TensorRec(n_components=10) model.fit(self.interactions, self.user_features, self.item_features, epochs=10) predictions = model.predict(user_features=self.user_features, item_features=self.item_features) ranks = model.predict_rank(user_features=self.user_features, item_features=self.item_features) model.save_model(directory_path=self.test_dir) # Reload the model, predict, and check for equal predictions new_model = TensorRec.load_model(directory_path=self.test_dir) new_predictions = new_model.predict(user_features=self.user_features, item_features=self.item_features) new_ranks = new_model.predict_rank(user_features=self.user_features, item_features=self.item_features) self.assertTrue((predictions == new_predictions).all()) self.assertTrue((ranks == new_ranks).all())