예제 #1
0
def _get_service(client):
    if client._api_key:
        api_client = APIClient(client._api_key)
    else:
        api_client = APIClient()
    service = api_client.services.get(client._service_id)
    return service
예제 #2
0
def create_client_mock(cache=TEST_SPEC_ALL):
    """Create an APIClient mock from a cache of the API spec

    Parameters
    ----------
    cache : str, optional
        Location of the API spec on the local filesystem

    Returns
    -------
    mock.Mock
        A `Mock` object which looks like an APIClient and which will
        error if any method calls have non-existent / misspelled parameters
    """
    # Create a client from the cache. We'll use this for auto-speccing.
    real_client = APIClient(local_api_spec=cache,
                            api_key='none',
                            resources='all')
    real_client._feature_flags = {'noflag': None}
    if hasattr(real_client, 'channels'):
        # Deleting "channels" causes the client to fall back on
        # regular polling for completion, which greatly eases testing.
        delattr(real_client, 'channels')

    # Prevent the client from trying to talk to the real API when autospeccing
    with mock.patch('requests.Session', mock.MagicMock):
        mock_client = mock.create_autospec(real_client, spec_set=True)

    return mock_client
예제 #3
0
def query_civis(sql,
                database,
                api_key=None,
                client=None,
                credential_id=None,
                preview_rows=10,
                polling_interval=None,
                hidden=True):
    """Execute a SQL statement as a Civis query.

    Run a query that may return no results or where only a small
    preview is required. To execute a query that returns a large number
    of rows, see :func:`~civis.io.read_civis_sql`.

    Parameters
    ----------
    sql : str
        The SQL statement to execute.
    database : str or int
        The name or ID of the database.
    api_key : DEPRECATED str, optional
        Your Civis API key. If not given, the :envvar:`CIVIS_API_KEY`
        environment variable will be used.
    client : :class:`civis.APIClient`, optional
        If not provided, an :class:`civis.APIClient` object will be
        created from the :envvar:`CIVIS_API_KEY`.
    credential_id : str or int, optional
        The ID of the database credential. If ``None``, the default
        credential will be used.
    preview_rows : int, optional
        The maximum number of rows to return. No more than 100 rows can be
        returned at once.
    polling_interval : int or float, optional
        Number of seconds to wait between checks for query completion.
    hidden : bool, optional
        If ``True`` (the default), this job will not appear in the Civis UI.

    Returns
    -------
    results : :class:`~civis.futures.CivisFuture`
        A `CivisFuture` object.

    Examples
    --------
    >>> run = query_civis(sql="DELETE schema.table", database='database')
    >>> run.result()  # Wait for query to complete
    """
    if client is None:
        client = APIClient(api_key=api_key)
    database_id = client.get_database_id(database)
    cred_id = credential_id or client.default_credential
    resp = client.queries.post(database_id,
                               sql,
                               preview_rows,
                               credential=cred_id,
                               hidden=hidden)
    return CivisFuture(client.queries.get, (resp.id, ),
                       polling_interval,
                       client=client,
                       poll_on_creation=False)
예제 #4
0
def _real_client(local_api_spec):
    real_client = APIClient(local_api_spec=local_api_spec, api_key='none')
    real_client._feature_flags = {'noflag': None}
    if hasattr(real_client, 'channels'):
        # Deleting "channels" causes the client to fall back on
        # regular polling for completion, which greatly eases testing.
        delattr(real_client, 'channels')
    return real_client
예제 #5
0
def query_civis(sql,
                database,
                api_key=None,
                credential_id=None,
                preview_rows=10,
                polling_interval=_DEFAULT_POLLING_INTERVAL):
    """Execute a SQL statement as a Civis query.

    Run a query that may return no results or where only a small
    preview is required. To execute a query that returns a large number
    of rows, see :func:`~civis.io.read_civis_sql`.

    Parameters
    ----------
    sql : str
        The SQL statement to execute.
    database : str or int
        The name or ID of the database.
    api_key : str, optional
        Your Civis API key. If not given, the :envvar:`CIVIS_API_KEY`
        environment variable will be used.
    credential_id : str or int, optional
        The ID of the database credential. If ``None``, the default
        credential will be used.
    preview_rows : int, optional
        The maximum number of rows to return. No more than 100 rows can be
        returned at once.
    polling_interval : int or float, optional
        Number of seconds to wait between checks for query completion.

    Returns
    -------
    results : :class:`~civis.polling.PollableResult`
        A `PollableResult` object.

    Examples
    --------
    >>> run = query_civis(sql="DELETE schema.table", database='database')
    >>> run.result()  # Wait for query to complete
    """
    client = APIClient(api_key=api_key)
    database_id = client.get_database_id(database)
    cred_id = credential_id or client.default_credential
    resp = client.queries.post(database_id,
                               sql,
                               preview_rows,
                               credential=cred_id)
    return PollableResult(client.queries.get, (resp.id, ), polling_interval)
