コード例 #1
0
ファイル: db_interface.py プロジェクト: gweiying/tfutils
def version_info(module):
    """Get version of a standard python module.

    Args:
        module (module): python module object to get version info for.

    Returns:
        dict: dictionary of version info.

    """
    if hasattr(module, '__version__'):
        version = module.__version__
    elif hasattr(module, 'VERSION'):
        version = module.VERSION
    else:
        pkgname = module.__name__.split('.')[0]
        try:
            info = pkg_resources.get_distribution(pkgname)
        except (pkg_resources.DistributionNotFound, pkg_resources.RequirementParseError):
            version = None
            log.warning(
                'version information not found for %s -- what package is this from?' % module.__name__)
        else:
            version = info.version

    return {'version': version}
コード例 #2
0
ファイル: db_interface.py プロジェクト: donnate/tfutils-1
 def sync_with_host(self):
     if self.checkpoint_thread is not None:
         try:
             self.checkpoint_coord.join([self.checkpoint_thread])
         except Exception as error:
             log.warning('A checkpoint thead raised an exception '
                         'while saving a checkpoint.')
             log.error(error)
             raise
         else:
             self.checkpoint_thread = None
コード例 #3
0
def git_info(repo):
    """Return information about a git repo.

    Args:
        repo (git.Repo): The git repo to be investigated.

    Returns:
        dict: Git repo information

    """
    if repo.is_dirty():
        log.warning('repo %s is dirty -- having committment issues?' %
                    repo.git_dir)
        clean = False
    else:
        clean = True
    branchname = repo.active_branch.name
    commit = repo.active_branch.commit.hexsha
    origin = repo.remote('origin')
    urls = map(str, list(origin.urls))
    remote_ref = [
        _r for _r in origin.refs if _r.name == 'origin/' + branchname
    ]
    if not len(remote_ref) > 0:
        log.warning('Active branch %s not in origin ref' % branchname)
        active_branch_in_origin = False
        commit_in_log = False
    else:
        active_branch_in_origin = True
        remote_ref = remote_ref[0]
        gitlog = remote_ref.log()
        shas = [_r.oldhexsha for _r in gitlog] + \
            [_r.newhexsha for _r in gitlog]
        if commit not in shas:
            log.warning('Commit %s not in remote origin log for branch %s' %
                        (commit, branchname))
            commit_in_log = False
        else:
            commit_in_log = True
    info = {
        'git_dir': repo.git_dir,
        'active_branch': branchname,
        'commit': commit,
        'remote_urls': urls,
        'clean': clean,
        'active_branch_in_origin': active_branch_in_origin,
        'commit_in_log': commit_in_log
    }
    return info
