예제 #1
0
def load_model_and_predict(algorithm_chosen, featurizer_output):
    session = featurizer_output['session']
    cands = featurizer_output['candidate_variable']
    featurizer = featurizer_output['featurizer_variable']

    if algorithm_chosen == 'logistic_regression':
        disc_model = LogisticRegression()
    elif algorithm_chosen == 'sparse_logistic_regression':
        disc_model = SparseLogisticRegression()
    else:
        disc_model = LSTM()

    # Manually load settings and cardinality from a saved trained model.
    checkpoint = torch.load(config.base_dir + '/checkpoints/' +
                            algorithm_chosen)
    disc_model.settings = checkpoint["config"]
    disc_model.cardinality = checkpoint["cardinality"]

    # Build a model using the loaded settings and cardinality.
    disc_model._build_model()

    disc_model.load(model_file=algorithm_chosen,
                    save_dir=config.base_dir + '/checkpoints')

    cand_list = [session.query(cands[0]).all()]
    cand_feature_matrix = featurizer.get_feature_matrices(cand_list)

    test_score = disc_model.predict((cand_list[0], cand_feature_matrix[0]),
                                    b=0.5,
                                    pos_label=TRUE)
    true_pred = [
        cand_list[0][_] for _ in np.nditer(np.where(test_score == TRUE))
    ]
    return true_pred
    def load_context(self, context: PythonModelContext) -> None:
        # Configure logging for Fonduer
        init_logging(log_dir="logs")
        logger.info("loading context")

        pyfunc_conf = _get_flavor_configuration(model_path=self.model_path,
                                                flavor_name=pyfunc.FLAVOR_NAME)
        conn_string = pyfunc_conf.get(CONN_STRING, None)
        if conn_string is None:
            raise RuntimeError("conn_string is missing from MLmodel file.")
        self.parallel = pyfunc_conf.get(PARALLEL, 1)
        session = Meta.init(conn_string).Session()

        logger.info("Getting parser")
        self.corpus_parser = self._get_parser(session)
        logger.info("Getting mention extractor")
        self.mention_extractor = self._get_mention_extractor(session)
        logger.info("Getting candidate extractor")
        self.candidate_extractor = self._get_candidate_extractor(session)
        candidate_classes = self.candidate_extractor.candidate_classes

        self.model_type = pyfunc_conf.get(MODEL_TYPE, "discriminative")
        if self.model_type == "discriminative":
            self.featurizer = Featurizer(session, candidate_classes)
            with open(os.path.join(self.model_path, "feature_keys.pkl"),
                      "rb") as f:
                key_names = pickle.load(f)
            self.featurizer.drop_keys(key_names)
            self.featurizer.upsert_keys(key_names)

            disc_model = LogisticRegression()

            # Workaround to https://github.com/HazyResearch/fonduer/issues/208
            checkpoint = torch.load(
                os.path.join(self.model_path, "best_model.pt"))
            disc_model.settings = checkpoint["config"]
            disc_model.cardinality = checkpoint["cardinality"]
            disc_model._build_model()

            disc_model.load(model_file="best_model.pt",
                            save_dir=self.model_path)
            self.disc_model = disc_model
        else:
            self.labeler = Labeler(session, candidate_classes)
            with open(os.path.join(self.model_path, "labeler_keys.pkl"),
                      "rb") as f:
                key_names = pickle.load(f)
            self.labeler.drop_keys(key_names)
            self.labeler.upsert_keys(key_names)

            self.gen_models = [
                LabelModel.load(
                    os.path.join(self.model_path, _.__name__ + ".pkl"))
                for _ in candidate_classes
            ]