예제 #6
0
def run_job(job_id, api_key=None, client=None):
    """Run a job.

    Parameters
    ----------
    job_id : str or int
        The ID of the job.
    api_key : DEPRECATED str, optional
        Your Civis API key. If not given, the :envvar:`CIVIS_API_KEY`
        environment variable will be used.
    client : :class:`civis.APIClient`, optional
        If not provided, an :class:`civis.APIClient` object will be
        created from the :envvar:`CIVIS_API_KEY`.

    Returns
    -------
    results : :class:`~civis.futures.CivisFuture`
        A `CivisFuture` object.
    """
    if client is None:
        client = APIClient(api_key=api_key)
    run = client.jobs.post_runs(job_id)
    return CivisFuture(client.jobs.get_runs, (job_id, run['id']),
                       client=client,
                       poll_on_creation=False)
예제 #7
0
 def __setstate__(self, state):
     self.__dict__ = state
     self._condition = threading.Condition()
     self.client = APIClient(resources='all')
     self.poller = self.client.scripts.get_containers_runs
     self._begin_tracking()
     self.add_done_callback(self._set_model_exception)
예제 #8
0
def _load_table_from_outputs(job_id, run_id, filename, client=None,
                             **table_kwargs):
    """Load a table from a run output directly into a ``DataFrame``"""
    client = APIClient(resources='all') if client is None else client
    file_id = cio.file_id_from_run_output(filename, job_id, run_id,
                                          client=client, regex=True)
    return cio.file_to_dataframe(file_id, client=client, **table_kwargs)
예제 #9
0
def _unshare_model(job_id, entity_id, entity_type, client=None):
    """Revoke permissions on a container job and all run outputs
    for the requested entity (singular)
    """
    client = client or APIClient()
    if entity_type not in ['groups', 'users']:
        raise ValueError("'entity_type' must be one of ['groups', 'users']. "
                         "Got '{0}'.".format(entity_type))

    log.debug("Revoking permissions on object %d for %s %s.",
              job_id, entity_type, entity_id)
    _func = getattr(client.scripts, "delete_containers_shares_" + entity_type)
    result = _func(job_id, entity_id)

    # CivisML relies on several run outputs attached to each model run.
    # Go through and revoke permissions for outputs on each run.
    runs = client.scripts.list_containers_runs(job_id, iterator=True)
    endpoint_name = "delete_shares_" + entity_type
    for run in runs:
        log.debug("Unsharing outputs on %d, run %s.", job_id, run.id)
        outputs = client.scripts.list_containers_runs_outputs(job_id, run.id)
        for _output in outputs:
            if _output['object_type'] == 'File':
                _func = getattr(client.files, endpoint_name)
            elif _output['object_type'] == 'Project':
                _func = getattr(client.projects, endpoint_name)
            elif _output['object_type'] == 'JSONValue':
                _func = getattr(client.json_values, endpoint_name)
            else:
                log.debug("Found run output of type %s, ID %s; not unsharing "
                          "it.", _output['object_type'],  _output['object_id'])
                continue
            _func(_output['object_id'], entity_id)

    return result
예제 #10
0
    def __init__(self,
                 script_name=None,
                 hidden=True,
                 max_n_retries=0,
                 client=None,
                 polling_interval=None,
                 inc_script_names=False):
        self.max_n_retries = max_n_retries
        self.hidden = hidden
        self.script_name = script_name
        self.polling_interval = polling_interval
        self.inc_script_names = inc_script_names
        self._script_name_counter = 0

        self._shutdown_lock = threading.Lock()
        self._shutdown_thread = False

        self.script_name = script_name

        if client is None:
            client = APIClient(resources='all')
        self.client = client

        # A list of ContainerFuture objects for submitted jobs.
        self._futures = []