コード例 #4
0
ファイル: db_interface.py プロジェクト: donnate/tfutils-1
    def load_from_db(self,
                     query,
                     cache_filters=False,
                     collfs=None,
                     collfs_recent=None):
        """Load checkpoint from the database.

        Checks the recent and regular checkpoint fs to find the latest one
        matching the query. Returns the GridOut obj corresponding to the
        record.

        Args:
            query: dict expressing MongoDB query
        """
        if collfs is None:
            collfs = self.collfs
        coll = collfs._GridFS__files
        if collfs_recent is None:
            collfs_recent = self.collfs_recent
        coll_recent = collfs_recent._GridFS__files

        query['saved_filters'] = True
        count = collfs.find(query).count()
        if count > 0:  # get latest that matches query
            ckpt_record = coll.find(query, sort=[('uploadDate', -1)])[0]
            loading_from = coll
        else:
            ckpt_record = None

        try:
            count_recent = collfs_recent.find(query).count()
        except Exception as inst:
            raise er.OperationFailure(
                inst.args[0] +
                "\n Is your dbname too long? Mongo requires that dbnames be no longer than 64 characters."
            )
        if count_recent > 0:  # get latest that matches query
            ckpt_record_recent = coll_recent.find(query,
                                                  sort=[('uploadDate', -1)])[0]
            # use the record with latest timestamp
            if ckpt_record is None or ckpt_record_recent[
                    'uploadDate'] > ckpt_record['uploadDate']:
                loading_from = coll_recent
                ckpt_record = ckpt_record_recent

        if count + count_recent == 0:  # no matches for query
            log.warning('No matching checkpoint for query "{}"'.format(
                repr(query)))
            return

        database = loading_from._Collection__database
        log.info('Loading checkpoint from %s' % loading_from.full_name)

        if cache_filters:
            filename = os.path.basename(ckpt_record['filename'])
            cache_filename = os.path.join(self.cache_dir, filename)

            # check if there is no local copy
            if not os.path.isfile(cache_filename):
                log.info('No cache file at %s, loading from DB' %
                         cache_filename)
                # create new file to write from gridfs
                load_dest = open(cache_filename, "w+")
                load_dest.close()
                load_dest = open(cache_filename, 'rwb+')
                fsbucket = gridfs.GridFSBucket(
                    database, bucket_name=loading_from.name.split('.')[0])
                fsbucket.download_to_stream(ckpt_record['_id'], load_dest)
                load_dest.close()
                if ckpt_record[
                        '_saver_write_version'] == saver_pb2.SaverDef.V2:
                    assert cache_filename.endswith('.tar')
                    tar = tarfile.open(cache_filename)
                    tar.extractall(path=self.cache_dir)
                    tar.close()
                    cache_filename = os.path.splitext(cache_filename)[0]
                    verify_pb2_v2_files(cache_filename, ckpt_record)
            else:
                if ckpt_record[
                        '_saver_write_version'] == saver_pb2.SaverDef.V2:
                    cache_filename = os.path.splitext(cache_filename)[0]
                    verify_pb2_v2_files(cache_filename, ckpt_record)
                log.info('Cache file found at %s, using that to load' %
                         cache_filename)
        else:
            cache_filename = None
        return ckpt_record, cache_filename
