示例#1
0
 def as_bow(nbow: str, id2vec: str) -> BOW:
     bow = NBOW().load(source=nbow)
     if id2vec:
         id2vec = Id2Vec().load(source=id2vec)
     else:
         id2vec = Id2Vec().load(source=bow.get_dependency("id2vec")["uuid"])
     bow.become_bow(id2vec)
     del id2vec
     return bow
示例#2
0
def postprocess(args):
    """
    Merges row and column embeddings produced by Swivel and writes the Id2Vec
    model.

    :param args: :class:`argparse.Namespace` with "swivel_output_directory" \
                 and "result". The text files are read from \
                 `swivel_output_directory` and the model is written to \
                 `result`.
    :return: None
    """
    log = logging.getLogger("postproc")
    log.info("Parsing the embeddings at %s...", args.swivel_output_directory)
    tokens = []
    embeddings = []
    swd = args.swivel_output_directory
    with open(os.path.join(swd, "row_embedding.tsv")) as frow:
        with open(os.path.join(swd, "col_embedding.tsv")) as fcol:
            for i, (lrow, lcol) in enumerate(zip(frow, fcol)):
                if i % 10000 == (10000 - 1):
                    sys.stdout.write("%d\r" % (i + 1))
                    sys.stdout.flush()
                prow, pcol = (l.split("\t", 1) for l in (lrow, lcol))
                assert prow[0] == pcol[0]
                tokens.append(prow[0][:TokenParser.MAX_TOKEN_LENGTH])
                erow, ecol = \
                    (numpy.fromstring(p[1], dtype=numpy.float32, sep="\t")
                     for p in (prow, pcol))
                embeddings.append((erow + ecol) / 2)
    log.info("Generating numpy arrays...")
    embeddings = numpy.array(embeddings, dtype=numpy.float32)
    log.info("Writing %s...", args.result)
    model = Id2Vec()
    model.construct(embeddings=embeddings, tokens=tokens)
    model.save(args.result)
示例#3
0
 def __init__(self, id2vec=None, docfreq=None, gcs_bucket=None, **kwargs):
     if gcs_bucket:
         backend = create_backend("gcs", "bucket=" + gcs_bucket)
     else:
         backend = None
     self._id2vec = kwargs["id2vec"] = Id2Vec().load(id2vec or None,
                                                     backend=backend)
     self._df = kwargs["docfreq"] = DocumentFrequencies().load(
         docfreq or None, backend=backend)
     super(Repo2nBOWTransformer, self).__init__(**kwargs)
示例#4
0
 def __init__(self, id2vec=None, docfreq=None, gcs_bucket=None, **kwargs):
     if gcs_bucket:
         backend = create_backend("gcs", "bucket=" + gcs_bucket)
     else:
         backend = None
     self._id2vec = kwargs["id2vec"] = Id2Vec().load(id2vec or None, backend=backend)
     self._df = kwargs["docfreq"] = DocumentFrequencies().load(docfreq or None, backend=backend)
     prune_df = kwargs.pop("prune_df", 1)
     if prune_df > 1:
         self._df = self._df.prune(prune_df)
     super().__init__(**kwargs)
示例#5
0
 def setUp(self):
     self.model = Id2Vec().load(
         source=os.path.join(os.path.dirname(__file__), paths.ID2VEC))