Exemple #1
0
class PartialStore(ABC):
    """A store spawned inside partial-daemon container"""
    def __init__(self):
        self._logger = JinaLogger(self.__class__.__name__, **vars(jinad_args))
        self.item = PartialStoreItem()
        self.object: Union[Type['BasePod'], Type['BaseDeployment'],
                           'Flow'] = None

    @abstractmethod
    def add(self, *args, **kwargs) -> PartialStoreItem:
        """Add a new element to the store. This method needs to be overridden by the subclass


        .. #noqa: DAR101"""
        ...

    def delete(self) -> None:
        """Terminates the object in the store & stops the server"""
        try:
            if hasattr(self.object, 'close'):
                self.object.close()
                self._logger.info(self.item.arguments)
                if self.item.arguments.get('identity'):
                    self._logger.success(
                        f'{colored(self.item.arguments["identity"], "cyan")} is removed!'
                    )
                else:
                    self._logger.success('object is removed!')
            else:
                self._logger.warning(f'nothing to close. exiting')
        except Exception as e:
            self._logger.error(f'{e!r}')
            raise
        else:
            self.item = PartialStoreItem()
Exemple #2
0
def _register_to_mongodb(logger: JinaLogger, summary: Optional[Dict] = None):
    """Hub API Invocation to run `hub push`.

    :param logger: the logger instance
    :param summary: the summary dict object
    """
    # TODO(Deepankar): implement to jsonschema based validation for summary
    logger.info('registering image to Jina Hub database...')
    with open(os.path.join(__resources_path__, 'hubapi.yml')) as fp:
        hubapi_yml = JAML.load(fp)
        hubapi_url = hubapi_yml['hubapi']['url'] + hubapi_yml['hubapi']['push']

    with ImportExtensions(
            required=True,
            help_text=
            'Missing "requests" dependency, please do pip install "jina[http]"',
    ):
        import requests

    headers = {
        'Accept': 'application/json',
        'authorizationToken': _fetch_access_token(logger),
    }
    response = requests.post(url=f'{hubapi_url}',
                             headers=headers,
                             data=json.dumps(summary))
    if response.status_code == requests.codes.ok:
        logger.success(f'✅ Successfully updated the database. {response.text}')
    else:
        raise HubLoginRequired(
            f'❌ Got an error from the API: {response.text.rstrip()}. '
            f'Please login using command: {colored("jina hub login", attrs=["bold"])}'
        )
Exemple #3
0
def log(logger: JinaLogger):
    logger.debug('this is test debug message')
    logger.info('this is test info message')
    logger.success('this is test success message')
    logger.warning('this is test warning message')
    logger.error('this is test error message')
    logger.critical('this is test critical message')