コード例 #5
0
ファイル: db_interface.py プロジェクト: donnate/tfutils-1
    def __init__(self,
                 params=None,
                 save_params=None,
                 load_params=None,
                 sess=None,
                 global_step=None,
                 cache_dir=None,
                 *tfsaver_args,
                 **tfsaver_kwargs):
        """
        :Kwargs:
            - params (dict)
                Describing all parameters of experiment
            - save_params (dict)
            - load_params (dict)
            - sess (tesorflow.Session)
                Object in which to run calculations.  This is required if actual loading/
                saving is going to be done (as opposed to just e.g. getting elements from
                the MongoDB).
            - global_step (tensorflow.Variable)
                Global step variable, the one that is updated by apply_gradients.  This
                is required if being using in a training context.
            - *tfsaver_args, **tsaver_kwargs
                Additional arguments to be passed onto base Saver class constructor

        """
        self.params = params
        self._skip_check = params.get('skip_check', False)
        if self._skip_check:
            log.warning('Skipping version check and info...')
        self.sonified_params = sonify(self.params, skip=self._skip_check)
        self.save_params = save_params
        self.load_params = load_params
        self.sess = sess
        self.global_step = global_step
        self.tfsaver_args = tfsaver_args
        self.tfsaver_kwargs = tfsaver_kwargs
        self.var_list = tfsaver_kwargs.get('var_list', None)

        if save_params is None:
            save_params = {}
        if load_params is None:
            load_params = {}
        location_variables = ['host', 'port', 'dbname', 'collname', 'exp_id']
        for _k in location_variables:
            if _k in save_params:
                sv = save_params[_k]
            else:
                sv = load_params[_k]
            if _k in load_params:
                lv = load_params[_k]
            else:
                lv = save_params[_k]
            setattr(self, _k, sv)
            setattr(self, 'load_' + _k, lv)
        self.sameloc = all([
            getattr(self, _k) == getattr(self, 'load_' + _k)
            for _k in location_variables
        ])
        if 'query' in load_params and not load_params[
                'query'] is None and 'exp_id' in load_params['query']:
            self.sameloc = self.sameloc & (load_params['query']['exp_id']
                                           == self.exp_id)

        for _k in [
                'do_save', 'save_metrics_freq', 'save_valid_freq',
                'cache_filters_freq', 'cache_max_num', 'save_filters_freq',
                'save_initial_filters', 'save_to_gfs'
        ]:
            setattr(self, _k, save_params.get(_k, DEFAULT_SAVE_PARAMS[_k]))

        for _k in ['do_restore', 'from_ckpt', 'to_restore', 'load_param_dict']:
            setattr(self, _k, load_params.get(_k, DEFAULT_LOAD_PARAMS[_k]))

        self.rec_to_save = None
        self.checkpoint_thread = None
        self.outrecs = []

        self.conn = pymongo.MongoClient(host=self.host, port=self.port)
        self.conn.server_info()
        self.collfs = gridfs.GridFS(self.conn[self.dbname], self.collname)

        recent_name = '_'.join(
            [self.dbname, self.collname, self.exp_id, '__RECENT'])
        self.collfs_recent = gridfs.GridFS(self.conn[recent_name])

        self.load_data = None
        load_query = load_params.get('query')
        if load_query is None:
            load_query = {}
        else:
            if self.sameloc and (not save_params == {}):
                raise Exception('Loading pointlessly')
            else:
                self.sameloc = False
                # print('Set sameloc to False!')

        if 'exp_id' not in load_query:
            load_query.update({'exp_id': self.load_exp_id})

        self.load_query = load_query
        if self.load_host != self.host or self.port != self.load_port:
            self.load_conn = pymongo.MongoClient(host=self.load_host,
                                                 port=self.load_port)
            self.load_conn.server_info()
        else:
            self.load_conn = self.conn
        self.load_collfs = gridfs.GridFS(self.load_conn[self.load_dbname],
                                         self.load_collname)
        load_recent_name = '_'.join([
            self.load_dbname, self.load_collname, self.load_exp_id, '__RECENT'
        ])
        self.load_collfs_recent = gridfs.GridFS(
            self.load_conn[load_recent_name])

        if (save_params == {}) and (
                'cache_dir' in load_params
        ):  # use cache_dir from load params if save_params not given
            cache_dir = load_params['cache_dir']
        elif 'cache_dir' in save_params:
            cache_dir = save_params['cache_dir']
        else:
            cache_dir = None

        if not cache_dir:
            self.cache_dir = os.path.join(TFUTILS_HOME,
                                          '%s:%d' % (self.host, self.port),
                                          self.dbname, self.collname,
                                          self.exp_id)
        else:
            self.cache_dir = cache_dir
        if not os.path.isdir(self.cache_dir):
            os.makedirs(self.cache_dir)