예제 #11
0
def _file_to_civis(buf, name, api_key=None, client=None, **kwargs):
    if client is None:
        client = APIClient(api_key=api_key)

    file_size = _buf_len(buf)
    if not file_size:
        log.warning('Could not determine file size; defaulting to '
                    'single post. Files over 5GB will fail.')

    # determine if file-like object is seekable
    try:
        buf.seek(buf.tell())
        is_seekable = True
    except io.UnsupportedOperation:
        is_seekable = False

    if not file_size:
        return _single_upload(buf, name, is_seekable, client, **kwargs)
    elif file_size > MAX_FILE_SIZE:
        msg = "File is greater than the maximum allowable file size (5TB)"
        raise ValueError(msg)
    elif not is_seekable and file_size > MAX_PART_SIZE:
        msg = "Cannot perform multipart upload on non-seekable files. " \
              "File is greater than the maximum allowable part size (5GB)"
        raise ValueError(msg)
    elif file_size <= MIN_MULTIPART_SIZE or not is_seekable:
        return _single_upload(buf, name, is_seekable, client, **kwargs)
    else:
        return _multipart_upload(buf, name, file_size, client, **kwargs)
예제 #12
0
def run_job(job_id, api_key=None, client=None, polling_interval=None):
    """Run a job.

    Parameters
    ----------
    job_id: str or int
        The ID of the job.
    api_key: DEPRECATED str, optional
        Your Civis API key. If not given, the :envvar:`CIVIS_API_KEY`
        environment variable will be used.
    client: :class:`civis.APIClient`, optional
        If not provided, an :class:`civis.APIClient` object will be
        created from the :envvar:`CIVIS_API_KEY`.
    polling_interval : int or float, optional
        The number of seconds between API requests to check whether a result
        is ready.

    Returns
    -------
    results: :class:`~civis.futures.CivisFuture`
        A `CivisFuture` object.
    """
    if client is None:
        client = APIClient(api_key=api_key)
    run = client.jobs.post_runs(job_id)
    return CivisFuture(
        client.jobs.get_runs,
        (job_id, run["id"]),
        client=client,
        polling_interval=polling_interval,
        poll_on_creation=False,
    )
예제 #13
0
def _civis_to_file(file_id, buf, api_key=None, client=None):
    if client is None:
        client = APIClient(api_key=api_key)
    files_response = client.files.get(file_id)
    url = files_response.file_url
    if not url:
        raise EmptyResultError('Unable to locate file {}. If it previously '
                               'existed, it may have '
                               'expired.'.format(file_id))

    # Store the current buffer position in case we need to retry below.
    buf_orig_position = buf.tell()

    @retry(RETRY_EXCEPTIONS)
    def _download_url_to_buf():
        # Reset the buffer in case we had to retry.
        buf.seek(buf_orig_position)

        response = requests.get(url, stream=True)
        response.raise_for_status()
        chunked = response.iter_content(CHUNK_SIZE)
        for lines in chunked:
            buf.write(lines)

    _download_url_to_buf()
예제 #14
0
 def test_feature_flags_memoized(self, *mocks):
     client = APIClient()
     with mock.patch.object(client.users,
                            'list_me',
                            wraps=client.users.list_me):
         client.feature_flags
         client.feature_flags
         self.assertEqual(client.users.list_me.call_count, 1)
예제 #15
0
def _import_bytes(buf, database, table, api_key, max_errors,
                  existing_table_rows, distkey, sortkey1, sortkey2, delimiter,
                  headers, credential_id, polling_interval, archive, hidden):
    client = APIClient(api_key=api_key)
    schema, table = table.split(".", 1)
    db_id = client.get_database_id(database)
    cred_id = credential_id or client.default_credential
    delimiter = DELIMITERS.get(delimiter)
    assert delimiter, "delimiter must be one of {}".format(DELIMITERS.keys())

    kwargs = dict(schema=schema,
                  name=table,
                  remote_host_id=db_id,
                  credential_id=cred_id,
                  max_errors=max_errors,
                  existing_table_rows=existing_table_rows,
                  distkey=distkey,
                  sortkey1=sortkey1,
                  sortkey2=sortkey2,
                  column_delimiter=delimiter,
                  first_row_is_header=headers,
                  hidden=hidden)

    import_job = client.imports.post_files(**kwargs)
    put_response = requests.put(import_job.upload_uri, buf)

    put_response.raise_for_status()
    run_job_result = client._session.post(import_job.run_uri)
    run_job_result.raise_for_status()
    run_info = run_job_result.json()
    fut = CivisFuture(client.imports.get_files_runs,
                      (run_info['importId'], run_info['id']),
                      polling_interval=polling_interval,
                      api_key=api_key,
                      poll_on_creation=False)
    if archive:

        def f(x):
            return client.imports.put_archive(import_job.id, True)

        fut.add_done_callback(f)
    return fut