Exemple #4
0
class NameChangeExecutor(Executor):
    def __init__(self, runtime_args, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.name = runtime_args['name']
        self.logger = JinaLogger(self.name)

    @requests
    def foo(self, docs: DocumentArray, **kwargs):
        self.logger.info(f'doc count {len(docs)}')
        docs.append(Document(text=self.name))
        return docs
Exemple #5
0
def _list(
        logger: JinaLogger,
        image_name: Optional[str] = None,
        image_kind: Optional[str] = None,
        image_type: Optional[str] = None,
        image_keywords: Sequence = (),
) -> Optional[List[Dict[str, Any]]]:
    """Use Hub api to get the list of filtered images.

    :param logger: logger to use
    :param image_name: name of hub image
    :param image_kind: kind of hub image (indexer/encoder/segmenter/crafter/evaluator/ranker etc)
    :param image_type: type of hub image (pod/app)
    :param image_keywords: keywords added in the manifest yml
    :return: a dict of manifest specifications, each coresponds to a hub image
    """
    with open(os.path.join(__resources_path__, 'hubapi.yml')) as fp:
        hubapi_yml = JAML.load(fp)
        hubapi_url = hubapi_yml['hubapi']['url'] + hubapi_yml['hubapi']['list']

    params = {
        'name': image_name,
        'kind': image_kind,
        'type': image_type,
        'keywords': image_keywords,
    }
    params = {k: v for k, v in params.items() if v}
    if params:
        data = urlencode(params, doseq=True)
        request = Request(f'{hubapi_url}?{data}')
        with TimeContext('searching', logger):
            try:
                with urlopen(request) as resp:
                    response = json.load(resp)
            except HTTPError as err:
                if err.code == 400:
                    logger.warning(
                        'no matched executors found. please use different filters and retry.'
                    )
                elif err.code == 500:
                    logger.error(f'server is down: {err.reason}')
                else:
                    logger.error(f'unknown error: {err.reason}')
                return

        local_manifest = _load_local_hub_manifest()
        if local_manifest:
            tb = _make_hub_table_with_local(response, local_manifest)
        else:
            tb = _make_hub_table(response)
        logger.info('\n'.join(tb))
        return response
Exemple #6
0
def create(
    template: str,
    params: Dict,
    logger: JinaLogger = default_logger,
    custom_resource_dir: Optional[str] = None,
):
    """Create a resource on Kubernetes based on the `template`. It fills the `template` using the `params`.

    :param template: path to the template file.
    :param custom_resource_dir: Path to a folder containing the kubernetes yml template files.
        Defaults to the standard location jina.resources if not specified.
    :param logger: logger to use. Defaults to the default logger.
    :param params: dictionary for replacing the placeholders (keys) with the actual values.
    """

    from kubernetes.utils import FailToCreateError
    from kubernetes import utils

    yaml = _get_yaml(template, params, custom_resource_dir)
    fd, path = tempfile.mkstemp()
    try:
        with os.fdopen(fd, 'w') as tmp:
            tmp.write(yaml)
        try:
            utils.create_from_yaml(__k8s_clients.k8s_client, path)
        except FailToCreateError as e:
            for api_exception in e.api_exceptions:
                if api_exception.status == 409:
                    # The exception's body is the error response from the
                    # Kubernetes apiserver, it looks like:
                    # {..."message": "<resource> <name> already exists"...}
                    resp = json.loads(api_exception.body)
                    logger.info(f'🔁\t{resp["message"]}')
                else:
                    raise e
        except Exception as e2:
            raise e2
    finally:
        os.remove(path)
Exemple #7
0
class CrudIndexer(Executor):
    """Simple indexer class"""
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.logger = JinaLogger('CrudIndexer')
        self._docs = DocumentArray()
        self._dump_location = os.path.join(self.metas.workspace, 'docs')
        if os.path.exists(self._dump_location):
            self._docs = DocumentArray.load(self._dump_location)
            self.logger.info(
                f'Loaded {len(self._docs)} from {self._dump_location}')
        else:
            self.logger.info(f'No data found at {self._dump_location}')

    @requests(on='/index')
    def index(self, docs: 'DocumentArray', **kwargs):
        self._docs.extend(docs)

    @requests(on='/update')
    def update(self, docs: 'DocumentArray', **kwargs):
        self.delete(docs)
        self.index(docs)

    def close(self) -> None:
        self.logger.info(f'Dumping {len(self._docs)} to {self._dump_location}')
        self._docs.save(self._dump_location)

    @requests(on='/delete')
    def delete(self, docs: 'DocumentArray', **kwargs):
        # TODO we can do del _docs[d.id] once
        # tests.unit.types.arrays.test_documentarray.test_delete_by_id is fixed
        ids_to_delete = [d.id for d in docs]
        idx_to_delete = []
        for i, doc in enumerate(self._docs):
            if doc.id in ids_to_delete:
                idx_to_delete.append(i)
        for i in sorted(idx_to_delete, reverse=True):
            del self._docs[i]

    @requests(on='/search')
    def search(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
        top_k = int(parameters.get('top_k', 1))
        a = np.stack(docs.get_attributes('embedding'))
        b = np.stack(self._docs.get_attributes('embedding'))
        q_emb = _ext_A(_norm(a))
        d_emb = _ext_B(_norm(b))
        dists = _cosine(q_emb, d_emb)
        idx, dist = self._get_sorted_top_k(dists, top_k)
        for _q, _ids, _dists in zip(docs, idx, dist):
            for _id, _dist in zip(_ids, _dists):
                d = Document(self._docs[int(_id)], copy=True)
                d.scores['cosine'] = 1 - _dist
                _q.matches.append(d)

    @staticmethod
    def _get_sorted_top_k(dist: 'np.array',
                          top_k: int) -> Tuple['np.ndarray', 'np.ndarray']:
        if top_k >= dist.shape[1]:
            idx = dist.argsort(axis=1)[:, :top_k]
            dist = np.take_along_axis(dist, idx, axis=1)
        else:
            idx_ps = dist.argpartition(kth=top_k, axis=1)[:, :top_k]
            dist = np.take_along_axis(dist, idx_ps, axis=1)
            idx_fs = dist.argsort(axis=1)
            idx = np.take_along_axis(idx_ps, idx_fs, axis=1)
            dist = np.take_along_axis(dist, idx_fs, axis=1)

        return idx, dist
class MongoDBHandler:
    """
    Mongodb Handler to connect to the database & insert documents in the collection
    MongoDB has no access control by default, hence can be used without username:password.
    If username & password are passed, we need to create it (can be changed to existing un:pw)
    """
    def __init__(self,
                 hostname: str = '127.0.0.1',
                 port: int = 27017,
                 username: str = None,
                 password: str = None,
                 database: str = 'defaultdb',
                 collection: str = 'defaultcol'):
        self.logger = JinaLogger(self.__class__.__name__)
        self.hostname = hostname
        self.port = port
        self.username = username
        self.password = password
        self.database_name = database
        self.collection_name = collection
        if self.username and self.password:
            self.connection_string = \
                f'mongodb://{self.username}:{self.password}@{self.hostname}:{self.port}'
        else:
            self.connection_string = \
                f'mongodb://{self.hostname}:{self.port}'

    def __enter__(self):
        return self.connect()

    def connect(self) -> 'MongoDBHandler':
        import pymongo
        try:
            self.client = pymongo.MongoClient(self.connection_string)
            self.client.admin.command('ismaster')
            self.logger.info('Successfully connected to the database')
        except pymongo.errors.ConnectionFailure:
            raise MongoDBException('Database server is not available')
        except pymongo.errors.ConfigurationError:
            raise MongoDBException('Credentials passed are not correct!')
        except pymongo.errors.PyMongoError as exp:
            raise MongoDBException(exp)
        except Exception as exp:
            raise MongoDBException(exp)
        return self

    @property
    def database(self):
        return self.client[self.database_name]

    @property
    def collection(self):
        return self.database[self.collection_name]

    def find(self, key: int) -> Optional[bytes]:
        import pymongo
        try:
            cursor = self.collection.find({'_id': key})
            cursor_contents = list(cursor)
            if cursor_contents:
                return cursor_contents[0]
            return None
        except pymongo.errors.PyMongoError as exp:
            self.logger.error(
                f'Got an error while finding a document in the db {exp}')

    def insert(self, documents: Iterator[Dict]) -> Optional[str]:
        import pymongo
        try:
            result = self.collection.insert_many(documents)
            self.logger.debug(
                f'inserted {len(result.inserted_ids)} documents in the database'
            )
            return result.inserted_ids
        except pymongo.errors.PyMongoError as exp:
            self.logger.error(
                f'got an error while inserting a document in the db {exp}')

    def __exit__(self, exc_type, exc_val, exc_tb):
        import pymongo
        try:
            self.client.close()
        except pymongo.errors.PyMongoError as exp:
            raise MongoDBException(exp)

    def delete(self, keys: Iterator[int], *args, **kwargs):
        import pymongo
        try:
            count = self.collection.delete_many({
                '_id': {
                    '$in': list(keys)
                }
            }).deleted_count
            self.logger.debug(f'deleted {count} documents in the database')
        except pymongo.errors.PyMongoError as exp:
            self.logger.error(
                f'got an error while deleting a document in the db {exp}')

    def update(self, keys: Iterator[int], values: Iterator[bytes], *args,
               **kwargs):
        import pymongo
        try:
            # update_many updates several keys with the same op. / data.
            # we need this instead
            count = 0
            for k, new_doc in zip(keys, values):
                new_doc = {'_id': k, 'values': new_doc}
                inserted_doc = self.collection.find_one_and_replace({'_id': k},
                                                                    new_doc)
                if inserted_doc == new_doc:
                    count += 1
            self.logger.debug(f'updated {count} documents in the database')
            return
        except pymongo.errors.PyMongoError as exp:
            self.logger.error(
                f'got an error while updating documents in the db {exp}')
Exemple #9
0
def deploy_service(
    name: str,
    namespace: str,
    image_name: str,
    container_cmd: str,
    container_args: str,
    logger: JinaLogger,
    replicas: int,
    pull_policy: str,
    init_container: Dict = None,
    custom_resource_dir: Optional[str] = None,
    port_expose: Optional[int] = None,
) -> str:
    """Deploy service on Kubernetes.

    :param name: name of the service and deployment
    :param namespace: k8s namespace of the service and deployment
    :param image_name: image for the k8s deployment
    :param container_cmd: command executed on the k8s pods
    :param container_args: arguments used for the k8s pod
    :param logger: used logger
    :param replicas: number of replicas
    :param pull_policy: pull policy used for fetching the Docker images from the registry.
    :param init_container: additional arguments used for the init container
    :param custom_resource_dir: Path to a folder containing the kubernetes yml template files.
        Defaults to the standard location jina.resources if not specified.
    :param port_expose: port which will be exposed by the deployed containers
    :return: dns name of the created service
    """

    # we can always assume the ports are the same for all executors since they run on different k8s pods
    # port expose can be defined by the user
    if not port_expose:
        port_expose = 8080
    port_in = 8081
    port_out = 8082
    port_ctrl = 8083

    logger.info(
        f'🔋\tCreate Service for "{name}" with exposed port "{port_expose}"')
    kubernetes_tools.create(
        'service',
        {
            'name': name,
            'target': name,
            'namespace': namespace,
            'port_expose': port_expose,
            'port_in': port_in,
            'port_out': port_out,
            'port_ctrl': port_ctrl,
            'type': 'ClusterIP',
        },
        logger=logger,
        custom_resource_dir=custom_resource_dir,
    )

    logger.info(
        f'🐳\tCreate Deployment for "{name}" with image "{image_name}", replicas {replicas} and init_container {init_container is not None}'
    )

    if init_container:
        template_name = 'deployment-init'
    else:
        template_name = 'deployment'
        init_container = {}
    kubernetes_tools.create(
        template_name,
        {
            'name': name,
            'namespace': namespace,
            'image': image_name,
            'replicas': replicas,
            'command': container_cmd,
            'args': container_args,
            'port_expose': port_expose,
            'port_in': port_in,
            'port_out': port_out,
            'port_ctrl': port_ctrl,
            'pull_policy': pull_policy,
            **init_container,
        },
        logger=logger,
        custom_resource_dir=custom_resource_dir,
    )

    logger.info(f'🔑\tCreate necessary permissions"')

    kubernetes_tools.create(
        'connection-pool-role',
        {
            'namespace': namespace,
        },
    )

    kubernetes_tools.create(
        'connection-pool-role-binding',
        {
            'namespace': namespace,
        },
    )

    return f'{name}.{namespace}.svc'
Exemple #10
0
class DaemonWorker(Thread):
    """Worker Thread for JinaD"""
    def __init__(self, id: 'DaemonID', files: List[UploadFile], name: str,
                 *args, **kwargs) -> None:
        super().__init__(name=f'{self.__class__.__name__}{name}', daemon=True)
        self.id = id
        self.files = files
        self._logger = JinaLogger(self.name,
                                  workspace_path=self.workdir,
                                  **vars(jinad_args))
        self.start()

    @cached_property
    def arguments(self) -> WorkspaceArguments:
        """sets arguments in workspace store

        :return: pydantic model for workspace arguments
        """
        try:
            _args = store[self.id].arguments.copy(deep=True)
            _args.files.extend([f.filename
                                for f in self.files] if self.files else [])
            _args.jinad.update({
                'dockerfile': self.daemon_file.dockerfile,
            })
            _args.requirements = self.daemon_file.requirements
        except AttributeError:
            _args = WorkspaceArguments(
                files=[f.filename for f in self.files] if self.files else [],
                jinad={
                    'dockerfile': self.daemon_file.dockerfile,
                },
                requirements=self.daemon_file.requirements,
            )
        return _args

    @cached_property
    def metadata(self) -> WorkspaceMetadata:
        """sets metadata in workspace store

        :return: pydantic model for workspace metadata
        """
        image_id = self.generate_image()
        try:
            _metadata = store[self.id].metadata.copy(deep=True)
            _metadata.image_id = image_id
            _metadata.image_name = self.id.tag
        except AttributeError:
            _metadata = WorkspaceMetadata(
                image_id=image_id,
                image_name=self.id.tag,
                network=id_cleaner(self.network_id),
                workdir=self.workdir,
            )
        return _metadata

    @cached_property
    def workdir(self) -> str:
        """sets workdir for current worker thread

        :return: local directory where files would get stored
        """
        return get_workspace_path(self.id)

    @cached_property
    def daemon_file(self) -> DaemonFile:
        """set daemonfile for current worker thread

        :return: DaemonFile object representing current workspace
        """
        return DaemonFile(workdir=self.workdir, logger=self._logger)

    @cached_property
    def network_id(self) -> str:
        """create a docker network

        :return: network id
        """
        return Dockerizer.network(workspace_id=self.id)

    def generate_image(self) -> str:
        """build and create a docker image

        :return: image id
        """
        return Dockerizer.build(
            workspace_id=self.id,
            daemon_file=self.daemon_file,
            logger=JinaLogger(
                context=self.name,
                # identity=self.id,
                workspace_path=self.workdir,
            ),
        )

    @cached_property
    def container_id(self) -> Optional[str]:
        """creates a container if run command is passed in .jinad file

        :return: container id, if created
        """
        if self.daemon_file.run:
            container, _, _ = Dockerizer.run_custom(
                workspace_id=self.id, daemon_file=self.daemon_file)
            return id_cleaner(container.id)
        else:
            return None

    def run(self) -> None:
        """
        Method representing the worker thread's activity
        DaemonWorker is a daemon thread responsible for the following tasks:
        During create:
        - store uploaded files in a local workspace
        - create a docker network for the workspace which would be used by all child containers
        - build a docker image to be used by all child containers
        - create a container if `run` command is passed
        During update:
        - update files in the local workspace
        - removes the workspace container, if any
        - recreate workspace container, if `run` command is passed
        """
        try:
            store.update(
                id=self.id,
                value=RemoteWorkspaceState.UPDATING
                if store[self.id].arguments else RemoteWorkspaceState.CREATING,
            )
            store_files_in_workspace(workspace_id=self.id,
                                     files=self.files,
                                     logger=self._logger)
            store.update(
                id=self.id,
                value=WorkspaceItem(
                    state=RemoteWorkspaceState.UPDATING,
                    metadata=self.metadata,
                    arguments=self.arguments,
                ),
            )

            # this needs to be done after the initial update, otherwise run won't find the necessary metadata
            # If a container exists already, kill it before running again
            previous_container = store[self.id].metadata.container_id
            if previous_container:
                self._logger.info(
                    f'Deleting previous container {previous_container}')
                store[self.id].metadata.container_id = None
                del self.container_id
                Dockerizer.rm_container(previous_container)

            # Create a new container if necessary
            store[self.id].metadata.container_id = self.container_id
            store[self.id].state = RemoteWorkspaceState.ACTIVE

            self._logger.success(
                f'workspace {colored(str(self.id), "cyan")} is updated')
        except DockerNetworkException as e:
            store.update(id=self.id, value=RemoteWorkspaceState.FAILED)
            self._logger.error(
                f'Error while creating the docker network: {e!r}')
        except DockerImageException as e:
            store.update(id=self.id, value=RemoteWorkspaceState.FAILED)
            self._logger.error(f'Error while building the docker image: {e!r}')
        except Exception as e:
            # TODO: how to communicate errors to users? users track it via logs?
            # TODO: Handle cleanup in case of exception
            store.update(id=self.id, value=RemoteWorkspaceState.FAILED)
            self._logger.error(f'{e!r}')
class MongoDBHandler:
    """Mongodb Handler to connect to the database and can apply add, update, delete and query.
    MongoDB has no access control by default, hence it can be used without username:password.
    """
    def __init__(self,
                 hostname: str = '127.0.0.1',
                 port: int = 27017,
                 username: Optional[str] = None,
                 password: Optional[str] = None,
                 database: str = 'defaultdb',
                 collection: str = 'defaultcol'):
        self.logger = JinaLogger(self.__class__.__name__)
        self.hostname = hostname
        self.port = port
        self.username = username
        self.password = password
        self.database_name = database
        self.collection_name = collection
        if self.username and self.password:
            self.connection_string = \
                f'mongodb://{self.username}:{self.password}@{self.hostname}:{self.port}'
        else:
            self.connection_string = \
                f'mongodb://{self.hostname}:{self.port}'

    def __enter__(self):
        return self.connect()

    def connect(self) -> 'MongoDBHandler':
        """Connect to the database.
        """
        import pymongo
        try:
            self.client = pymongo.MongoClient(self.connection_string)
            self.client.admin.command('ismaster')
            self.logger.info('Successfully connected to the database')
        except pymongo.errors.ConnectionFailure:
            raise MongoDBException('Database server is not available')
        except pymongo.errors.ConfigurationError:
            raise MongoDBException('Credentials passed are not correct!')
        except pymongo.errors.PyMongoError as exp:
            raise MongoDBException(exp)
        except Exception as exp:
            raise MongoDBException(exp)
        return self

    @property
    def database(self) -> 'Database':
        """ Get database. """
        return self.client[self.database_name]

    @property
    def collection(self) -> 'Collection':
        """ Get collection. """
        return self.database[self.collection_name]

    def query(self, key: str) -> Optional[bytes]:
        """ Queries the related document for the provided ``key``.

        :param key: id of the document
        """
        import pymongo
        try:
            cursor = self.collection.find({'_id': key})
            cursor_contents = list(cursor)
            if cursor_contents:
                return cursor_contents[0]
            return None
        except pymongo.errors.PyMongoError as exp:
            raise Exception(
                f'Got an error while finding a document in the db {exp}')

    def add(self, documents: Iterable[Dict]) -> Optional[str]:
        """ Insert the documents into the database.

        :param documents: documents to be inserted
        """
        import pymongo
        try:
            result = self.collection.insert_many(documents)
            self.logger.debug(
                f'inserted {len(result.inserted_ids)} documents in the database'
            )
            return result.inserted_ids
        except pymongo.errors.PyMongoError as exp:
            raise Exception(
                f'got an error while inserting a document in the db {exp}')

    def __exit__(self, *args):
        """ Make sure the connection to the database is closed.
        """
        import pymongo
        try:
            self.client.close()
        except pymongo.errors.PyMongoError as exp:
            raise MongoDBException(exp)

    def delete(self, keys: Iterable[str], *args, **kwargs):
        """Delete documents from the indexer.

        :param keys: document ids to delete related documents
        """
        import pymongo
        try:
            count = self.collection.delete_many({
                '_id': {
                    '$in': list(keys)
                }
            }).deleted_count
            self.logger.debug(f'deleted {count} documents in the database')
        except pymongo.errors.PyMongoError as exp:
            raise Exception(
                f'got an error while deleting a document in the db {exp}')

    def update(self, keys: Iterable[str], values: Iterable[bytes]) -> None:
        """ Update the documents on the database.

        :param keys: document ids
        :param values: serialized documents
        """
        import pymongo
        try:
            # update_many updates several keys with the same op. / data.
            # we need this instead
            count = 0
            for k, new_doc in zip(keys, values):
                new_doc = {'_id': k, 'values': new_doc}
                inserted_doc = self.collection.find_one_and_replace({'_id': k},
                                                                    new_doc)
                if inserted_doc == new_doc:
                    count += 1
            self.logger.debug(f'updated {count} documents in the database')
            return
        except pymongo.errors.PyMongoError as exp:
            raise Exception(
                f'got an error while updating documents in the db {exp}')
class PostgreSQLDBMSHandler:
    """
    Postgres Handler to connect to the database and can apply add, update, delete and query.

    :param hostname: hostname of the machine
    :param port: the port
    :param username: the username to authenticate
    :param password: the password to authenticate
    :param database: the database name
    :param collection: the collection name
    :param args: other arguments
    :param kwargs: other keyword arguments
    """
    def __init__(self,
                 hostname: str = '127.0.0.1',
                 port: int = 5432,
                 username: str = 'default_name',
                 password: str = 'default_pwd',
                 database: str = 'postgres',
                 table: Optional[str] = 'default_table',
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = JinaLogger(self.__class__.__name__)
        self.hostname = hostname
        self.port = port
        self.username = username
        self.password = password
        self.database = database
        self.table = table

    def __enter__(self):
        return self.connect()

    def connect(self) -> 'PostgreSQLDBMSHandler':
        """Connect to the database. """

        import psycopg2
        from psycopg2 import Error

        try:
            self.connection = psycopg2.connect(user=self.username,
                                               password=self.password,
                                               database=self.database,
                                               host=self.hostname,
                                               port=self.port)
            self.cursor = self.connection.cursor()
            self.logger.info('Successfully connected to the database')
            self.use_table()
            self.connection.commit()
        except (Exception, Error) as error:
            self.logger.error('Error while connecting to PostgreSQL', error)
        return self

    def use_table(self):
        """
        Use table if exists or create one if it doesn't.

        Create table if needed with id, vecs and metas.
        """
        from psycopg2 import Error

        self.cursor.execute(
            'select exists(select * from information_schema.tables where table_name=%s)',
            (self.table, ))
        if self.cursor.fetchone()[0]:
            self.logger.info('Using existing table')
        else:
            try:
                self.cursor.execute(f"CREATE TABLE {self.table} ( \
                    ID VARCHAR PRIMARY KEY,  \
                    VECS BYTEA,  \
                    METAS BYTEA);")
                self.logger.info('Successfully created table')
            except (Exception, Error) as error:
                self.logger.error('Error while creating table!')

    def add(self, ids, vecs, metas, *args, **kwargs):
        """ Insert the documents into the database.

        :param ids: List of doc ids to be added
        :param vecs: List of vecs to be added
        :param metas: List of metas of docs to be added
        :param args: other arguments
        :param kwargs: other keyword arguments
        :param args: other arguments
        :param kwargs: other keyword arguments
        :return record: List of Document's id added
        """
        row_count = 0
        for i in range(len(ids)):
            self.cursor.execute(
                f'INSERT INTO {self.table} (ID, VECS, METAS) VALUES (%s, %s, %s)',
                (ids[i], vecs[i].tobytes(), metas[i]),
            )
            row_count += self.cursor.rowcount
        self.connection.commit()
        return row_count

    def update(self, ids, vecs, metas, *args, **kwargs):
        """ Updated documents from the database.

        :param ids: Ids of Doc to be updated
        :param vecs: List of vecs to be updated
        :param metas: List of metas of docs to be updated
        :param args: other arguments
        :param kwargs: other keyword arguments
        :return record: List of Document's id after update
        """
        row_count = 0

        for i in range(len(ids)):
            self.cursor.execute(
                f'UPDATE {self.table} SET VECS = %s, METAS = %s WHERE ID = %s',
                (vecs[i].tobytes(), metas[i], ids[i]),
            )
            row_count += self.cursor.rowcount
        self.connection.commit()
        return row_count

    def delete(self, ids, *args, **kwargs):
        """ Delete document from the database.

        :param ids: ids of Documents to be removed
        :param args: other arguments
        :param kwargs: other keyword arguments
        :return record: List of Document's id after deletion
         """
        row_count = 0
        for id in ids:
            self.cursor.execute(f'DELETE FROM {self.table} where (ID) = (%s);',
                                (id, ))
            row_count += self.cursor.rowcount
        self.connection.commit()
        return row_count

    def __exit__(self, *args):
        """ Make sure the connection to the database is closed."""

        from psycopg2 import Error
        try:
            self.connection.close()
            self.cursor.close()
            self.logger.info('PostgreSQL connection is closed')
        except (Exception, Error) as error:
            self.logger.error('Error while closing: ', error)
Exemple #13
0
class MongoDBHandler:
    """
    Mongodb Handler to connect to the database & insert documents in the collection
    MongoDB has no access control by default, hence can be used without username:password.
    If username & password are passed, we need to create it (can be changed to existing un:pw)
    """
    def __init__(self,
                 hostname: str = '127.0.0.1',
                 port: int = 27017,
                 username: str = None,
                 password: str = None,
                 database: str = 'defaultdb',
                 collection: str = 'defaultcol'):
        self.logger = JinaLogger(self.__class__.__name__)
        self.hostname = hostname
        self.port = port
        self.username = username
        self.password = password
        self.database_name = database
        self.collection_name = collection
        if self.username and self.password:
            self.connection_string = \
                f'mongodb://{self.username}:{self.password}@{self.hostname}:{self.port}'
        else:
            self.connection_string = \
                f'mongodb://{self.hostname}:{self.port}'

    def __enter__(self):
        return self.connect()

    def connect(self) -> 'MongoDBHandler':
        import pymongo
        try:
            self.client = pymongo.MongoClient(self.connection_string)
            self.client.admin.command('ismaster')
            self.logger.info('Successfully connected to the database')
        except pymongo.errors.ConnectionFailure:
            raise MongoDBException('Database server is not available')
        except pymongo.errors.ConfigurationError:
            raise MongoDBException('Credentials passed are not correct!')
        except pymongo.errors.PyMongoError as exp:
            raise MongoDBException(exp)
        except Exception as exp:
            raise MongoDBException(exp)
        return self

    @property
    def database(self):
        return self.client[self.database_name]

    @property
    def collection(self):
        return self.database[self.collection_name]

    def find(self, query: Dict[str, Union[Dict, List]]) -> None:
        import pymongo
        try:
            return self.collection.find(query)
        except pymongo.errors.PyMongoError as exp:
            self.logger.error(
                f'Got an error while finding a document in the db {exp}')

    def insert(self, documents: Iterator[Dict]) -> Optional[str]:
        import pymongo
        try:
            result = self.collection.insert_many(documents)
            self.logger.debug(f'inserted documents in the database')
            return result.inserted_ids
        except pymongo.errors.PyMongoError as exp:
            self.logger.error(
                f'got an error while inserting a document in the db {exp}')

    def __exit__(self, exc_type, exc_val, exc_tb):
        import pymongo
        try:
            self.client.close()
        except pymongo.errors.PyMongoError as exp:
            raise MongoDBException(exp)