示例#1
0
    def __init__(self, db_config, blocking_auth=True, verbose=10, store=None):
        guest = db_config.get('guest')

        self.app = pyrebase.initialize_app(db_config)
        self.logger = logging.getLogger('FirebaseProvider')
        self.logger.setLevel(verbose)

        self.auth = None
        if not guest and 'serviceAccount' not in db_config.keys():
            self.auth = FirebaseAuth(self.app, db_config.get("use_email_auth"),
                                     db_config.get("email"),
                                     db_config.get("password"), blocking_auth)

        self.store = store if store else FirebaseArtifactStore(
            db_config, verbose=verbose, blocking_auth=blocking_auth)

        self._experiment_info_cache = {}
        self._experiment_cache = {}

        iothreads = 10

        if ThreadPool:
            self.pool = ThreadPool(iothreads)
        else:
            self.pool = None

        if self.auth and not self.auth.expired:
            self.__setitem__(self._get_user_keybase() + "email",
                             self.auth.get_user_email())
示例#2
0
    def __init__(self, config, verbose=10, blocking_auth=True):
        # TODO: implement connection
        self.url = config.get('serverUrl')
        self.verbose = verbose
        self.logger = logging.getLogger('HTTPProvider')
        self.logger.setLevel(self.verbose)

        self.auth = None
        self.app = pyrebase.initialize_app(config)
        guest = config.get('guest')
        if not guest and 'serviceAccount' not in config.keys():
            self.auth = FirebaseAuth(self.app, config.get("use_email_auth"),
                                     config.get("email"),
                                     config.get("password"), blocking_auth)
示例#3
0
    def __init__(self, db_config, measure_timestamp_diff=True,
                 blocking_auth=True, verbose=10):

        guest = db_config.get('guest')

        self.app = pyrebase.initialize_app(db_config)

        self.auth = None
        if not guest and 'serviceAccount' not in db_config.keys():
            self.auth = FirebaseAuth(self.app,
                                     db_config.get("use_email_auth"),
                                     db_config.get("email"),
                                     db_config.get("password"),
                                     blocking_auth)

        self.logger = logging.getLogger('FirebaseArtifactStore')
        self.logger.setLevel(verbose)
        super(FirebaseArtifactStore, self).__init__(measure_timestamp_diff)
示例#4
0
文件: model.py 项目: Mistobaan/studio
    def __init__(self, db_config, blocking_auth=True, verbose=10, store=None):
        guest = db_config.get('guest')

        self.app = pyrebase.initialize_app(db_config)
        self.logger = logging.getLogger('FirebaseProvider')
        self.logger.setLevel(verbose)

        self.auth = None
        if not guest and 'serviceAccount' not in db_config.keys():
            self.auth = FirebaseAuth(self.app, db_config.get("use_email_auth"),
                                     db_config.get("email"),
                                     db_config.get("password"), blocking_auth)

        self.store = store if store else FirebaseArtifactStore(
            db_config, verbose=verbose, blocking_auth=blocking_auth)

        if self.auth and not self.auth.expired:
            myemail = self._get(self._get_user_keybase() + "email")
            if not myemail or myemail != self.auth.get_user_email():
                self.__setitem__(self._get_user_keybase() + "email",
                                 self.auth.get_user_email())

        self.max_keys = db_config.get('max_keys', 100)
示例#5
0
    def __init__(self, config, verbose=10, measure_timestamp_diff=True):
        self.logger = logging.getLogger('GCloudArtifactStore')
        self.logger.setLevel(verbose)

        auth_config = config.get('auth')
        if not auth_config:
            self.client = storage.Client()
        else:
            assert auth_config['type'].lower() == 'firebase'
            app = pyrebase.initialize_app(auth_config)
            self.auth = FirebaseAuth(app,
                                     auth_config.get("use_email_auth"),
                                     auth_config.get("email"),
                                     auth_config.get("password"))

            self.client = storage.Client(credentials=self.auth.get_token())

        try:
            self.bucket = self.client.get_bucket(config['bucket'])
        except BaseException as e:
            self.logger.exception(e)
            self.bucket = self.client.create_bucket(config['bucket'])

        super(GCloudArtifactStore, self).__init__(measure_timestamp_diff)
