Exemplo n.º 1
0
Arquivo: el.py Projeto: zxlzr/nordlys
    def link(self, query, qid=""):
        """Performs entity linking for the query.

        :param query: query string
        :return: annotated query
        """
        PLOGGER.info("Linking query " + qid + " [" + query + "] ")
        q = Query(query, qid)
        linker = self.__get_linker(q)
        if self.__config["step"] == "ranking":
            res = linker.rank_ens()
        else:
            linked_ens = linker.link()
            res = {
                "query": q.raw_query,
                "processed_query": q.query,
                "results": linked_ens
            }
        return res
Exemplo n.º 2
0
    def _second_pass_scoring(self, res1, scorer):
        """Returns second-pass scoring of documents.

        :param res1: first pass results
        :param scorer: scorer object
        :return: RetrievalResults object
        """
        PLOGGER.debug("\tSecond pass scoring... ", )
        for field in self.__get_fields():
            self.__elastic.multi_termvector(list(res1.keys()), field)

        res2 = {}
        for doc_id in res1.keys():
            res2[doc_id] = {
                "score": scorer.score_doc(doc_id),
                "fields": res1[doc_id].get("fields", {})
            }
        PLOGGER.debug("done")
        return res2
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("collection", help="name of the collection")
    parser.add_argument("doc_id", help="doc_id to be looked up")
    args = parser.parse_args()

    if args.collection:
        coll = args.collection
    if args.doc_id:
        doc_id = args.doc_id

    mongo = Mongo(MONGO_HOST, MONGO_DB, coll)

    # currently, a single operation (lookup) is supported
    res = mongo.find_by_id(doc_id)
    if res is None:
        PLOGGER.info("Document ID " + doc_id + " cannot be found")
    else:
        mongo.print_doc(res)
Exemplo n.º 4
0
    def to_json(self, file_name=None):
        """Converts instance to the JSON format.

        :param file_name: (string)
        :return JSON dump of the instance.
        """
        json_ins = {
            self.__id: {
                "target": self.target,
                "score": self.score,
                "features": self.__features,
                "properties": self.__properties
            }
        }
        if file_name is not None:
            PLOGGER.info("writing instance \"" + str(self.__id) + "\" to " +
                         file_name + "...")
            out = open(file_name, "w")
            json.dump(json_ins, out, indent=4)
        return json_ins
Exemplo n.º 5
0
    def load_fb2dbp_mapping(self):
        """Checks Freebase IDs that are mapped to more than one entity and keeps only one of them."""
        mappings = defaultdict(list)
        fb2dbp_39 = self.read_fb2dbp_file(is_39=True)
        fb2dbp = self.read_fb2dbp_file()

        for fb_id, dbp_ids in fb2dbp.items():
            if len(dbp_ids) > 1:
                dbp_ids_39 = fb2dbp_39.get(fb_id, None)
                dbp_id_39 = dbp_ids_39.pop() if dbp_ids_39 else None
                if dbp_id_39 in dbp_ids:
                    mappings[fb_id].append(dbp_id_39)
                else:
                    mappings[fb_id] = list(dbp_ids)
                    PLOGGER.info(fb_id, "3.9", dbp_id_39, "2015", dbp_ids)
            else:
                mappings[fb_id] = list(dbp_ids)

        PLOGGER.info(len(mappings))
        return mappings
Exemplo n.º 6
0
    def load_yerd(gt_file):
        """
        Reads the Y-ERD collection and returns a dictionary.

        :param gt_file: Path to the Y-ERD collection
        :return: dictionary {(qid, query, en_id, mention) ...}
        """
        PLOGGER.info("Loading the ground truth ...")
        gt = set()
        with open(gt_file, "r") as tsvfile:
            reader = csv.DictReader(tsvfile,
                                    delimiter="\t",
                                    quoting=csv.QUOTE_NONE)
            for line in reader:
                if line["entity"] == "":
                    continue
                query = Query(line["query"]).query
                mention = Query(line["mention"]).query
                gt.add((line["qid"], query, line["entity"], mention))
        return gt
