コード例 #1
0
ファイル: analysis.py プロジェクト: cthoyt/embeddingdb
def main(id_1: int, id_2: int, model: Optional[str],
         output: Optional[BinaryIO], connection: str):
    """Perform a regression between two collections."""
    session = get_session(connection=connection)
    collection_1 = _get_collection(session, id_1)
    collection_2 = _get_collection(session, id_2)

    clf, r2, intersect, intersect_percent = perform_regression(
        collection_1,
        collection_2,
        regression_cls=model,
        output=output,
    )
    click.echo(f'Model: {clf}')
    click.echo(f'Dimensions: {clf.coef_.shape}')
    click.echo(f'R^2: {r2:.3f}')
    click.echo(f'Intersection: {intersect} ({intersect_percent:.1%})')
コード例 #2
0
def ls(limit: Optional[int], connection: str):
    """List the collections in the database."""
    session = get_session(connection)
    collections = session.query(Collection)
    if limit is not None:
        collections = collections.limit(limit)
    click.echo('\t'.join((
        'collection_id',
        'dimensions',
        'package_name',
        'package_version',
        'extras',
    )))
    for collection in collections:
        click.echo('\t'.join(
            (str(collection.id), str(collection.dimensions),
             collection.package_name, collection.package_version,
             json.dumps(collection.extras) if collection.extras else '{}')))
コード例 #3
0
ファイル: base.py プロジェクト: kantholtz/pykeen
    def to_embeddingdb(self, session=None, use_tqdm: bool = False):
        """Upload to the embedding database.

        :param session: Optional SQLAlchemy session
        :param use_tqdm: Use :mod:`tqdm` progress bar?
        :rtype: embeddingdb.sql.models.Collection
        """
        from embeddingdb.sql.models import Embedding, Collection

        if session is None:
            from embeddingdb.sql.models import get_session
            session = get_session()

        collection = Collection(
            package_name='pykeen',
            package_version=get_version(),
            dimensions=self.embedding_dim,
        )

        embeddings = self.entity_embeddings.weight.detach().cpu().numpy()
        names = sorted(
            self.triples_factory.entity_to_id,
            key=self.triples_factory.entity_to_id.get,
        )

        if use_tqdm:
            names = tqdm(names, desc='Building SQLAlchemy models')
        for name, embedding in zip(names, embeddings):
            embedding = Embedding(
                collection=collection,
                curie=name,
                vector=list(embedding),
            )
            session.add(embedding)
        session.add(collection)
        session.commit()
        return collection