示例#6
0
class FirebaseArtifactStore(TartifactStore):

    def __init__(self, db_config, measure_timestamp_diff=True,
                 blocking_auth=True, verbose=10):

        guest = db_config.get('guest')

        self.app = pyrebase.initialize_app(db_config)

        self.auth = None
        if not guest and 'serviceAccount' not in db_config.keys():
            self.auth = FirebaseAuth(self.app,
                                     db_config.get("use_email_auth"),
                                     db_config.get("email"),
                                     db_config.get("password"),
                                     blocking_auth)

        self.logger = logging.getLogger('FirebaseArtifactStore')
        self.logger.setLevel(verbose)
        super(FirebaseArtifactStore, self).__init__(measure_timestamp_diff)

    def _upload_file(self, key, local_file_path):
        try:
            storageobj = self.app.storage().child(key)
            if self.auth:
                storageobj.put(local_file_path,
                               self.auth.get_token(),
                               self.auth.get_user_id())
            else:
                storageobj.put(local_file_path)
        except Exception as err:
            self.logger.warn(("Uploading file {} with key {} into storage " +
                              "raised an exception: {}")
                             .format(local_file_path, key, err))

    def _download_file(self, key, local_file_path):
        self.logger.debug("Downloading file at key {} to local path {}..."
                          .format(key, local_file_path))
        try:
            storageobj = self.app.storage().child(key)

            if self.auth:
                # pyrebase download does not work with files that require
                # authentication...
                # Need to rewrite
                # storageobj.download(local_file_path, self.auth.get_token())

                headers = {"Authorization": "Firebase " +
                           self.auth.get_token()}
                escaped_key = key.replace('/', '%2f')
                url = "{}/o/{}?alt=media".format(
                    self.app.storage().storage_bucket,
                    escaped_key)

                response = self.app.requests.get(
                    url,
                    stream=True,
                    headers=headers,
                    verify=certifi.old_where())
                if response.status_code == 200:
                    with open(local_file_path, 'wb') as f:
                        for chunk in response:
                            f.write(chunk)
                else:
                    raise ValueError("Response error with code {}"
                                     .format(response.status_code))
            else:
                storageobj.download(local_file_path)
            self.logger.debug("Done")
        except Exception as err:
            self.logger.warn(
                ("Downloading file {} to local path {} from storage " +
                 "raised an exception: {}") .format(
                    key,
                    local_file_path,
                    err))

    def _delete_file(self, key):
        self.logger.debug("Deleting file at key {}".format(key))
        try:
            if self.auth:

                headers = {"Authorization": "Firebase " +
                           self.auth.get_token()}
            else:
                headers = {}

            escaped_key = key.replace('/', '%2f')
            url = "{}/o/{}?alt=media".format(
                self.app.storage().storage_bucket,
                escaped_key)

            response = self.app.requests.delete(
                url, headers=headers, verify=certifi.old_where())
            if response.status_code != 204:
                raise ValueError("Response error with code {}, text {}"
                                 .format(response.status_code, response.text))

            self.logger.debug("Done")
        except Exception as err:
            self.logger.warn(
                ("Deleting file {} from storage " +
                 "raised an exception: {}") .format(key, err))

    def _get_file_url(self, key):
        self.logger.debug("Getting a download url for a file at key {}"
                          .format(key))

        response_dict, url = self._get_file_meta(key)
        if response_dict is None:
            self.logger.debug("Getting file metainfo failed")
            return None

        self.logger.debug("Done")
        return url + '?alt=media&token=' \
            + response_dict['downloadTokens']

    def _get_file_timestamp(self, key):
        response, _ = self._get_file_meta(key)
        if response is not None and 'updated' in response.keys():
            timestamp = calendar.timegm(
                time.strptime(
                    response['updated'],
                    "%Y-%m-%dT%H:%M:%S.%fZ"))
            return timestamp
        else:
            return None

    def _get_file_meta(self, key):
        self.logger.debug("Getting metainformation for a file at key {}"
                          .format(key))
        try:
            if self.auth and self.auth.get_token():
                # pyrebase download does not work with files that require
                # authentication...
                # Need to rewrite
                # storageobj.download(local_file_path, self.auth.get_token())

                headers = {"Authorization": "Firebase " +
                           self.auth.get_token()}
            else:
                headers = {}

            escaped_key = key.replace('/', '%2f')
            url = "{}/o/{}".format(
                self.app.storage().storage_bucket,
                escaped_key)

            response = self.app.requests.get(
                url, headers=headers, verify=certifi.old_where())
            if response.status_code != 200:
                self.logger.info("Response error with code {}"
                                 .format(response.status_code))
                return (None, None)

            return (json.loads(response.content), url)

        except Exception as err:
            self.logger.warn(
                ("Getting metainfo of file {} " +
                 "raised an exception: {}") .format(key, err))
            self.logger.exception(err)
            return (None, None)

    def get_qualified_location(self, key):
        return 'gs://' + self.app.storage_bucket + '/' + key

    def get_bucket(self):
        return self.app.storage_bucket