Exemplo n.º 7
0
Arquivo: ml.py Projeto: zxlzr/nordlys
    def apply_model(self, instances, model):
        """Applies model on a given set of instances.

        :param instances: Instances object
        :param model: trained model
        :return: Instances
        """
        PLOGGER.info("Applying model ... ")
        if len(instances.get_all()) > 0:
            features_names = sorted(instances.get_all()[0].features.keys())
            for ins in instances.get_all():
                test_x = numpy.array(
                    [[ins.features[ftr] for ftr in features_names]])
                if self.__config.get("category", "regression") == "regression":
                    ins.score = model.predict(test_x)[0]
                else:  # classification
                    ins.target = str(model.predict(test_x)[0])
                    # "predict_proba" gets class probabilities; an array of probabilities for each class e.g.[0.99, 0.1]
                    ins.score = model.predict_proba(test_x)[0][1]
        return instances
Exemplo n.º 8
0
    def parse_file(self, filename, triplehandler):
        """Parses file and calls callback function with the parsed triple"""
        PLOGGER.info("Processing " + filename + "...")

        prefix = URIPrefix()
        t = Triple(prefix)
        p = NTriplesParser(t)
        i = 0

        with open(filename) as f:
            for line in f:
                p.parsestring(line)
                if t.subject() is None:  # only if parsed as a triple
                    continue

                # call the handler object with the parsed triple
                triplehandler.triple_parsed(t)

                i += 1
                if i % 10000 == 0:
                    PLOGGER.info(str(i / 1000) + "K lines processed")
Exemplo n.º 9
0
    def append_set(self, doc_id, field, value):
        """Adds a list of values to a set.
        If the field does not exist yet, it will be created.
        The value should be a list.

        :param doc_id: document id
        :param field: field
        :param value: list, a value to be appended to the current list
        """
        try:
            self.__collection.update(
                {Mongo.ID_FIELD: self.__escape(doc_id)},
                {'$addToSet': {
                    self.__escape(field): {
                        '$each': value
                    }
                }},
                upsert=True)
        except Exception as e:
            PLOGGER.error("\nError (doc_id: " + str(doc_id) + "), field: " +
                          field + "\n" + str(e))
Exemplo n.º 10
0
    def gen_train_set(gt, query_file, train_set):
        """Trains LTR model for entity linking."""
        entity, elastic, fcache = Entity(), ElasticCache(
            ELASTIC_INDICES[0]), FeatureCache()
        inss = Instances()
        positive_annots = set()

        # Adds groundtruth instances (positive instances)
        PLOGGER.info("Adding groundtruth instances (positive instances) ....")
        for item in sorted(gt):  # qid, query, en_id, mention
            ltr = LTR(Query(item[1], item[0]), entity, elastic, fcache)
            ins = ltr.__gen_raw_ins(item[2], item[3])
            ins.features = ltr.get_features(ins)
            ins.target = 1
            inss.add_instance(ins)
            positive_annots.add((item[0], item[2]))

        # Adds all other instances
        PLOGGER.info("Adding all other instances (negative instances) ...")
        for qid, q in sorted(json.load(open(query_file, "r")).items()):
            PLOGGER.info("Query [" + qid + "]")
            ltr = LTR(Query(q, qid), entity, elastic, fcache)
            q_inss = ltr.get_candidate_inss()
            for ins in q_inss.get_all():
                if (qid, ins.get_property("en_id")) in positive_annots:
                    continue
                ins.target = 0
                inss.add_instance(ins)
        inss.to_json(train_set)