예제 #16
0
def list_models(job_type="train", author=SENTINEL, client=None, **kwargs):
    """List a user's CivisML models.

    Parameters
    ----------
    job_type : {"train", "predict", None}
        The type of model job to list. If "train", list training jobs
        only (including registered models trained outside of CivisML).
        If "predict", list prediction jobs only. If None, list both.
    author : int, optional
        User id of the user whose models you want to list. Defaults to
        the current user. Use ``None`` to list models from all users.
    client : :class:`civis.APIClient`, optional
        If not provided, an :class:`civis.APIClient` object will be
        created from the :envvar:`CIVIS_API_KEY`.
    **kwargs : kwargs
        Extra keyword arguments passed to `client.scripts.list_custom()`

    See Also
    --------
    APIClient.scripts.list_custom
    """
    if job_type == "train":
        job_types = ('training',)
    elif job_type == "predict":
        job_types = ('prediction',)
    elif job_type is None:
        job_types = ('training', 'prediction')
    else:
        raise ValueError("Parameter 'job_type' must be None, 'train', "
                         "or 'predict'.")

    if client is None:
        client = APIClient()

    template_id_list = [
        ids[_job_type]
        for _job_type in job_types
        for ids in _get_template_ids_all_versions(client).values()
    ]
    # Applying set() because we don't want repeated IDs
    # between the version-less production alias and the versioned alias.
    template_id_str = ', '.join(str(tmp) for tmp in set(template_id_list))

    if author is SENTINEL:
        author = client.users.list_me().id

    # default to showing most recent models first
    kwargs.setdefault('order_dir', 'desc')

    models = client.scripts.list_custom(from_template_id=template_id_str,
                                        author=author,
                                        **kwargs)
    return models
예제 #17
0
 def _pubnub_config(self):
     client = APIClient(api_key=self.api_key, resources='all')
     channel_config = client.channels.list()
     channels = [channel['name'] for channel in channel_config['channels']]
     pnconfig = PNConfiguration()
     pnconfig.subscribe_key = channel_config['subscribe_key']
     pnconfig.cipher_key = channel_config['cipher_key']
     pnconfig.auth_key = channel_config['auth_key']
     pnconfig.ssl = True
     pnconfig.reconnection_policy = PNReconnectionPolicy.LINEAR
     return pnconfig, channels
예제 #18
0
 def __setstate__(self, state):
     self.__dict__ = state
     self._condition = threading.Condition()
     self.client = APIClient(resources='all')
     if getattr(self, '_pubnub', None) is True:
         # Re-subscribe to notifications channel
         self._pubnub = self._subscribe(*self._pubnub_config())
     self._polling_thread = _ResultPollingThread(self._check_result, (),
                                                 self.polling_interval)
     self.poller = self.client.scripts.get_containers_runs
     self.add_done_callback(self._set_model_exception)