コード例 #6
0
    def __init__(self,
                 var_manager=None,
                 params=None,
                 save_params=None,
                 load_params=None,
                 sess=None,
                 global_step=None,
                 cache_dir=None,
                 tfsaver_args=[],
                 tfsaver_kwargs={}):
        """
        :Kwargs:
            - params (dict)
                Describing all parameters of experiment
            - save_params (dict)
            - load_params (dict)
            - sess (tesorflow.Session)
                Object in which to run calculations.  This is required if actual loading/
                saving is going to be done (as opposed to just e.g. getting elements from
                the MongoDB).
            - global_step (tensorflow.Variable)
                Global step variable, the one that is updated by apply_gradients.  This
                is required if being using in a training context.
            - tfsaver_args, tsaver_kwargs
                Additional arguments to be passed onto base Saver class constructor

        """
        self.params = params
        self._skip_check = params.get('skip_check', False)
        if self._skip_check:
            log.warning('Skipping version check and info...')
        self.sonified_params = sonify(self.params, skip=self._skip_check)
        self.save_params = save_params
        self.load_params = load_params
        self.sess = sess
        self.global_step = global_step
        self.tfsaver_args = tfsaver_args
        self.tfsaver_kwargs = tfsaver_kwargs
        self.var_manager = var_manager
        if self.var_manager:
            self.var_list = get_var_list_wo_prefix(params, var_manager)
        else:
            all_vars = tf.global_variables()
            self.var_list = {v.op.name: v for v in all_vars}

        # Set save_params and load_params:
        #   And set these parameters as attributes in this instance
        if save_params is None:
            save_params = {}
        if load_params is None:
            load_params = {}
        location_variables = ['host', 'port', 'dbname', 'collname', 'exp_id']
        for _k in location_variables:
            if _k in save_params:
                sv = save_params[_k]
            else:
                sv = load_params[_k]
            if _k in load_params:
                lv = load_params[_k]
            else:
                lv = save_params[_k]
            setattr(self, _k, sv)
            setattr(self, 'load_' + _k, lv)

        # Determine whether this loading is from the same location as saving
        self.sameloc = all([
            getattr(self, _k) == getattr(self, 'load_' + _k)
            for _k in location_variables
        ])
        if 'query' in load_params \
                and not load_params['query'] is None \
                and 'exp_id' in load_params['query']:
            self.sameloc = \
                    self.sameloc \
                    & (load_params['query']['exp_id'] == self.exp_id)

        # Set some attributes only in save_params
        for _k in [\
                'do_save', 'save_metrics_freq', \
                'save_valid_freq', 'cache_filters_freq', 'cache_max_num', \
                'save_filters_freq', 'save_initial_filters', 'save_to_gfs']:
            setattr(self, _k, save_params.get(_k, DEFAULT_SAVE_PARAMS[_k]))

        # Set some attributes only in load_params
        for _k in [
                'do_restore', 'from_ckpt', 'to_restore', 'load_param_dict',
                'restore_global_step'
        ]:
            setattr(self, _k, load_params.get(_k, DEFAULT_LOAD_PARAMS[_k]))

        self.rec_to_save = None
        self.checkpoint_thread = None
        self.outrecs = []

        # Set the save mongo client
        self.conn = pymongo.MongoClient(host=self.host, port=self.port)
        self.conn.server_info()
        self.collfs = gridfs.GridFS(self.conn[self.dbname], self.collname)

        # Set the cache mongo client
        recent_name = '_'.join(
            [self.dbname, self.collname, self.exp_id, '__RECENT'])
        self.collfs_recent = gridfs.GridFS(self.conn[recent_name])

        self.load_data = None
        load_query = load_params.get('query')
        if load_query is None:
            load_query = {}
        else:
            # Special situation here
            # Users try to load from the same place they try to save through
            # setting the load_query
            # This is not allowed
            if self.sameloc and (not save_params == {}):
                raise Exception(
                        'Loading pointlessly! '\
                        + 'If you want to continue your training, '\
                        + 'please set your load_query to be None!')
            else:
                self.sameloc = False
        if 'exp_id' not in load_query:
            load_query.update({'exp_id': self.load_exp_id})
        self.load_query = load_query

        # Set the load mongo client
        if self.load_host != self.host or self.port != self.load_port:
            self.load_conn = pymongo.MongoClient(host=self.load_host,
                                                 port=self.load_port)
            self.load_conn.server_info()
        else:
            self.load_conn = self.conn
        self.load_collfs = gridfs.GridFS(self.load_conn[self.load_dbname],
                                         self.load_collname)
        # Set the cache mongo client for loading
        load_recent_name = '_'.join([
            self.load_dbname, self.load_collname, self.load_exp_id, '__RECENT'
        ])
        self.load_collfs_recent = gridfs.GridFS(
            self.load_conn[load_recent_name])

        # Set the cache_dir: where to put local cache files
        # use cache_dir from load params if save_params not given
        if (save_params == {}) and ('cache_dir' in load_params):
            cache_dir = load_params['cache_dir']
        elif 'cache_dir' in save_params:
            cache_dir = save_params['cache_dir']
        else:
            cache_dir = None

        if not cache_dir:
            self.cache_dir = os.path.join(TFUTILS_HOME,
                                          '%s:%d' % (self.host, self.port),
                                          self.dbname, self.collname,
                                          self.exp_id)
        else:
            self.cache_dir = cache_dir
        if not os.path.isdir(self.cache_dir):
            os.makedirs(self.cache_dir)