示例#7
0
class FirebaseProvider(object):
    """Data provider for Firebase."""
    def __init__(self, db_config, blocking_auth=True, verbose=10, store=None):
        guest = db_config.get('guest')

        self.app = pyrebase.initialize_app(db_config)
        self.logger = logging.getLogger('FirebaseProvider')
        self.logger.setLevel(verbose)

        self.auth = None
        if not guest and 'serviceAccount' not in db_config.keys():
            self.auth = FirebaseAuth(self.app, db_config.get("use_email_auth"),
                                     db_config.get("email"),
                                     db_config.get("password"), blocking_auth)

        self.store = store if store else FirebaseArtifactStore(
            db_config, verbose=verbose, blocking_auth=blocking_auth)

        self._experiment_info_cache = {}
        self._experiment_cache = {}

        iothreads = 10

        if ThreadPool:
            self.pool = ThreadPool(iothreads)
        else:
            self.pool = None

        if self.auth and not self.auth.expired:
            self.__setitem__(self._get_user_keybase() + "email",
                             self.auth.get_user_email())

    def __getitem__(self, key):
        try:
            splitKey = key.split('/')
            key_path = '/'.join(splitKey[:-1])
            key_name = splitKey[-1]
            dbobj = self.app.database().child(key_path).child(key_name)
            return dbobj.get(self.auth.get_token()).val() if self.auth \
                else dbobj.get().val()
        except Exception as err:
            self.logger.warn(("Getting key {} from a database " +
                              "raised an exception: {}").format(key, err))
            return None

    def __setitem__(self, key, value):
        try:
            splitKey = key.split('/')
            key_path = '/'.join(splitKey[:-1])
            key_name = splitKey[-1]
            dbobj = self.app.database().child(key_path)
            if self.auth:
                dbobj.update({key_name: value}, self.auth.get_token())
            else:
                dbobj.update({key_name: value})
        except Exception as err:
            self.logger.warn(
                ("Putting key {}, value {} into a database " +
                 "raised an exception: {}").format(key, value, err))

    def _delete(self, key):
        dbobj = self.app.database().child(key)

        if self.auth:
            dbobj.remove(self.auth.get_token())
        else:
            dbobj.remove()

    def _get_userid(self):
        userid = None
        if self.auth:
            userid = self.auth.get_user_id()
        userid = userid if userid else 'guest'
        return userid

    def _get_user_keybase(self, userid=None):
        if userid is None:
            userid = self._get_userid()

        return "users/" + userid + "/"

    def _get_experiments_keybase(self, userid=None):
        return "experiments/"

    def _get_projects_keybase(self):
        return "projects/"

    def add_experiment(self, experiment):
        self._delete(self._get_experiments_keybase() + experiment.key)
        experiment.time_added = time.time()
        experiment.status = 'waiting'

        if 'local' in experiment.artifacts['workspace'].keys() and \
                os.path.exists(experiment.artifacts['workspace']['local']):
            experiment.git = git_util.get_git_info(
                experiment.artifacts['workspace']['local'])

        for tag, art in experiment.artifacts.iteritems():
            if art['mutable']:
                art['key'] = self._get_experiments_keybase() + \
                    experiment.key + '/' + tag + '.tgz'
            else:
                if 'local' in art.keys():
                    # upload immutable artifacts
                    art['key'] = self.store.put_artifact(art)

            if 'key' in art.keys():
                art['qualified'] = self.store.get_qualified_location(
                    art['key'])

            art['bucket'] = self.store.get_bucket()

        experiment_dict = experiment.__dict__.copy()
        experiment_dict['owner'] = self._get_userid()

        self.__setitem__(self._get_experiments_keybase() + experiment.key,
                         experiment_dict)

        self.__setitem__(
            self._get_user_keybase() + "experiments/" + experiment.key,
            experiment.key)

        if experiment.project and self.auth:
            self.__setitem__(
                self._get_projects_keybase() + experiment.project + "/" +
                experiment.key + "/owner", self.auth.get_user_id())

        self.checkpoint_experiment(experiment, blocking=True)
        self.logger.info("Added experiment " + experiment.key)

    def start_experiment(self, experiment):
        experiment.time_started = time.time()
        experiment.status = 'running'
        self.__setitem__(
            self._get_experiments_keybase() + experiment.key + "/status",
            "running")

        self.__setitem__(
            self._get_experiments_keybase() + experiment.key + "/time_started",
            experiment.time_started)

        self.checkpoint_experiment(experiment)

    def stop_experiment(self, key):
        # can be called remotely (the assumption is
        # that remote worker checks experiments status periodically,
        # and if it is 'stopped', kills the experiment.
        if isinstance(key, Experiment):
            key = key.key

        self.__setitem__(self._get_experiments_keybase() + key + "/status",
                         "stopped")

    def finish_experiment(self, experiment):
        time_finished = time.time()
        if isinstance(experiment, basestring):
            key = experiment
        else:
            key = experiment.key
            self.checkpoint_experiment(experiment, blocking=True)
            experiment.status = 'finished'
            experiment.time_finished = time_finished

        self.__setitem__(self._get_experiments_keybase() + key + "/status",
                         "finished")

        self.__setitem__(
            self._get_experiments_keybase() + key + "/time_finished",
            time_finished)

    def delete_experiment(self, experiment):
        if isinstance(experiment, basestring):
            experiment_key = experiment
            try:
                experiment = self.get_experiment(experiment)
                experiment_key = experiment.key
            except BaseException:
                experiment = None
        else:
            experiment_key = experiment.key

        self._delete(self._get_user_keybase() + 'experiments/' +
                     experiment_key)

        if experiment_key in self._experiment_cache.keys():
            del self._experiment_cache[experiment_key]
        if experiment_key in self._experiment_info_cache.keys():
            del self._experiment_info_cache[experiment_key]

        if experiment is not None:
            for tag, art in experiment.artifacts.iteritems():
                if art.get('key') is not None:
                    self.logger.debug(
                        ('Deleting artifact {} from the store, ' +
                         'artifact key {}').format(tag, art['key']))
                    self.store.delete_artifact(art)

        if experiment.project is not None:
            self._delete(self._get_projects_keybase() + experiment.project +
                         "/" + experiment.key)

        self._delete(self._get_experiments_keybase() + experiment.key)

    def checkpoint_experiment(self, experiment, blocking=False):
        if isinstance(experiment, basestring):
            key = experiment
            experiment = self.get_experiment(key, getinfo=False)
        else:
            key = experiment.key

        checkpoint_threads = [
            Thread(target=self.store.put_artifact, args=(art, ))
            for _, art in experiment.artifacts.iteritems()
            if art['mutable'] and art.get('local')
        ]

        for t in checkpoint_threads:
            t.start()

        self.__setitem__(
            self._get_experiments_keybase() + key + "/time_last_checkpoint",
            time.time())
        if blocking:
            for t in checkpoint_threads:
                t.join()
        else:
            return checkpoint_threads

    def _get_experiment_info(self, experiment):
        info = {}
        type_found = False
        '''
        local_modeldir = self.store.get_artifact(
            experiment.artifacts['modeldir'])
        hdf5_files = glob.glob(os.path.join(local_modeldir, '*.hdf*'))
        type_found = False
        if any(hdf5_files):
            info['type'] = 'keras'
            info['no_checkpoints'] = len(hdf5_files)
            type_found = True

        meta_files = glob.glob(os.path.join(local_modeldir, '*.meta'))
        if any(meta_files) and not type_found:
            info['type'] = 'tensorflow'
            global_step = checkpoint_utils.load_variable(
                local_modeldir, 'global_step')

            info['global_step'] = global_step
            type_found = True
        '''

        if not type_found:
            info['type'] = 'unknown'

        info['logtail'] = self._get_experiment_logtail(experiment)

        if experiment.metric is not None:
            metric_str = experiment.metric.split(':')
            metric_name = metric_str[0]
            metric_type = metric_str[1] if len(metric_str) > 1 else None

            tbtar = self.store.stream_artifact(experiment.artifacts['tb'])

            if metric_type == 'min':

                def metric_accum(x, y):
                    return min(x, y) if x else y
            elif metric_type == 'max':

                def metric_accum(x, y):
                    return max(x, y) if x else y
            else:

                def metric_accum(x, y):
                    return y

            metric_value = None
            for f in tbtar:
                if f.isreg():
                    for e in util.event_reader(tbtar.extractfile(f)):
                        for v in e.summary.value:
                            if v.tag == metric_name:
                                metric_value = metric_accum(
                                    metric_value, v.simple_value)

            info['metric_value'] = metric_value

        return info

    def _get_experiment_logtail(self, experiment):
        try:
            tarf = self.store.stream_artifact(experiment.artifacts['output'])
            if not tarf:
                return None

            logdata = tarf.extractfile(tarf.members[0]).read()
            logdata = util.remove_backspaces(logdata).split('\n')
            return logdata
        except BaseException as e:
            self.logger.info('Getting experiment logtail raised an exception:')
            self.logger.info(e)
            return None

    def get_experiment(self, key, getinfo=True):
        data = self.__getitem__(self._get_experiments_keybase() + key)
        assert data, "data at path %s not found! " % (
            self._get_experiments_keybase() + key)
        data['key'] = key

        experiment_stub = experiment_from_dict(data)

        if getinfo:
            self._start_info_download(experiment_stub)

        info = self._experiment_info_cache.get(key)[0] \
            if self._experiment_info_cache.get(key) else None

        return experiment_from_dict(data, info)

    def _start_info_download(self, experiment):
        key = experiment.key
        if key not in self._experiment_info_cache.keys():
            self._experiment_info_cache[key] = ({}, time.time())

        def download_info():
            try:
                self._experiment_info_cache[key] = (
                    self._get_experiment_info(experiment), time.time())

                self.logger.debug("Finished info download for " + key)
            except Exception as e:
                self.logger.info(
                    "Exception {} while info download for {}".format(e, key))

        if not(any(self._experiment_info_cache[key][0])) or \
           self._experiment_info_cache[key][1] < \
           experiment.time_last_checkpoint:

            self.logger.debug("Starting info download for " + key)
            if self.pool:
                Thread(target=download_info).start()
            else:
                download_info()

    def get_user_experiments(self, userid=None, blocking=True):
        if userid and '@' in userid:
            users = self.get_users()
            user_ids = [u for u in users if users[u].get('email') == userid]
            if len(user_ids) < 1:
                return None
            else:
                userid = user_ids[0]

        experiment_keys = self.__getitem__(
            self._get_user_keybase(userid) + "/experiments")
        if not experiment_keys:
            experiment_keys = {}
        return self._get_valid_experiments(experiment_keys.keys(),
                                           getinfo=True,
                                           blocking=blocking)

    def get_project_experiments(self, project):
        experiment_keys = self.__getitem__(self._get_projects_keybase() +
                                           project)
        if not experiment_keys:
            experiment_keys = {}
        return self._get_valid_experiments(experiment_keys.keys(),
                                           getinfo=True)

    def get_artifacts(self, key):
        experiment = self.get_experiment(key, getinfo=False)
        retval = {}
        if experiment.artifacts is not None:
            for tag, art in experiment.artifacts.iteritems():
                url = self.store.get_artifact_url(art)
                if url is not None:
                    retval[tag] = url

        return retval

    def _get_valid_experiments(self,
                               experiment_keys,
                               getinfo=False,
                               blocking=True):
        def cache_valid_experiment(key):
            try:
                self._experiment_cache[key] = self.get_experiment(
                    key, getinfo=getinfo)
            except AssertionError:
                self.logger.warn(
                    ("Experiment {} does not exist " +
                     "or is corrupted, try to delete record").format(key))
                try:
                    self.delete_experiment(key)
                except BaseException:
                    pass

        if self.pool:
            if blocking:
                self.pool.map(cache_valid_experiment, experiment_keys)
            else:
                self.pool.map_async(cache_valid_experiment, experiment_keys)
        else:
            for e in experiment_keys:
                cache_valid_experiment(e)

        return [
            self._experiment_cache[key] for key in experiment_keys
            if key in self._experiment_cache.keys()
        ]

    def get_projects(self):
        return self.__getitem__(self._get_projects_keybase())

    def get_users(self):
        return self.__getitem__('users/')

    def refresh_auth_token(self, email, refresh_token):
        if self.auth:
            self.auth.refresh_token(email, refresh_token)

    def get_auth_domain(self):
        return self.app.auth_domain

    def is_auth_expired(self):
        if self.auth:
            return self.auth.expired
        else:
            return False

    def can_write_experiment(self, key=None, user=None):
        assert key is not None
        user = user if user else self._get_userid()

        owner = self.__getitem__(self._get_experiments_keybase() + key +
                                 "/owner")
        if owner is None:
            return True
        else:
            return (owner == user)

    def __enter__(self):
        return self

    def __exit__(self, *args):
        if self.pool:
            self.pool.close()
        if self.app:
            self.app.requests.close()