예제 #19
0
def file_to_dataframe(file_id,
                      compression='infer',
                      client=None,
                      **read_kwargs):
    """Load a :class:`~pandas.DataFrame` from a CSV stored in a Civis File

    The :class:`~pandas.DataFrame` will be read directly from Civis
    without copying the CSV to a local file on disk.

    Parameters
    ----------
    file_id : int
        ID of a Civis File which contains a CSV
    compression : str, optional
        If "infer", set the ``compression`` argument of ``pandas.read_csv``
        based on the file extension of the name of the Civis File.
        Otherwise pass this argument to ``pandas.read_csv``.
    client : :class:`civis.APIClient`, optional
        If not provided, an :class:`civis.APIClient` object will be
        created from the :envvar:`CIVIS_API_KEY`.
    **read_kwargs
        Additional arguments will be passed directly to
        :func:`~pandas.read_csv`.

    Returns
    -------
    :class:`~pandas.DataFrame` containing the contents of the CSV

    Raises
    ------
    ImportError
        If ``pandas`` is not available

    See Also
    --------
    pandas.read_csv
    """
    if not HAS_PANDAS:
        raise ImportError('file_to_dataframe requires pandas to be installed.')
    client = APIClient() if client is None else client
    file_info = client.files.get(file_id)
    file_url = file_info.file_url
    if not file_url:
        raise EmptyResultError('Unable to locate file {}. If it previously '
                               'existed, it may have '
                               'expired.'.format(file_id))
    file_name = file_info.name
    if compression == 'infer':
        comp_exts = {'.gz': 'gzip', '.xz': 'xz', '.bz2': 'bz2', '.zip': 'zip'}
        ext = os.path.splitext(file_name)[-1]
        if ext in comp_exts:
            compression = comp_exts[ext]

    return pd.read_csv(file_url, compression=compression, **read_kwargs)
예제 #20
0
def run_template(id, arguments, JSONValue=False, client=None):
    """Run a template and return the results.

    Parameters
    ----------
    id: int
        The template id to be run.
    arguments: dict
        Dictionary of arguments to be passed to the template.
    JSONValue: bool, optional
        If True, will return the JSON output of the template.
        If False, will return the file ids associated with the
        output results.
    client: :class:`civis.APIClient`, optional
        If not provided, an :class:`civis.APIClient` object will be
        created from the :envvar:`CIVIS_API_KEY`.

    Returns
    -------
    output: dict
        If JSONValue = False, dictionary of file ids with the keys
        being their output names.
        If JSONValue = True, JSON dict containing the results of the
        template run. Expects only a single JSON result. Will return
        nothing if either there is no JSON result or there is more
        than 1 JSON result.


    """
    if client is None:
        client = APIClient()
    job = client.scripts.post_custom(id, arguments=arguments)
    run = client.scripts.post_custom_runs(job.id)
    fut = CivisFuture(client.scripts.get_custom_runs, (job.id, run.id),
                      client=client)
    fut.result()
    outputs = client.scripts.list_custom_runs_outputs(job.id, run.id)
    if JSONValue:
        json_output = [
            o.value for o in outputs if o.object_type == "JSONValue"
        ]
        if len(json_output) == 0:
            log.warning("No JSON output for template {}".format(id))
            return
        if len(json_output) > 1:
            log.warning("More than 1 JSON output for template {}"
                        " -- returning only the first one.".format(id))
        # Note that the cast to a dict is to convert
        # an expected Response object.
        return dict(json_output[0])
    else:
        file_ids = {o.name: o.object_id for o in outputs}
        return file_ids
예제 #21
0
 def __init__(self, job_id, run_id,
              max_n_retries=0,
              polling_interval=None,
              client=None,
              poll_on_creation=True):
     if client is None:
         client = APIClient()
     super().__init__(client.scripts.get_containers_runs,
                      [int(job_id), int(run_id)],
                      polling_interval=polling_interval,
                      client=client,
                      poll_on_creation=poll_on_creation)
     self._max_n_retries = max_n_retries
예제 #22
0
def _civis_to_file(file_id, buf, api_key=None, client=None):
    if client is None:
        client = APIClient(api_key=api_key)
    files_response = client.files.get(file_id)
    url = files_response.file_url
    if not url:
        raise EmptyResultError('Unable to locate file {}. If it previously '
                               'existed, it may have '
                               'expired.'.format(file_id))
    response = requests.get(url, stream=True)
    response.raise_for_status()
    chunked = response.iter_content(CHUNK_SIZE)
    for lines in chunked:
        buf.write(lines)