Exemplo n.º 11
0
Arquivo: el.py Projeto: zxlzr/nordlys
    def batch_linking(self):
        """Scores queries in a batch and outputs results."""
        results = {}

        if self.__config["step"] == "linking":
            queries = json.load(open(self.__query_file))
            for qid in sorted(queries):
                results[qid] = self.link(queries[qid], qid)
            to_elq_eval(results, self.__output_file)
            # json.dump(results, open(self.__output_file, "w"), indent=4, sort_keys=True)

        # only ranking step
        if self.__config["step"] == "ranking":
            queries = json.load(open(self.__query_file))
            for qid in sorted(queries):
                linker = self.__get_linker(Query(queries[qid], qid))
                results[qid] = linker.rank_ens()
            ranked_inss = Instances(
                sum([inss.get_all() for inss in results.values()], []))
            ranked_inss.to_treceval(self.__output_file)
            if self.__config.get("json_file", None):
                ranked_inss.to_json(self.__config["json_file"])

        # only disambiguation step
        if self.__config["step"] == "disambiguation":
            inss = Instances.from_json(self.__config["test_set"])
            inss_by_query = inss.group_by_property("qid")
            for qid, q_inss in sorted(inss_by_query.items()):
                linker = self.__get_linker("")
                results[qid] = {
                    "results": linker.disambiguate(Instances(q_inss))
                }
            if self.__config.get("json_file", None):
                json.dump(open(self.__config["json_file"], "w"),
                          results,
                          indent=4,
                          sort_keys=True)
            to_elq_eval(results, self.__output_file)

        PLOGGER.info("Output file: " + self.__output_file)