示例#8
0
class GCloudArtifactStore(TartifactStore):
    def __init__(self, config, verbose=10, measure_timestamp_diff=True):
        self.logger = logging.getLogger('GCloudArtifactStore')
        self.logger.setLevel(verbose)

        auth_config = config.get('auth')
        if not auth_config:
            self.client = storage.Client()
        else:
            assert auth_config['type'].lower() == 'firebase'
            app = pyrebase.initialize_app(auth_config)
            self.auth = FirebaseAuth(app,
                                     auth_config.get("use_email_auth"),
                                     auth_config.get("email"),
                                     auth_config.get("password"))

            self.client = storage.Client(credentials=self.auth.get_token())

        try:
            self.bucket = self.client.get_bucket(config['bucket'])
        except BaseException as e:
            self.logger.exception(e)
            self.bucket = self.client.create_bucket(config['bucket'])

        super(GCloudArtifactStore, self).__init__(measure_timestamp_diff)

    def _upload_file(self, key, local_path):
        self.bucket.blob(key).upload_from_filename(local_path)

    def _download_file(self, key, local_path):
        self.bucket.get_blob(key).download_to_filename(local_path)

    def _delete_file(self, key):
        self.bucket.get_blob(key).delete()

    def _get_file_url(self, key):
        expiration = long(time.time() + 100000)
        return self.bucket.blob(key).generate_signed_url(expiration)

    def _get_file_timestamp(self, key):
        time_updated = self.bucket.get_blob(key).updated
        if time_updated:
            timestamp = calendar.timegm(time_updated.timetuple())
            return timestamp
        else:
            return None

    def grant_write(self, key, user):
        blob = self.bucket.get_blob(key)
        if not blob:
            blob = self.bucket.blob(key)
            blob.upload_from_string("dummy")

        acl = blob.acl
        if user:
            acl.user(user).grant_owner()
        else:
            acl.all().grant_owner()

        acl.save()

    def get_qualified_location(self, key):
        return 'gs://' + self.bucket.name + '/' + key

    def get_bucket(self):
        return self.bucket.name