예제 #23
0
def list_models(job_type="train", author=SENTINEL, client=None, **kwargs):
    """List a user's CivisML models.

    Parameters
    ----------
    job_type : {"train", "predict", None}
        The type of model job to list. If "train", list training jobs
        only (including registered models trained outside of CivisML).
        If "predict", list prediction jobs only. If None, list both.
    author : int, optional
        User id of the user whose models you want to list. Defaults to
        the current user. Use ``None`` to list models from all users.
    client : :class:`civis.APIClient`, optional
        If not provided, an :class:`civis.APIClient` object will be
        created from the :envvar:`CIVIS_API_KEY`.
    **kwargs : kwargs
        Extra keyword arguments passed to `client.scripts.list_custom()`

    See Also
    --------
    APIClient.scripts.list_custom
    """
    if job_type == "train":
        template_id_list = list(_PRED_TEMPLATES.keys())
    elif job_type == "predict":
        # get a unique list of prediction ids
        template_id_list = list(set(_PRED_TEMPLATES.values()))
    elif job_type is None:
        # use sets to make sure there's no duplicate ids
        template_id_list = list(
            set(_PRED_TEMPLATES.keys()).union(set(_PRED_TEMPLATES.values())))
    else:
        raise ValueError("Parameter 'job_type' must be None, 'train', "
                         "or 'predict'.")
    template_id_str = ', '.join([str(tmp) for tmp in template_id_list])

    if client is None:
        client = APIClient()

    if author is SENTINEL:
        author = client.users.list_me().id

    # default to showing most recent models first
    kwargs.setdefault('order_dir', 'desc')

    models = client.scripts.list_custom(from_template_id=template_id_str,
                                        author=author,
                                        **kwargs)
    return models
예제 #24
0
    def __init__(self, poller, poller_args,
                 polling_interval=None, api_key=None, client=None,
                 poll_on_creation=True):
        if client is None:
            client = APIClient(api_key=api_key)

        if polling_interval is None and hasattr(client, 'channels'):
            polling_interval = _LONG_POLLING_INTERVAL

        super().__init__(poller=poller,
                         poller_args=poller_args,
                         polling_interval=polling_interval,
                         api_key=api_key,
                         client=client,
                         poll_on_creation=poll_on_creation)
예제 #25
0
def _share_model(job_id,
                 entity_ids,
                 permission_level,
                 entity_type,
                 client=None,
                 **kwargs):
    """Share a container job and all run outputs with requested entities"""
    client = client or APIClient()
    if entity_type not in ['groups', 'users']:
        raise ValueError("'entity_type' must be one of ['groups', 'users']. "
                         "Got '{0}'.".format(entity_type))

    log.debug("Sharing object %d with %s %s at permission level %s.", job_id,
              entity_type, entity_ids, permission_level)
    _func = getattr(client.scripts, "put_containers_shares_" + entity_type)
    result = _func(job_id, entity_ids, permission_level, **kwargs)

    # CivisML relies on several run outputs attached to each model run.
    # Go through and share all outputs on each run.
    runs = client.scripts.list_containers_runs(job_id, iterator=True)
    for run in runs:
        log.debug("Sharing outputs on %d, run %s.", job_id, run.id)
        outputs = client.scripts.list_containers_runs_outputs(job_id, run.id)
        for _output in outputs:
            if _output['object_type'] == 'File':
                _func = getattr(client.files, "put_shares_" + entity_type)
                obj_permission = permission_level
            elif _output['object_type'] == 'Project':
                _func = getattr(client.projects, "put_shares_" + entity_type)
                if permission_level == 'read':
                    # Users must be able to add to projects to use the model
                    obj_permission = 'write'
                else:
                    obj_permission = permission_level
            elif _output['object_type'] == 'JSONValue':
                _func = getattr(client.json_values,
                                "put_shares_" + entity_type)
                obj_permission = permission_level
            else:
                log.debug(
                    "Found a run output of type %s, ID %s; not sharing "
                    "it.", _output['object_type'], _output['object_id'])
                continue
            _oid = _output['object_id']
            # Don't send share emails for any of the run outputs.
            _func(_oid, entity_ids, obj_permission, send_shared_email=False)

    return result
예제 #26
0
def _file_to_civis(buf, name, api_key=None, client=None, **kwargs):
    if client is None:
        client = APIClient(api_key=api_key)

    file_size = _buf_len(buf)
    if not file_size:
        log.warning('Could not determine file size; defaulting to '
                    'single post. Files over 5GB will fail.')

    if not file_size or file_size <= MIN_MULTIPART_SIZE:
        return _single_upload(buf, name, client, **kwargs)
    elif file_size > MAX_FILE_SIZE:
        msg = "File is greater than the maximum allowable file size (5TB)"
        raise ValueError(msg)
    else:
        return _multipart_upload(buf, name, file_size, client, **kwargs)
