def __init__(self, model_dir, model_name, experiment_name): self.root = f'{model_dir}/{model_name}-{experiment_name}' self.model_name = model_name self.experiment_name = experiment_name if not file_util.exists(self.root): log.verbose('Regex expanding root to find experiment ID') options = file_util.glob(self.root[:-1] + '*') if len(options) != 1: log.verbose( "Tried to glob for directory but didn't find one path. Found:" ) log.verbose(options) raise ValueError('Directory not found: %s' % self.root) else: self.root = options[0] + '/' self.experiment_name = os.path.basename(self.root.strip('/')) self.experiment_name = self.experiment_name.replace( self.model_name + '-', '') log.verbose('Expanded experiment name with regex to root: %s' % self.root) job_strs = [ os.path.basename(n) for n in file_util.glob(f'{self.root}/*') ] banned = ['log', 'mldash_config.txt', 'snapshot', 'mldash_config'] job_strs = [p for p in job_strs if p not in banned] job_strs = sorted(job_strs, key=to_xid) log.verbose('Job strings: %s' % repr(job_strs)) self.all_jobs = [Job(self, job_str) for job_str in job_strs] self._visible_jobs = self.all_jobs[:]
def all_remote_mesh_names(self, split, ensure_nonempty=True): """Returns remote mesh hashes that are present for all xids.""" if self._remote_mesh_names is not None: return self._remote_mesh_names all_mesh_names = None for xid in self.experiment.visible_xids: base = self.remote_result_base(xid, split) mesh_paths = file_util.glob(f'{base}/*/*.ply') if not mesh_paths and ensure_nonempty: raise ValueError('No meshes present for xid %i with path %s' % (xid, base)) # TODO(kgenova) Now we are assuming hashes are not replicated in # multiple synsets. mesh_names = set() for mesh_path in mesh_paths: mesh_hash = os.path.basename(mesh_path).replace('.ply', '') synset = mesh_path.split('/')[-2] mesh_names.add('%s-%s' % (synset, mesh_hash)) if all_mesh_names is None: all_mesh_names = mesh_names else: all_mesh_names = all_mesh_names.intersection(mesh_names) if ensure_nonempty and not all_mesh_names: raise ValueError( 'There are 0 meshes common to the xids %i for split %s' % (repr(self.experiment.visible_xids), split)) self._remote_mesh_names = list(all_mesh_names) return list(all_mesh_names)
def get_result_path(xid): """Generates the result path associated with the requested XID.""" # TODO(ldif-user) Set up the result path: base = FLAGS.input_dir + '/ROOT%i-*00000-*' % xid matches = file_util.glob(base) assert len(matches) >= 1 ckpts = [] for match in matches: # TODO(ldif-user) Set the file extension extension = None ckpt = int(match.split(extension)[0].split('-')[-1]) ckpts.append(ckpt) if len(ckpts) > 1 and not FLAGS.use_newest: log.info( 'Found multiple checkpoint matches for %s and --nouse_newest: %s' % (base, repr(ckpts))) if len(ckpts) == 1: ckpt = ckpts[0] elif len(ckpts) > 1: ckpts.sort() ckpt = ckpts[-1] log.info('Found multiple checkpoint matches %s, using %s' % (repr(ckpts), repr(ckpt))) # TODO(ldif-user) Set up the result path: path = FLAGS.input_dir + '/ROOT%i-%i.*' path = path % (xid, ckpt) return path
def ckpts_for_xid_and_split(self, xid, split): """Returns the checkpoints with results for a given XID-split pair.""" key = '%s-%s' % (str(xid), split) if key in self.available_ckpt_dict: return self.available_ckpt_dict[key] base = self.remote_result_ckpt_dir(xid) candidates = file_util.glob(f'{base}/*/{split}') ckpts = [int(x.split('/')[-2]) for x in candidates] self.available_ckpt_dict[key] = ckpts return ckpts
def all_checkpoints(self): """A list of all checkpoint objects in the checkpoint directory.""" if self._all_checkpoints is None: candidates = [ os.path.basename(n) for n in file_util.glob(f'{self.ckpt_dir}/*') ] inds = [ int(x.replace('model.ckpt-', '').replace('.index', '')) for x in candidates if 'index' in x and 'tempstate' not in x ] inds.sort(reverse=False) # The train task may delete the 5 most recent checkpoints periodically: if not inds: raise ValueError( 'There are no checkpoints in the directory %s.' % self.ckpt_dir) elif self._use_temp_ckpts is True: # pylint: disable=g-bool-id-comparison message = 'Temporary checkpoints are enabled.' message += ' The most recent temporary checkpoint is %i.' % inds[ -1] if len(inds) >= 6: message += ' The most recent permanent checkpoint is %i.' % inds[ -6] else: message += ' There are no permanent checkpoints.' log.warning(message) elif len(inds) < 6: if self._use_temp_ckpts == 'warn': warning = ( 'No permanent checkpoints. Resorting to temporary one: %i' % inds[-1]) log.warning(warning) inds = [inds[-1]] elif not self._use_temp_ckpts: raise ValueError( 'Only temporary checkpoints are available, and they are not enabled.' ) else: inds = inds[:-5] self._all_checkpoints = [Checkpoint(self, ind) for ind in inds] return self._all_checkpoints