def main(): p = argparse.ArgumentParser() p.add_argument('log_dir_root', help='Path to dir containing log_dirs.') p.add_argument('job_ids', help='Comma separated list of job_ids.') p.add_argument('images') p.add_argument('--save_ours', '-o', action='store_const', const=True, help='If given, store output images in VAL_OUT/imgs.') p.add_argument('--how_many', type=int, help='Number of images to output') p.add_argument('--image_cache_max', '-cache', type=int, default=500, help='Cache max in [MB]. Set to 0 to disable.') p.add_argument('--restore_itr', '-i', type=int) p.add_argument( '--ckpt_step', '-s', type=int, default=2, help= 'Every CKPT_STEP-th checkpoint will be validated. Set to 1 to validate all of them. ' 'Last checkpoint will always be validated. Set to -1 to only validate last.' ) p.add_argument('--reset', action='store_const', const=True, help='Remove previous output') p.add_argument( '--real_bpp', action='store_const', const=True, help= 'If given, calculate real bpp using arithmetic encoding. Note: in our experiments, ' 'this matches the theoretical bpp up to 1% precision. Note: this is very slow.' ) flags, unknown_flags = p.parse_known_args() if unknown_flags: print('Unknown flags: {}'.format(unknown_flags)) image_paths, dataset_name = val_images.get_image_paths(flags.images) for ckpt_dir in logdir_helpers.iter_ckpt_dirs(flags.log_dir_root, flags.job_ids): try: validate( ValidationDirs(ckpt_dir, flags.log_dir_root, dataset_name, flags.reset), ImagesIterator(image_paths[:flags.how_many], dataset_name, flags.image_cache_max), OutputFlags(flags.save_ours, flags.ckpt_step, flags.real_bpp)) except tf.errors.NotFoundError as e: # happens if ckpt was deleted while validation print('*** Caught {}'.format(e)) continue tf.reset_default_graph() print('*** All given job_ids validated.')
def get_measures_readers(log_dir_root, job_ids, dataset): if job_ids == 'NA': # TODO return [] missing = [] measures_readers = [] for job_id, ckpt_dir in zip(job_ids.split(','), logdir_helpers.iter_ckpt_dirs(log_dir_root, job_ids)): val_dirs = val_files.ValidationDirs(ckpt_dir, log_dir_root, dataset) try: measures_reader = val_files.MeasuresReader(val_dirs.out_dir) measures_readers.append(measures_reader) except FileNotFoundError: missing.append(job_id) if missing: print('Missing measures files for:\n{}'.format(','.join(missing))) # uniquify m = [val_files.MeasuresReader(o) for o in {m.out_dir for m in measures_readers}] return m