예제 #27
0
    def __init__(self, model, dependent_variable,
                 primary_key=None, parameters=None,
                 cross_validation_parameters=None, model_name=None,
                 calibration=None, excluded_columns=None, client=None,
                 cpu_requested=None, memory_requested=None,
                 disk_requested=None, notifications=None,
                 dependencies=None, git_token_name=None, verbose=False,
                 etl=None):
        self.model = model
        self._input_model = model  # In case we need to modify the input
        if isinstance(dependent_variable, str):
            # Standardize the dependent variable as a list.
            dependent_variable = [dependent_variable]
        self.dependent_variable = dependent_variable

        # optional but common parameters
        self.primary_key = primary_key
        self.parameters = parameters or {}
        self.cv_params = cross_validation_parameters or {}
        self.model_name = model_name  # None lets Platform use template name
        self.excluded_columns = excluded_columns
        self.calibration = calibration
        self.job_resources = {'REQUIRED_CPU': cpu_requested,
                              'REQUIRED_MEMORY': memory_requested,
                              'REQUIRED_DISK_SPACE': disk_requested}
        self.notifications = notifications or {}
        self.dependencies = dependencies
        self.git_token_name = git_token_name
        self.verbose = verbose

        if client is None:
            client = APIClient(resources='all')
        self._client = client
        self.train_result_ = None

        self._set_template_version(client)

        if _NEWEST_CIVISML_VERSION:
            self.etl = etl
        elif not _NEWEST_CIVISML_VERSION and etl is not None:
            raise NotImplementedError("The etl argument is not implemented"
                                      " in this version of CivisML.")

        else:
            # fall back to previous version templates
            self.train_template_id = self._train_template_id_fallback
            self.predict_template_id = self._predict_template_id_fallback
예제 #28
0
    def __init__(self,
                 poller,
                 poller_args,
                 polling_interval=None,
                 api_key=None,
                 poll_on_creation=True):
        client = APIClient(api_key=api_key, resources='all')
        if (polling_interval is None and has_pubnub
                and hasattr(client, 'channels')):
            polling_interval = _LONG_POLLING_INTERVAL

        super().__init__(poller, poller_args, polling_interval, api_key,
                         poll_on_creation)

        if has_pubnub and hasattr(client, 'channels'):
            config, channels = self._pubnub_config()
            self._pubnub = self._subscribe(config, channels)
예제 #29
0
    def __init__(self,
                 poller,
                 poller_args,
                 polling_interval=None,
                 api_key=None,
                 client=None,
                 poll_on_creation=True):
        if client is None:
            client = APIClient(api_key=api_key)

        super().__init__(poller=poller,
                         poller_args=poller_args,
                         polling_interval=polling_interval,
                         api_key=api_key,
                         client=client,
                         poll_on_creation=poll_on_creation)

        self._exception_handled = False
        self.add_done_callback(self._set_job_exception)
예제 #30
0
    def __init__(self,
                 model,
                 dependent_variable,
                 primary_key=None,
                 parameters=None,
                 cross_validation_parameters=None,
                 model_name=None,
                 calibration=None,
                 excluded_columns=None,
                 client=None,
                 cpu_requested=None,
                 memory_requested=None,
                 disk_requested=None,
                 notifications=None,
                 verbose=False):
        self.model = model
        self._input_model = model  # In case we need to modify the input
        if isinstance(dependent_variable, str):
            # Standardize the dependent variable as a list.
            dependent_variable = [dependent_variable]
        self.dependent_variable = dependent_variable

        # optional but common parameters
        self.primary_key = primary_key
        self.parameters = parameters or {}
        self.cv_params = cross_validation_parameters or {}
        self.model_name = model_name  # None lets Platform use template name
        self.excluded_columns = excluded_columns
        self.calibration = calibration
        self.job_resources = {
            'REQUIRED_CPU': cpu_requested,
            'REQUIRED_MEMORY': memory_requested,
            'REQUIRED_DISK_SPACE': disk_requested
        }
        self.notifications = notifications or {}
        self.verbose = verbose

        if client is None:
            client = APIClient(resources='all')
        self._client = client
        self.train_result_ = None