Exemplo n.º 12
0
    def __load_entity_abstracts(self, filename):
        prefix = URIPrefix()
        t = Triple()
        p = NTriplesParser(t)
        lines_counter = 0
        PLOGGER.info("Loading entity abstracts from {}".format(filename))
        for line in FileUtils.read_file_as_list(filename):
            # basic line parsing
            line = line.decode("utf-8") if isinstance(line, bytes) else line
            try:
                p.parsestring(line)
            except ParseError:  # skip lines that couldn't be parsed
                continue
            if t.subject() is None:  # only if parsed as a triple
                continue

            # Subject and object identification
            subj = prefix.get_prefixed(t.subject())
            obj = ""
            if type(t.object()) is URIRef:
                # PLOGGER.error("Error: it is URIRef the parsed obj")
                pass
            else:
                obj = t.object().encode("utf-8")
                if len(obj) == 0:
                    continue  # skip empty objects
            self.__entity_abstracts[subj] = obj

            lines_counter += 1
            if lines_counter % 10000 == 0:
                PLOGGER.info("\t{}K lines processed".format(lines_counter // 1000))
                pass

        PLOGGER.info("\n### Loading entity abstracts... Done.")
Exemplo n.º 13
0
 def triple_parsed(self, triple):
     PLOGGER.info("S: " + triple.subject() + " ==> " +
                  triple.subject_prefixed())
     PLOGGER.info("  P: " + triple.predicate() + " ==> " +
                  triple.predicate_prefixed())
     PLOGGER.info("  O: " + triple.object() + " ==> " +
                  triple.object_prefixed())
Exemplo n.º 14
0
    def __type_centric(self, query):
        """Type-centric TTI.

        :param query: query string
        :type query: str
        """
        types = dict()
        model = self.__config.get("model", TTI_MODEL_BM25)
        elastic = ElasticCache(
            self.__tc_config.get("index", DEFAULT_TTI_TC_INDEX))

        if model == TTI_MODEL_BM25:
            PLOGGER.info("TTI, TC, BM25")
            self.__tc_config["model"] = "bm25"
            # scorer = Scorer.get_scorer(elastic, query, self.__tc_config)
            types = Retrieval(self.__tc_config).retrieve(query)

        elif model == TTI_MODEL_LM:
            PLOGGER.debug("TTI, TC, LM")
            self.__tc_config["model"] = "lm"  # Needed for 2nd-pass
            self.__tc_config["field"] = "content"  # Needed for 2nd-pass
            self.__tc_config["second_pass"] = {"field": "content"}
            for param in ["smoothing_method", "smoothing_param"]:
                if self.__config.get(param, None) is not None:
                    self.__tc_config["second_pass"][param] = self.__config.get(
                        param)

            scorer = Scorer.get_scorer(elastic, query, self.__tc_config)
            types = Retrieval(self.__tc_config).retrieve(query, scorer)

            PLOGGER.info(types)

        return types
Exemplo n.º 15
0
    def build(self, callback_get_doc_content, bulk_size=1000):
        """Builds the DBpedia index from the mongo collection.

        To speedup indexing, we index documents as a bulk.
        There is an optimum value for the bulk size; try to figure it out.

        :param callback_get_doc_content: a function that get a documet from mongo and return the content for indexing
        :param bulk_size: Number of documents to be added to the index as a bulk
        """
        PLOGGER.info("Building " + self.__index_name + " ...")
        elastic = Elastic(self.__index_name)
        elastic.create_index(self.__mappings, model=self.__model, force=True)

        i = 0
        docs = dict()
        for mdoc in self.__mongo.find_all(no_timeout=True):
            docid = Mongo.unescape(mdoc[Mongo.ID_FIELD])

            # get back document from mongo with keys and _id field unescaped
            doc = callback_get_doc_content(Mongo.unescape_doc(mdoc))
            if doc is None:
                continue
            docs[docid] = doc

            i += 1
            if i % bulk_size == 0:
                elastic.add_docs_bulk(docs)
                docs = dict()
                PLOGGER.info(str(i / 1000) + "K documents indexed")
        # indexing the last bulk of documents
        elastic.add_docs_bulk(docs)
        PLOGGER.info("Finished indexing (" + str(i) + " documents in total)")
Exemplo n.º 16
0
    def read_fb2dbp_file(self, is_39=False):
        """Reads the file and generates an initial mapping of Freebase to DBpedia IDs.
        Only proper DBpedia entities are considered; i.e. redirect and disambiguation pages are ignored.
        """
        fb2dbp_file = self.__fb2dbp_file_39 if is_39 else self.__fb2dbp_file
        PLOGGER.info("Processing " + fb2dbp_file + "...")

        t = Triple()
        p = NTriplesParser(t)
        i = 0
        fb2dbp_mapping = defaultdict(set)
        with FileUtils.open_file_by_type(fb2dbp_file) as f:
            for line in f:
                try:
                    p.parsestring(line.decode("utf-8"))
                except ParseError:  # skip lines that couldn't be parsed
                    continue
                if t.subject() is None:  # only if parsed as a triple
                    continue

                # prefixing
                dbp_id = self.__prefix.get_prefixed(t.subject())
                fb_id = self.__prefix.get_prefixed(t.object())

                # if reading 3.9 file, converts ID to 2015-10 version
                if is_39:
                    dbp_id = EntityUtils.convert_39_to_201510(dbp_id)
                    fb2dbp_mapping[fb_id].add(dbp_id)

                # if reading 2015-10 file, keeps only the proper DBpedia entities
                else:
                    entity_utils = EntityUtils(
                        self.__mongo_dbpedia.find_by_id(dbp_id))
                    if entity_utils.is_entity():
                        fb2dbp_mapping[fb_id].add(dbp_id)
                i += 1
                if i % 1000 == 0:
                    PLOGGER.info(str(i // 1000) + "K lines are processed!")
        return fb2dbp_mapping
Exemplo n.º 17
0
    def __load_entity_types(self):
        num_lines = 0
        for types_file in ENTITY_TYPES_FILES:
            filename = os.sep.join([self.__dbpedia_path, types_file])
            PLOGGER.info("Loading entity types from {}".format(filename))
            for line in FileUtils.read_file_as_list(filename):
                entity, entity_type = self.__parse_line(line)
                if type(entity_type) != str:  # Likely result of parsing error
                    continue
                if not entity_type.startswith("<dbo:"):
                    PLOGGER.info("  Non-DBpedia type: {}".format(entity_type))
                    continue
                if not entity.startswith("<dbpedia:"):
                    PLOGGER.info("  Invalid entity: {}".format(entity))
                    continue
                self.__types_entities[entity_type].append(entity)

                num_lines += 1
                if num_lines % 10000 == 0:
                    PLOGGER.info("  {}K lines processed".format(num_lines //
                                                                1000))
            PLOGGER.info("  Done.")
Exemplo n.º 18
0
def compute_field_counts():
    """Reads all documents in the Mongo collection and calculates field frequencies.
        i.e. For DBpedia collection, it returns all entity fields.

    :return a dictionary of fields and their frequency
    """
    PLOGGER.info("Counting fields ...")
    dbpedia_coll = Mongo(MONGO_HOST, MONGO_DB, MONGO_COLLECTION_DBPEDIA).find_all()
    i = 0
    field_counts = dict()
    for entity in dbpedia_coll:
        for field in entity:
            if field == Mongo.ID_FIELD:
                continue
            if field in field_counts:
                field_counts[field] += 1
            else:
                field_counts[field] = 1
        i += 1
        if i % 1000000 == 0:
            PLOGGER.info("\t" + str(int(i / 1000000)) + "M entity is processed!")
    return field_counts
Exemplo n.º 19
0
 def __check_config(config):
     """Checks params and set default values."""
     try:
         if KEY_COLLECTION not in config:
             raise Exception(KEY_COLLECTION + " is missing")
         if KEY_OPERATION not in config:
             config[KEY_OPERATION] = KEY_APPEND
         if KEY_PATH not in config:
             raise Exception(KEY_PATH + " is missing")
         if KEY_FILES not in config:
             raise Exception(KEY_FILES + " is missing")
         # reads all files
         existing_files = set()
         for subdir, dir, files in os.walk(config[KEY_PATH]):
             for file in files:
                 existing_files.add(os.path.join(subdir, file))
         for file in config[KEY_FILES]:
             dbpedia_file = config[KEY_PATH] + file[KEY_FILE_NAME]
             if dbpedia_file not in existing_files:
                 raise Exception(dbpedia_file + " does not exist.")
     except Exception as e:
         PLOGGER.error("Error in config file: ", e)
         sys.exit(1)
Exemplo n.º 20
0
    def get_top_fields(self):
        """Gets top-n frequent fields from DBpedia
        NOTE: Rank of fields with the same frequency is equal.
              This means that there can more than one field for each rank.
        """
        PLOGGER.info("Getting the top-n frequent DBpedia fields ...")
        sorted_fields = sorted(self.__field_counts.items(), key=lambda item: item[1], reverse=True)
        PLOGGER.info("Number of total fields: " + str(len(sorted_fields)))

        top_fields = []
        rank, prev_count, i = 0, 0, 0
        for field, count in sorted_fields:
            if field in self._config["blacklist"]:
                continue
            # changes the rank if the count number is changed
            i += 1
            if prev_count != count:
                rank = i
            prev_count = count
            if rank > self.__n:
                break
            top_fields.append(field)
        self.__top_fields = top_fields
Exemplo n.º 21
0
    def __make_type_doc(self, entities, last_type):
        """Gets the document representation of a type to be indexed, from its entity short abstracts."""
        content = ABSTRACTS_SEPARATOR.join([self.__entity_abstracts.get(e, b"").decode("utf-8")
                                            for e in entities])

        if len(content) > MAX_BULKING_DOC_SIZE:

            PLOGGER.info("Type {} has content larger than allowed: {}.".format(last_type, len(content)))

            # we randomly sample a subset of Y entity abstracts, s.t. Y * AVG_SHORT_ABSTRACT_LEN <= MAX_BULKING_DOC_SIZE
            amount_abstracts_to_sample = min(floor(MAX_BULKING_DOC_SIZE / AVG_SHORT_ABSTRACT_LEN), len(entities))
            entities_sample = [entities[i] for i in sample(range(len(entities)), amount_abstracts_to_sample)]
            content = ""  # reset content

            for entity in entities_sample:
                new_content_candidate = (content + ABSTRACTS_SEPARATOR +
                                         self.__entity_abstracts.get(entity, b"").decode("utf-8"))
                # we add an abstract only if by doing so it will not exceed MAX_BULKING_DOC_SIZE
                if len(new_content_candidate) <= MAX_BULKING_DOC_SIZE:
                    content = new_content_candidate
                else:
                    break

        return {CONTENT_KEY: content}
Exemplo n.º 22
0
def main(args):
    config = FileUtils.load_config(args.config)
    dbpedia_path = config.get("dbpedia_files_path", "")
    # Check DBpedia files
    PLOGGER.info("Checking needed DBpedia files under {}".format(dbpedia_path))
    for fname in [ENTITY_ABSTRACTS_FILE] + ENTITY_TYPES_FILES:
        if os.path.isfile(os.sep.join([dbpedia_path, fname])):
            PLOGGER.info("  - {}: OK".format(fname))
        else:
            PLOGGER.error("  - {}: Missing".format(fname))
            exit(1)

    indexer = IndexerDBpediaTypes(config)
    indexer.build_index(force=True)
Exemplo n.º 23
0
    def __load_entity_abstracts(self):
        num_lines = 0
        filename = os.sep.join([self.__dbpedia_path, ENTITY_ABSTRACTS_FILE])
        PLOGGER.info("Loading entity abstracts from {}".format(filename))
        for line in FileUtils.read_file_as_list(filename):
            entity, abstract = self.__parse_line(line)
            if abstract and len(abstract) > 0:  # skip empty objects
                self.__entity_abstracts[entity] = abstract

            num_lines += 1
            if num_lines % 10000 == 0:
                PLOGGER.info("  {}K lines processed".format(num_lines // 1000))

        PLOGGER.info("  Done.")
Exemplo n.º 24
0
    def build_index(self, force=False):
        """Builds the index.

        Note: since DBpedia only has a few hundred types, no bulk indexing is
        needed.

        :param force: True iff it is required to overwrite the index (i.e. by
        creating it by force); False by default.
        :type force: bool
        :return:
        """
        PLOGGER.info("Building type index {}".format(self.__index_name))
        self.__elastic = Elastic(self.__index_name)
        self.__elastic.create_index(mappings=self.__MAPPINGS, force=force)

        for type_name in self.__types_entities:
            PLOGGER.info("  Adding {} ...".format(type_name))
            contents = self.__make_type_doc(type_name)
            self.__elastic.add_doc(type_name, contents)

        PLOGGER.info("  Done.")
Exemplo n.º 25
0
    def get_scorer(elastic, query, config):
        """Returns Scorer object (Scorer factory).

        :param elastic: Elastic object
        :param query: raw query (to be analyzed)
        :param config: dict with models parameters
        """
        model = config.get("model", None)
        if model == "lm":
            PLOGGER.debug("\tLM scoring ... ")
            return ScorerLM(elastic, query, config)
        elif model == "mlm":
            PLOGGER.debug("\tMLM scoring ...")
            return ScorerMLM(elastic, query, config)
        elif model == "prms":
            PLOGGER.debug("\tPRMS scoring ...")
            return ScorerPRMS(elastic, query, config)
        elif model is None:
            return None
        else:
            raise Exception("Unknown model " + model)
Exemplo n.º 26
0
Arquivo: ml.py Projeto: zxlzr/nordlys
    def train_model(self, instances):
        """Trains model on a given set of instances.

        :param instances: Instances object
        :return: the learned model
        """

        features = instances.get_all()[0].features
        features_names = sorted(features.keys())
        PLOGGER.info("Number of instances:\t" + str(len(instances.get_all())))
        PLOGGER.info("Number of features:\t" + str(len(features_names)))
        # Converts instances to Scikit-learn format : (n_samples, n_features)
        n_samples = len(instances.get_all())
        train_x = numpy.zeros((n_samples, len(features_names)))
        train_y = numpy.empty(n_samples,
                              dtype=object)  # numpy.zeros(n_samples)
        for i, ins in enumerate(instances.get_all()):
            train_x[i] = [ins.features[ftr] for ftr in features_names]
            if self.__config.get("category", "regression") == "regression":
                train_y[i] = float(ins.target)
            else:
                train_y[i] = str(ins.target)
        # training
        model = self.gen_model(len(features))
        model.fit(train_x, train_y)

        # write the trained model to the file
        if "model_file" in self.__config:
            # @todo if CV is used we need to append the fold no. to the filename
            PLOGGER.info("Writing trained model to {} ...".format(
                self.__config["model_file"]))
            pickle.dump(model, open(self.__config["model_file"], "wb"))

        if "feature_imp_file" in self.__config:
            print(self.analyse_features(model, features_names))
        return model
Exemplo n.º 27
0
    def to_libsvm(self, file_name=None, qid_prop=None):
        """
        Converts all instances to the LibSVM format and writes them to the file.
        - Libsvm format:
            <line> .=. <target> qid:<qid> <feature>:<value> ... # <info>
            <target> .=. <float>
            <qid> .=. <positive integer>
            <feature> .=. <positive integer>
            <value> .=. <float>
            <info> .=. <string>
        - Example: 3 qid:1 1:1 2:1 3:0 4:0.2 5:0 # 1A

        NOTES:
            - The property used for qid(qid_prop) should hold integers
            - For pointwise algorithms, we use instance id for qid
            - Lines in the RankLib input have to be sorted by increasing qid.

        :param file_name: File to write libsvm format of instances.
        :param qid_prop: property to be used as qid. If none,
        """
        # If no entity matches query
        if len(self.__instances) == 0:
            PLOGGER.info("No instance is created!!")
            open(file_name, "w").write("")
            return ""

        # Getting features
        ins = next(iter(self.__instances.values()))
        features = sorted(list(ins.features.keys()))

        # cleans previous contents
        open(file_name, "w").close()
        out_file = open(file_name, "a")

        # Adding feature names as header of libsvm file
        out = "# target instance_Id"
        for feature in features:
            out += " " + feature
        out += "\n"

        # sort instances by qid
        if qid_prop is None:
            sorted_instances = sorted(self.get_all(),
                                      key=lambda ins: int(ins.id))
        else:
            sorted_instances = sorted(
                self.get_all(),
                key=lambda ins: int(ins.get_property(qid_prop)))

        counter = 0
        PLOGGER.info("Converting instances to ranklib format ...")
        for ins in sorted_instances:
            out += ins.to_libsvm(features, qid_prop) + "\n"
            counter += 1
            # write the instances to the file
            if (counter % 1000) == 0:
                out_file.write(out)
                out = ""
                # print "Converting is done until instance " + str(ins.id)
        out_file.write(out)
        PLOGGER.info("Libsvm output:\t" + file_name)
Exemplo n.º 28
0
Arquivo: ml.py Projeto: zxlzr/nordlys
    def gen_model(self, num_features=None):
        """ Reads parameters and generates a model to be trained.

        :param num_features: int, number of features
        :return untrained ranker/classifier
        """
        model = None
        if self.__config["model"].lower() == "gbrt":
            alpha = self.__config["parameters"].get("alpha", 0.1)
            tree = self.__config["parameters"].get("tree", 1000)
            default_depth = round(num_features /
                                  10.0) if num_features is not None else None
            depth = self.__config["parameters"].get("depth", default_depth)

            PLOGGER.info("Training instances using GBRT ...")
            PLOGGER.info("Number of trees: " + str(tree) +
                         "\tDepth of trees: " + str(depth))
            if self.__config.get("category", "regression") == "regression":
                PLOGGER.info("Training regressor")
                model = GradientBoostingRegressor(n_estimators=tree,
                                                  max_depth=depth,
                                                  learning_rate=alpha)
            else:
                PLOGGER.info("Training the classifier")
                model = GradientBoostingClassifier(n_estimators=tree,
                                                   max_depth=depth,
                                                   learning_rate=alpha)

        elif self.__config["model"].lower() == "rf":
            tree = self.__config["parameters"].get("tree", 1000)
            default_maxfeat = round(num_features /
                                    10.0) if num_features is not None else None
            max_feat = self.__config["parameters"].get("maxfeat",
                                                       default_maxfeat)

            PLOGGER.info("Training instances using RF ...")
            PLOGGER.info("Number of trees: " + str(tree) + "\tMax features: " +
                         str(max_feat))
            if self.__config.get("category", "regression") == "regression":
                PLOGGER.info("Training regressor")
                model = RandomForestRegressor(n_estimators=tree,
                                              max_features=max_feat)
            else:
                PLOGGER.info("Training classifier")
                model = RandomForestClassifier(n_estimators=tree,
                                               max_features=max_feat)
        return model
Exemplo n.º 29
0
 def print_stat(self):
     """Prints simple statistics."""
     PLOGGER.info("#queries:  " + str(len(self.__results)))
     PLOGGER.info("#results: " +
                  str(sum(v.num_docs() for k, v in self.__results.items())))
Exemplo n.º 30
0
def main():
    ins = Instance(1, {"f1": "0.5", "f2": "0.4"}, "rel")
    ins.q_id = "q1"
    ins.q_content = "test query"
    ins_file = "../../src/output/instance.txt"
    PLOGGER.info(ins.to_json(ins_file))