Exemplo n.º 1
0
def get_oss_saved_model_type_and_estimator(model_name, project):
    """Get oss model type and estimator name, model can be:
    1. PAI ML models: model is saved by pai
    2. xgboost: on OSS with model file xgboost_model_desc
    3. PAI tensorflow models: on OSS with meta file: tensorflow_model_desc

    Args:
        model_name: the model to get info
        project: current odps project

    Returns:
        If model is TensorFlow model, return type and estimator name
        If model is XGBoost, or other PAI model, just return model type
    """
    # FIXME(typhoonzero): if the model not exist on OSS, assume it's a random
    # forest model should use a general method to fetch the model and see the
    # model type.
    bucket = oss.get_models_bucket()
    tf = bucket.object_exists(model_name + "/tensorflow_model_desc")
    if tf:
        modelType = EstimatorType.TENSORFLOW
        bucket.get_object_to_file(
            model_name + "/tensorflow_model_desc_estimator",
            "tmp_estimator_name")
        with open("tmp_estimator_name") as file:
            estimator = file.readline()
        return modelType, estimator

    xgb = bucket.object_exists(model_name + "/xgboost_model_desc")
    if xgb:
        modelType = EstimatorType.XGBOOST
        return modelType, "xgboost"

    return EstimatorType.PAIML, ""
Exemplo n.º 2
0
    def test_save_load_oss(self):
        bucket = oss.get_models_bucket()
        meta = {"model_params": {"n_classes": 3}}
        m = Model(EstimatorType.XGBOOST, meta)

        oss_dir = "unknown/model_test_dnn_classifier/"
        oss_model_path = "oss://%s/%s" % (bucket.bucket_name, oss_dir)

        oss.delete_oss_dir_recursive(bucket, oss_dir)

        # save model
        def save_to_oss():
            with temp_file.TemporaryDirectory() as d:
                m.save_to_oss(oss_model_path, d)

        # load model
        def load_from_oss():
            with temp_file.TemporaryDirectory() as d:
                return Model.load_from_oss(oss_model_path, d)

        with self.assertRaises(Exception):
            load_from_oss()

        save_to_oss()
        m = load_from_oss()
        self.assertEqual(m._meta, meta)
Exemplo n.º 3
0
def clean_oss_model_path(oss_path):
    bucket = oss.get_models_bucket()
    oss.delete_oss_dir_recursive(bucket, oss_path)