def version_info(module): """Get version of a standard python module. Args: module (module): python module object to get version info for. Returns: dict: dictionary of version info. """ if hasattr(module, '__version__'): version = module.__version__ elif hasattr(module, 'VERSION'): version = module.VERSION else: pkgname = module.__name__.split('.')[0] try: info = pkg_resources.get_distribution(pkgname) except (pkg_resources.DistributionNotFound, pkg_resources.RequirementParseError): version = None log.warning( 'version information not found for %s -- what package is this from?' % module.__name__) else: version = info.version return {'version': version}
def sync_with_host(self): if self.checkpoint_thread is not None: try: self.checkpoint_coord.join([self.checkpoint_thread]) except Exception as error: log.warning('A checkpoint thead raised an exception ' 'while saving a checkpoint.') log.error(error) raise else: self.checkpoint_thread = None
def git_info(repo): """Return information about a git repo. Args: repo (git.Repo): The git repo to be investigated. Returns: dict: Git repo information """ if repo.is_dirty(): log.warning('repo %s is dirty -- having committment issues?' % repo.git_dir) clean = False else: clean = True branchname = repo.active_branch.name commit = repo.active_branch.commit.hexsha origin = repo.remote('origin') urls = map(str, list(origin.urls)) remote_ref = [ _r for _r in origin.refs if _r.name == 'origin/' + branchname ] if not len(remote_ref) > 0: log.warning('Active branch %s not in origin ref' % branchname) active_branch_in_origin = False commit_in_log = False else: active_branch_in_origin = True remote_ref = remote_ref[0] gitlog = remote_ref.log() shas = [_r.oldhexsha for _r in gitlog] + \ [_r.newhexsha for _r in gitlog] if commit not in shas: log.warning('Commit %s not in remote origin log for branch %s' % (commit, branchname)) commit_in_log = False else: commit_in_log = True info = { 'git_dir': repo.git_dir, 'active_branch': branchname, 'commit': commit, 'remote_urls': urls, 'clean': clean, 'active_branch_in_origin': active_branch_in_origin, 'commit_in_log': commit_in_log } return info
def load_from_db(self, query, cache_filters=False, collfs=None, collfs_recent=None): """Load checkpoint from the database. Checks the recent and regular checkpoint fs to find the latest one matching the query. Returns the GridOut obj corresponding to the record. Args: query: dict expressing MongoDB query """ if collfs is None: collfs = self.collfs coll = collfs._GridFS__files if collfs_recent is None: collfs_recent = self.collfs_recent coll_recent = collfs_recent._GridFS__files query['saved_filters'] = True count = collfs.find(query).count() if count > 0: # get latest that matches query ckpt_record = coll.find(query, sort=[('uploadDate', -1)])[0] loading_from = coll else: ckpt_record = None try: count_recent = collfs_recent.find(query).count() except Exception as inst: raise er.OperationFailure( inst.args[0] + "\n Is your dbname too long? Mongo requires that dbnames be no longer than 64 characters." ) if count_recent > 0: # get latest that matches query ckpt_record_recent = coll_recent.find(query, sort=[('uploadDate', -1)])[0] # use the record with latest timestamp if ckpt_record is None or ckpt_record_recent[ 'uploadDate'] > ckpt_record['uploadDate']: loading_from = coll_recent ckpt_record = ckpt_record_recent if count + count_recent == 0: # no matches for query log.warning('No matching checkpoint for query "{}"'.format( repr(query))) return database = loading_from._Collection__database log.info('Loading checkpoint from %s' % loading_from.full_name) if cache_filters: filename = os.path.basename(ckpt_record['filename']) cache_filename = os.path.join(self.cache_dir, filename) # check if there is no local copy if not os.path.isfile(cache_filename): log.info('No cache file at %s, loading from DB' % cache_filename) # create new file to write from gridfs load_dest = open(cache_filename, "w+") load_dest.close() load_dest = open(cache_filename, 'rwb+') fsbucket = gridfs.GridFSBucket( database, bucket_name=loading_from.name.split('.')[0]) fsbucket.download_to_stream(ckpt_record['_id'], load_dest) load_dest.close() if ckpt_record[ '_saver_write_version'] == saver_pb2.SaverDef.V2: assert cache_filename.endswith('.tar') tar = tarfile.open(cache_filename) tar.extractall(path=self.cache_dir) tar.close() cache_filename = os.path.splitext(cache_filename)[0] verify_pb2_v2_files(cache_filename, ckpt_record) else: if ckpt_record[ '_saver_write_version'] == saver_pb2.SaverDef.V2: cache_filename = os.path.splitext(cache_filename)[0] verify_pb2_v2_files(cache_filename, ckpt_record) log.info('Cache file found at %s, using that to load' % cache_filename) else: cache_filename = None return ckpt_record, cache_filename
def __init__(self, params=None, save_params=None, load_params=None, sess=None, global_step=None, cache_dir=None, *tfsaver_args, **tfsaver_kwargs): """ :Kwargs: - params (dict) Describing all parameters of experiment - save_params (dict) - load_params (dict) - sess (tesorflow.Session) Object in which to run calculations. This is required if actual loading/ saving is going to be done (as opposed to just e.g. getting elements from the MongoDB). - global_step (tensorflow.Variable) Global step variable, the one that is updated by apply_gradients. This is required if being using in a training context. - *tfsaver_args, **tsaver_kwargs Additional arguments to be passed onto base Saver class constructor """ self.params = params self._skip_check = params.get('skip_check', False) if self._skip_check: log.warning('Skipping version check and info...') self.sonified_params = sonify(self.params, skip=self._skip_check) self.save_params = save_params self.load_params = load_params self.sess = sess self.global_step = global_step self.tfsaver_args = tfsaver_args self.tfsaver_kwargs = tfsaver_kwargs self.var_list = tfsaver_kwargs.get('var_list', None) if save_params is None: save_params = {} if load_params is None: load_params = {} location_variables = ['host', 'port', 'dbname', 'collname', 'exp_id'] for _k in location_variables: if _k in save_params: sv = save_params[_k] else: sv = load_params[_k] if _k in load_params: lv = load_params[_k] else: lv = save_params[_k] setattr(self, _k, sv) setattr(self, 'load_' + _k, lv) self.sameloc = all([ getattr(self, _k) == getattr(self, 'load_' + _k) for _k in location_variables ]) if 'query' in load_params and not load_params[ 'query'] is None and 'exp_id' in load_params['query']: self.sameloc = self.sameloc & (load_params['query']['exp_id'] == self.exp_id) for _k in [ 'do_save', 'save_metrics_freq', 'save_valid_freq', 'cache_filters_freq', 'cache_max_num', 'save_filters_freq', 'save_initial_filters', 'save_to_gfs' ]: setattr(self, _k, save_params.get(_k, DEFAULT_SAVE_PARAMS[_k])) for _k in ['do_restore', 'from_ckpt', 'to_restore', 'load_param_dict']: setattr(self, _k, load_params.get(_k, DEFAULT_LOAD_PARAMS[_k])) self.rec_to_save = None self.checkpoint_thread = None self.outrecs = [] self.conn = pymongo.MongoClient(host=self.host, port=self.port) self.conn.server_info() self.collfs = gridfs.GridFS(self.conn[self.dbname], self.collname) recent_name = '_'.join( [self.dbname, self.collname, self.exp_id, '__RECENT']) self.collfs_recent = gridfs.GridFS(self.conn[recent_name]) self.load_data = None load_query = load_params.get('query') if load_query is None: load_query = {} else: if self.sameloc and (not save_params == {}): raise Exception('Loading pointlessly') else: self.sameloc = False # print('Set sameloc to False!') if 'exp_id' not in load_query: load_query.update({'exp_id': self.load_exp_id}) self.load_query = load_query if self.load_host != self.host or self.port != self.load_port: self.load_conn = pymongo.MongoClient(host=self.load_host, port=self.load_port) self.load_conn.server_info() else: self.load_conn = self.conn self.load_collfs = gridfs.GridFS(self.load_conn[self.load_dbname], self.load_collname) load_recent_name = '_'.join([ self.load_dbname, self.load_collname, self.load_exp_id, '__RECENT' ]) self.load_collfs_recent = gridfs.GridFS( self.load_conn[load_recent_name]) if (save_params == {}) and ( 'cache_dir' in load_params ): # use cache_dir from load params if save_params not given cache_dir = load_params['cache_dir'] elif 'cache_dir' in save_params: cache_dir = save_params['cache_dir'] else: cache_dir = None if not cache_dir: self.cache_dir = os.path.join(TFUTILS_HOME, '%s:%d' % (self.host, self.port), self.dbname, self.collname, self.exp_id) else: self.cache_dir = cache_dir if not os.path.isdir(self.cache_dir): os.makedirs(self.cache_dir)
def __init__(self, var_manager=None, params=None, save_params=None, load_params=None, sess=None, global_step=None, cache_dir=None, tfsaver_args=[], tfsaver_kwargs={}): """ :Kwargs: - params (dict) Describing all parameters of experiment - save_params (dict) - load_params (dict) - sess (tesorflow.Session) Object in which to run calculations. This is required if actual loading/ saving is going to be done (as opposed to just e.g. getting elements from the MongoDB). - global_step (tensorflow.Variable) Global step variable, the one that is updated by apply_gradients. This is required if being using in a training context. - tfsaver_args, tsaver_kwargs Additional arguments to be passed onto base Saver class constructor """ self.params = params self._skip_check = params.get('skip_check', False) if self._skip_check: log.warning('Skipping version check and info...') self.sonified_params = sonify(self.params, skip=self._skip_check) self.save_params = save_params self.load_params = load_params self.sess = sess self.global_step = global_step self.tfsaver_args = tfsaver_args self.tfsaver_kwargs = tfsaver_kwargs self.var_manager = var_manager if self.var_manager: self.var_list = get_var_list_wo_prefix(params, var_manager) else: all_vars = tf.global_variables() self.var_list = {v.op.name: v for v in all_vars} # Set save_params and load_params: # And set these parameters as attributes in this instance if save_params is None: save_params = {} if load_params is None: load_params = {} location_variables = ['host', 'port', 'dbname', 'collname', 'exp_id'] for _k in location_variables: if _k in save_params: sv = save_params[_k] else: sv = load_params[_k] if _k in load_params: lv = load_params[_k] else: lv = save_params[_k] setattr(self, _k, sv) setattr(self, 'load_' + _k, lv) # Determine whether this loading is from the same location as saving self.sameloc = all([ getattr(self, _k) == getattr(self, 'load_' + _k) for _k in location_variables ]) if 'query' in load_params \ and not load_params['query'] is None \ and 'exp_id' in load_params['query']: self.sameloc = \ self.sameloc \ & (load_params['query']['exp_id'] == self.exp_id) # Set some attributes only in save_params for _k in [\ 'do_save', 'save_metrics_freq', \ 'save_valid_freq', 'cache_filters_freq', 'cache_max_num', \ 'save_filters_freq', 'save_initial_filters', 'save_to_gfs']: setattr(self, _k, save_params.get(_k, DEFAULT_SAVE_PARAMS[_k])) # Set some attributes only in load_params for _k in [ 'do_restore', 'from_ckpt', 'to_restore', 'load_param_dict', 'restore_global_step' ]: setattr(self, _k, load_params.get(_k, DEFAULT_LOAD_PARAMS[_k])) self.rec_to_save = None self.checkpoint_thread = None self.outrecs = [] # Set the save mongo client self.conn = pymongo.MongoClient(host=self.host, port=self.port) self.conn.server_info() self.collfs = gridfs.GridFS(self.conn[self.dbname], self.collname) # Set the cache mongo client recent_name = '_'.join( [self.dbname, self.collname, self.exp_id, '__RECENT']) self.collfs_recent = gridfs.GridFS(self.conn[recent_name]) self.load_data = None load_query = load_params.get('query') if load_query is None: load_query = {} else: # Special situation here # Users try to load from the same place they try to save through # setting the load_query # This is not allowed if self.sameloc and (not save_params == {}): raise Exception( 'Loading pointlessly! '\ + 'If you want to continue your training, '\ + 'please set your load_query to be None!') else: self.sameloc = False if 'exp_id' not in load_query: load_query.update({'exp_id': self.load_exp_id}) self.load_query = load_query # Set the load mongo client if self.load_host != self.host or self.port != self.load_port: self.load_conn = pymongo.MongoClient(host=self.load_host, port=self.load_port) self.load_conn.server_info() else: self.load_conn = self.conn self.load_collfs = gridfs.GridFS(self.load_conn[self.load_dbname], self.load_collname) # Set the cache mongo client for loading load_recent_name = '_'.join([ self.load_dbname, self.load_collname, self.load_exp_id, '__RECENT' ]) self.load_collfs_recent = gridfs.GridFS( self.load_conn[load_recent_name]) # Set the cache_dir: where to put local cache files # use cache_dir from load params if save_params not given if (save_params == {}) and ('cache_dir' in load_params): cache_dir = load_params['cache_dir'] elif 'cache_dir' in save_params: cache_dir = save_params['cache_dir'] else: cache_dir = None if not cache_dir: self.cache_dir = os.path.join(TFUTILS_HOME, '%s:%d' % (self.host, self.port), self.dbname, self.collname, self.exp_id) else: self.cache_dir = cache_dir if not os.path.isdir(self.cache_dir): os.makedirs(self.cache_dir)