示例#9
0
class HTTPProvider(object):
    """Data provider communicating with API server."""
    def __init__(self, config, verbose=10, blocking_auth=True):
        # TODO: implement connection
        self.url = config.get('serverUrl')
        self.verbose = verbose
        self.logger = logging.getLogger('HTTPProvider')
        self.logger.setLevel(self.verbose)

        self.auth = None
        self.app = pyrebase.initialize_app(config)
        guest = config.get('guest')
        if not guest and 'serviceAccount' not in config.keys():
            self.auth = FirebaseAuth(self.app, config.get("use_email_auth"),
                                     config.get("email"),
                                     config.get("password"), blocking_auth)

    def add_experiment(self, experiment):
        headers = self._get_headers()
        request = requests.post(self.url + '/api/add_experiment',
                                headers=headers,
                                data=json.dumps(
                                    {"experiment": experiment.__dict__}))

        self._raise_detailed_error(request)
        artifacts = request.json()['artifacts']

        self._update_artifacts(experiment, artifacts)

    def _update_artifacts(self, experiment, artifacts):
        for tag, art in experiment.artifacts.iteritems():
            art['key'] = artifacts[tag]['key']
            art['qualified'] = artifacts[tag]['qualified']
            art['bucket'] = artifacts[tag]['bucket']

            HTTPArtifactStore(artifacts[tag]['url'],
                              artifacts[tag]['timestamp'],
                              self.verbose) \
                .put_artifact(art)

    def delete_experiment(self, experiment):
        if isinstance(experiment, basestring):
            key = experiment
        else:
            key = experiment.key

        headers = self._get_headers()
        request = requests.post(self.url + '/api/delete_experiment',
                                headers=headers,
                                data=json.dumps({"key": key}))
        self._raise_detailed_error(request)

    def get_experiment(self, experiment, getinfo='True'):
        if isinstance(experiment, basestring):
            key = experiment
        else:
            key = experiment.key

        headers = self._get_headers()
        request = requests.post(self.url + '/api/get_experiment',
                                headers=headers,
                                data=json.dumps({"key": key}))

        self._raise_detailed_error(request)
        return model.experiment_from_dict(request.json()['experiment'])

    def start_experiment(self, experiment):
        self.checkpoint_experiment(experiment)
        if isinstance(experiment, basestring):
            key = experiment
        else:
            key = experiment.key

        headers = self._get_headers()
        request = requests.post(self.url + '/api/start_experiment',
                                headers=headers,
                                data=json.dumps({"key": key}))
        self._raise_detailed_error(request)

    def stop_experiment(self, experiment):
        key = experiment.key

        headers = self._get_headers()
        request = requests.post(self.url + '/api/stop_experiment',
                                headers=headers,
                                data=json.dumps({"key": key}))
        self._raise_detailed_error(request)

    def finish_experiment(self, experiment):
        self.checkpoint_experiment(experiment)
        if isinstance(experiment, basestring):
            key = experiment
        else:
            key = experiment.key

        headers = self._get_headers()
        request = requests.post(self.url + '/api/finish_experiment',
                                headers=headers,
                                data=json.dumps({"key": key}))
        self._raise_detailed_error(request)

    def get_user_experiments(self, user=None, blocking=True):
        headers = self._get_headers()
        user = user if user else self._get_userid()

        response = requests.post(self.url + '/api/get_user_experiments',
                                 headers=headers,
                                 data=json.dumps({"user": user}))

        self._raise_detailed_error(response)
        data = response.json()['experiments']

        experiments = data

        return experiments

    def get_projects(self):
        headers = self._get_headers()
        response = requests.post(self.url + '/api/get_projects',
                                 headers=headers)

        self._raise_detailed_error(response)
        projects = response.json()['projects']

        return projects

    def get_project_experiments(self, project):
        headers = self._get_headers()
        response = requests.post(self.url + '/api/get_project_experiments',
                                 headers=headers,
                                 data=json.dumps({"project": project}))

        self._raise_detailed_error(response)
        data = response.json()['experiments']

        experiments = [model.experiment_from_dict(edict) for edict in data]

        return experiments

    def get_artifacts(self):
        raise NotImplementedError()

    def get_artifact(self, artifact, only_newer='True'):
        return HTTPArtifactStore(artifact['url'], self.verbose) \
            .get_artifact(artifact)

    def get_users(self):
        headers = self._get_headers()
        response = requests.post(self.url + '/api/get_users', headers=headers)

        self._raise_detailed_error(response)
        users = response.json()['users']

        return users

    def checkpoint_experiment(self, experiment):
        if isinstance(experiment, basestring):
            key = experiment
            experiment = self.get_experiment(key)
        else:
            key = experiment.key

        headers = self._get_headers()
        request = requests.post(self.url + '/api/checkpoint_experiment',
                                headers=headers,
                                data=json.dumps({"key": key}))

        self._raise_detailed_error(request)
        artifacts = request.json()['artifacts']

        self._update_artifacts(experiment, artifacts)

    def refresh_auth_token(self, email, refresh_token):
        if self.auth:
            self.auth.refresh_token(email, refresh_token)

    def _get_headers(self):
        headers = {"content-type": "application/json"}
        if self.auth:
            headers["Authorization"] = "Firebase " + self.auth.get_token()
        return headers

    def _get_userid(self):
        userid = None
        if self.auth:
            userid = self.auth.get_user_id()
        userid = userid if userid else 'guest'
        return userid

    def _raise_detailed_error(self, request):
        if request.status_code != 200:
            raise ValueError(request.message)

        data = request.json()
        if data['status'] == 'ok':
            return

        raise ValueError(data['status'])

    def __enter__(self):
        return self

    def __exit__(self, *args):
        pass
