コード例 #1
0
ファイル: test_io_utils.py プロジェクト: junshi15/gdmix
 def testLoadModel(self):
     models = load_linear_models_from_avro(self.model_file, self.feature_file)
     for model, expected in zip(models, self.expected_models):
         self.assertAllEqual(model, expected)
     short_models = load_linear_models_from_avro(self.model_file, self.short_feature_file)
     for i in range(len(short_models)):
         self.assertAllEqual(short_models[i], self.expected_short_models[i])
コード例 #2
0
 def _check_model(self, coefficients, model_dir, feature_file):
     """
     Check if model coefficients are as expected
     :param coefficients: Expected coefficients
     :param model_dir: Model directory
     :return: None
     """
     model_file = os.path.join(model_dir, "part-00000.avro")
     model = load_linear_models_from_avro(model_file, feature_file)[0]
     self.assertAllClose(coefficients, model, msg='models mismatch')
コード例 #3
0
 def _load_model(self, catch_exception=False):
     """ Load model from avro file. """
     model = None
     logging("Loading model from {}".format(self.checkpoint_path))
     model_exist = self.checkpoint_path and tf1.io.gfile.exists(self.checkpoint_path)
     if model_exist:
         model_file = tf1.io.gfile.glob("{}/*.avro".format(self.checkpoint_path))
         if len(model_file) == 1:
             model = load_linear_models_from_avro(model_file[0], self.feature_file)[0]
         elif not catch_exception:
             raise ValueError("Load model failed, no model file or multiple model"
                              " files found in the model diretory {}".format(self.checkpoint))
     elif not catch_exception:
         raise FileNotFoundError("checkpoint path {} doesn't exist".format(self.checkpoint_path))
     if self.feature_bag_name is None and model is not None:
         # intercept only model, add a dummy weight.
         model = add_dummy_weight(model)
     return model