示例#10
0
from fastapi import FastAPI, Depends, status, HTTPException
from fastapi_cloudauth.firebase import FirebaseCurrentUser, FirebaseClaims

from auth import FirebaseAuth
from model import AuthModel
from db import DatabaseManager
from user import User

from dotenv import load_dotenv
load_dotenv()

app = FastAPI()
firebase_auth = FirebaseAuth()
get_current_user = FirebaseCurrentUser()


@app.get("/")
def index():
    return {'message': 'Thanks for visiting the api'}


@app.get("/student/upcomming_classes")
def get_upcomming_classes(user: FirebaseClaims = Depends(get_current_user)):
    # user = UserModel(user[''])
    return f"Hello, {user.__dict__}"


@app.get("/student/profile")
def get_student_profile(user: FirebaseClaims = Depends(get_current_user)):
    profile = User(user)
    return dict(user)
示例#11
0
 def __init__(self, data):
     fa = FirebaseAuth()
     fa.initialisedb()
     self.root = fa.createdb(data)
示例#12
0
文件: model.py 项目: Mistobaan/studio
class FirebaseProvider(object):
    """Data provider for Firebase."""
    def __init__(self, db_config, blocking_auth=True, verbose=10, store=None):
        guest = db_config.get('guest')

        self.app = pyrebase.initialize_app(db_config)
        self.logger = logging.getLogger('FirebaseProvider')
        self.logger.setLevel(verbose)

        self.auth = None
        if not guest and 'serviceAccount' not in db_config.keys():
            self.auth = FirebaseAuth(self.app, db_config.get("use_email_auth"),
                                     db_config.get("email"),
                                     db_config.get("password"), blocking_auth)

        self.store = store if store else FirebaseArtifactStore(
            db_config, verbose=verbose, blocking_auth=blocking_auth)

        if self.auth and not self.auth.expired:
            myemail = self._get(self._get_user_keybase() + "email")
            if not myemail or myemail != self.auth.get_user_email():
                self.__setitem__(self._get_user_keybase() + "email",
                                 self.auth.get_user_email())

        self.max_keys = db_config.get('max_keys', 100)

    def _get(self, key, shallow=False):
        try:
            splitKey = key.split('/')
            key_path = '/'.join(splitKey[:-1])
            key_name = splitKey[-1]
            dbobj = self.app.database().child(key_path).child(key_name)
            return dbobj.get(self.auth.get_token(), shallow=shallow).val() \
                if self.auth else dbobj.get(shallow=shallow).val()
        except Exception as err:
            self.logger.warn(("Getting key {} from a database " +
                              "raised an exception: {}").format(key, err))
            return None

    def __setitem__(self, key, value):
        try:
            splitKey = key.split('/')
            key_path = '/'.join(splitKey[:-1])
            key_name = splitKey[-1]
            dbobj = self.app.database().child(key_path)
            if self.auth:
                dbobj.update({key_name: value}, self.auth.get_token())
            else:
                dbobj.update({key_name: value})
        except Exception as err:
            self.logger.warn(
                ("Putting key {}, value {} into a database " +
                 "raised an exception: {}").format(key, value, err))

    def _delete(self, key, token=None):
        dbobj = self.app.database().child(key)

        if self.auth:
            dbobj.remove(self.auth.get_token())
        else:
            dbobj.remove()

    def _get_userid(self):
        userid = None
        if self.auth:
            userid = self.auth.get_user_id()
        userid = userid if userid else 'guest'
        return userid

    def _get_user_keybase(self, userid=None):
        if userid is None:
            userid = self._get_userid()

        return "users/" + userid + "/"

    def _get_experiments_keybase(self, userid=None):
        return "experiments/"

    def _get_projects_keybase(self):
        return "projects/"

    def add_experiment(self, experiment, userid=None):
        self._delete(self._get_experiments_keybase() + experiment.key)
        experiment.time_added = time.time()
        experiment.status = 'waiting'

        if 'local' in experiment.artifacts['workspace'].keys() and \
                os.path.exists(experiment.artifacts['workspace']['local']):
            experiment.git = git_util.get_git_info(
                experiment.artifacts['workspace']['local'])

        for tag, art in experiment.artifacts.iteritems():
            if art['mutable']:
                art['key'] = self._get_experiments_keybase() + \
                    experiment.key + '/' + tag + '.tgz'
            else:
                if 'local' in art.keys():
                    # upload immutable artifacts
                    art['key'] = self.store.put_artifact(art)

            if art.get('key') is not None:
                art['qualified'] = self.store.get_qualified_location(
                    art['key'])

            art['bucket'] = self.store.get_bucket()

        userid = userid if userid else self._get_userid()

        experiment_dict = experiment.__dict__.copy()
        experiment_dict['owner'] = userid

        self.__setitem__(self._get_experiments_keybase() + experiment.key,
                         experiment_dict)

        self.__setitem__(
            self._get_user_keybase(userid) + "experiments/" + experiment.key,
            experiment.time_added)

        if experiment.project and self.auth:
            self.__setitem__(
                self._get_projects_keybase() + experiment.project + "/" +
                experiment.key + "/owner", userid)

        self.checkpoint_experiment(experiment, blocking=True)
        self.logger.info("Added experiment " + experiment.key)

    def start_experiment(self, experiment):
        experiment.time_started = time.time()
        experiment.status = 'running'
        self.__setitem__(
            self._get_experiments_keybase() + experiment.key + "/status",
            "running")

        self.__setitem__(
            self._get_experiments_keybase() + experiment.key + "/time_started",
            experiment.time_started)

        self.checkpoint_experiment(experiment)

    def stop_experiment(self, key):
        # can be called remotely (the assumption is
        # that remote worker checks experiments status periodically,
        # and if it is 'stopped', kills the experiment.
        if isinstance(key, Experiment):
            key = key.key

        self.__setitem__(self._get_experiments_keybase() + key + "/status",
                         "stopped")

    def finish_experiment(self, experiment):
        time_finished = time.time()
        if isinstance(experiment, basestring):
            key = experiment
        else:
            key = experiment.key
            self.checkpoint_experiment(experiment, blocking=True)
            experiment.status = 'finished'
            experiment.time_finished = time_finished

        self.__setitem__(self._get_experiments_keybase() + key + "/status",
                         "finished")

        self.__setitem__(
            self._get_experiments_keybase() + key + "/time_finished",
            time_finished)

    def delete_experiment(self, experiment):
        if isinstance(experiment, basestring):
            experiment_key = experiment
            try:
                experiment = self.get_experiment(experiment)
                experiment_key = experiment.key
            except BaseException:
                experiment = None
        else:
            experiment_key = experiment.key

        self._delete(self._get_user_keybase() + 'experiments/' +
                     experiment_key)
        if experiment is not None:
            for tag, art in experiment.artifacts.iteritems():
                if art.get('key') is not None:
                    self.logger.debug(
                        ('Deleting artifact {} from the store, ' +
                         'artifact key {}').format(tag, art['key']))
                    self.store.delete_artifact(art)

            if experiment.project is not None:
                self._delete(self._get_projects_keybase() +
                             experiment.project + "/" + experiment_key)

        self._delete(self._get_experiments_keybase() + experiment_key)

    def checkpoint_experiment(self, experiment, blocking=True):
        if isinstance(experiment, basestring):
            key = experiment
            experiment = self.get_experiment(key, getinfo=False)
        else:
            key = experiment.key

        # self.logger.info("%s, %s: checkpointing experiment" %
        #                  (os.getpid(), key))
        checkpoint_threads = [
            Thread(target=self.store.put_artifact, args=(art, ))
            for _, art in experiment.artifacts.iteritems()
            if art['mutable'] and art.get('local')
        ]

        for t in checkpoint_threads:
            t.start()

        self.__setitem__(
            self._get_experiments_keybase() + key + "/time_last_checkpoint",
            time.time())
        if blocking:
            for t in checkpoint_threads:
                t.join()
            # self.logger.info("%s, %s: finish checkpointing experiment" %
            #                  (os.getpid(), key))
        else:
            return checkpoint_threads

    def _get_experiment_info(self, experiment):
        info = {}
        type_found = False

        if not type_found:
            info['type'] = 'unknown'

        info['logtail'] = self._get_experiment_logtail(experiment)

        if experiment.metric is not None:
            metric_str = experiment.metric.split(':')
            metric_name = metric_str[0]
            metric_type = metric_str[1] if len(metric_str) > 1 else None

            tbtar = self.store.stream_artifact(experiment.artifacts['tb'])

            if metric_type == 'min':

                def metric_accum(x, y):
                    return min(x, y) if x else y
            elif metric_type == 'max':

                def metric_accum(x, y):
                    return max(x, y) if x else y
            else:

                def metric_accum(x, y):
                    return y

            metric_value = None
            for f in tbtar:
                if f.isreg():
                    for e in util.event_reader(tbtar.extractfile(f)):
                        for v in e.summary.value:
                            if v.tag == metric_name:
                                metric_value = metric_accum(
                                    metric_value, v.simple_value)

            info['metric_value'] = metric_value

        return info

    def _get_experiment_logtail(self, experiment):
        try:
            tarf = self.store.stream_artifact(experiment.artifacts['output'])
            if not tarf:
                return None

            logdata = tarf.extractfile(tarf.members[0]).read()
            logdata = util.remove_backspaces(logdata).split('\n')
            return logdata

        except BaseException as e:
            self.logger.info('Getting experiment logtail raised an exception:')
            self.logger.info(e)
            return None

        finally:
            if tarf:
                tarf.close()

    def get_experiment(self, key, getinfo=True):
        data = self._get(self._get_experiments_keybase() + key)
        assert data, "data at path %s not found! " % (
            self._get_experiments_keybase() + key)
        data['key'] = key

        experiment_stub = experiment_from_dict(data)

        expinfo = {}
        if getinfo:
            try:
                expinfo = self._get_experiment_info(experiment_stub)

            except Exception as e:
                self.logger.info(
                    "Exception {} while info download for {}".format(e, key))

        return experiment_from_dict(data, expinfo)

    def get_user_experiments(self, userid=None, blocking=True):
        if userid and '@' in userid:
            users = self.get_users()
            user_ids = [u for u in users if users[u].get('email') == userid]
            if len(user_ids) < 1:
                return None
            else:
                userid = user_ids[0]

        experiment_keys = self._get(
            self._get_user_keybase(userid) + "/experiments")
        if not experiment_keys:
            experiment_keys = {}

        keys = sorted(experiment_keys.keys(),
                      key=lambda k: experiment_keys[k],
                      reverse=True)

        return keys

    def get_project_experiments(self, project):
        experiment_keys = self._get(self._get_projects_keybase() + project)
        if not experiment_keys:
            experiment_keys = {}

        return experiment_keys

    def get_artifacts(self, key):
        experiment = self.get_experiment(key, getinfo=False)
        retval = {}
        if experiment.artifacts is not None:
            for tag, art in experiment.artifacts.iteritems():
                url = self.store.get_artifact_url(art)
                if url is not None:
                    retval[tag] = url

        return retval

    def get_artifact(self, artifact, only_newer=True):
        return self.store.get_artifact(artifact, only_newer=only_newer)

    def get_projects(self):
        return self._get(self._get_projects_keybase(), shallow=True)

    def get_users(self):
        user_ids = self._get('users/', shallow=True)
        retval = {}
        for user_id in user_ids.keys():
            retval[user_id] = {
                'email': self._get('users/' + user_id + '/email')
            }
        return retval

    def refresh_auth_token(self, email, refresh_token):
        if self.auth:
            self.auth.refresh_token(email, refresh_token)

    def is_auth_expired(self):
        if self.auth:
            return self.auth.expired
        else:
            return False

    def can_write_experiment(self, key=None, user=None):
        assert key is not None
        user = user if user else self._get_userid()

        owner = self._get(self._get_experiments_keybase() + key + "/owner")
        if owner is None or owner == 'guest':
            return True
        else:
            return (owner == user)

    def __enter__(self):
        return self

    def __exit__(self, *args):
        if self.app:
            self.app.requests.close()

        if self.store:
            self.store.__exit__()