def validate_data_dir(path_dir, md5): if not os.path.exists(path_dir): os.makedirs(path_dir) elif md5_folder(path_dir) != md5: shutil.rmtree(path_dir) print(f"MD5 preprocessed at {path_dir} mismatch, remove and override!") os.makedirs(path_dir) return path_dir
def _run(self, cfg: DictConfig): # the cfg is dispatched by hydra.run_job, we couldn't change anything here logger = LOGGER with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) # prepare the paths model_path = self.get_model_path(cfg) md5_path = model_path + '.md5' # check the configs config_path = self.get_config_path(cfg, datetime=True) with open(config_path, 'w') as f: OmegaConf.save(cfg, f) logger.info("Save config: %s" % config_path) # load data self.on_load_data(cfg) logger.info("Loaded data") # create or load model if os.path.exists(model_path) and len(os.listdir(model_path)) > 0: # check if the loading model is consistent with the saved model if os.path.exists(md5_path) and self.consistent_model: md5_loaded = md5_folder(model_path) with open(md5_path, 'r') as f: md5_saved = f.read().strip() assert md5_loaded == md5_saved, \ "MD5 of saved model mismatch, probably files are corrupted" model = self.on_load_model(cfg, model_path) if model is None: raise RuntimeError( "The implementation of on_load_model must return the loaded model." ) logger.info("Loaded model: %s" % model_path) else: self.on_create_model(cfg) logger.info("Create model: %s" % model_path) # training self.on_train(cfg, model_path) logger.info("Finish training") # saving the model hash if os.path.exists(model_path) and len(os.listdir(model_path)) > 0: with open(md5_path, 'w') as f: f.write(md5_folder(model_path)) logger.info("Save model:%s" % model_path)
def __init__(self, path="~/tensorflow_datasets/lego_faces", image_size=64, background_threshold=255): super().__init__() path = os.path.abspath(os.path.expanduser(path)) if not os.path.exists(path): os.makedirs(path) ### download metadata meta_path = os.path.join(path, 'meta.csv') if not os.path.exists(meta_path): print("Download lego faces metadata ...") meta_path, _ = urlretrieve(url=LegoFaces.METADATA, filename=meta_path) import pandas as pd metadata = pd.read_csv(meta_path) metadata = metadata[metadata["Category Name"] == "Minifigure, Head"] ### check downloaded images image_folder = os.path.join(path, "dataset") if os.path.exists(image_folder): if md5_folder(image_folder) != LegoFaces.MD5: shutil.rmtree(image_folder) ### download data zip_path = os.path.join(path, "dataset.zip") if not os.path.exists(zip_path): print("Download zip lego faces dataset ...") zip_path, _ = urlretrieve(url=LegoFaces.DATASET, filename=zip_path) if not os.path.exists(image_folder): with zipfile.ZipFile(zip_path, mode="r") as f: print("Extract all lego faces images ...") f.extractall(path) ### load all images, downsample if necessary images = glob.glob(image_folder + '/*.jpg', recursive=True) if image_size != 128: image_folder = image_folder + '_%d' % int(image_size) if not os.path.exists(image_folder): os.mkdir(image_folder) if len(os.listdir(image_folder)) != len(images): shutil.rmtree(image_folder) os.mkdir(image_folder) from tqdm import tqdm images = [ i for i in tqdm(MPI(jobs=images, func=partial(_resize, image_size=image_size, outpath=image_folder), ncpu=3, batch=1), total=len(images), desc="Resizing images to %d" % image_size) ] else: images = glob.glob(image_folder + '/*.jpg', recursive=True) ### extract the heuristic factors metadata = { part_id: desc for part_id, desc in zip(metadata["Number"], metadata["Name"]) } images_desc = {} for path in images: name = os.path.basename(path)[:-4] if name in metadata: desc = metadata[name] else: name = name.split('_') desc = metadata[name[0]] images_desc[path] = _process_desc(desc) ### tokenizing the description from PIL import Image def imread(p): img = Image.open(p, mode='r') arr = np.array(img, dtype=np.uint8) del img return arr self.image_size = image_size self.images = np.stack( [i for i in MPI(jobs=images, func=imread, ncpu=2, batch=1)]) self.factors = _extract_factors(list(images_desc.keys()), list(images_desc.values())) ### remove images with background ids = np.array([ True if np.min(i) <= int(background_threshold) else False for i in self.images ]) self.images = self.images[ids] self.factors = self.factors[ids] ### split the dataset rand = np.random.RandomState(seed=1) n = len(self.images) ids = rand.permutation(n) self.train = (self.images[:int(0.8 * n)], self.factors[:int(0.8 * n)]) self.valid = (self.images[int(0.8 * n):int(0.9 * n)], self.factors[int(0.8 * n):int(0.9 * n)]) self.test = (self.images[int(0.9 * n):], self.factors[int(0.9 * n):])
def download_and_extract(path, url, extract=True, md5_download=None, md5_extract=None): r""" Download a file to given path then extract the file Arguments: path : a String path to a folder url : a String of download URL extract : a Boolean, if True decompress the file """ from tqdm import tqdm path = os.path.abspath(os.path.expanduser(path)) if not os.path.exists(path): os.makedirs(path) assert os.path.isdir(path), "path to '%s' is not a directory" % path ### file name filename = url.split('/')[-1] filepath = os.path.join(path, filename) ### download if os.path.exists(filepath) and md5_download is not None: md5 = md5_checksum(filepath) if md5 != md5_download: print("MD5 of downloaded file mismatch! downloaded:%s provided:%s" % (md5, md5_download)) os.remove(filepath) if not os.path.exists(filepath): prog = tqdm(desc="Download '%s'" % filename, total=-1, unit="MB") def _progress(count, block_size, total_size): # to MB total_size = total_size / 1024. / 1024. block_size = block_size / 1024. / 1024. if prog.total < 0: prog.total = total_size prog.update(block_size) filepath, _ = urlretrieve(url, filepath, reporthook=_progress) ### no extraction needed if not extract: return filepath ### extract extract_path = os.path.join(path, os.path.basename(filename).split('.')[0]) if os.path.exists(extract_path) and md5_extract is not None: md5 = md5_folder(extract_path) if md5 != md5_extract: print("MD5 extracted folder mismatch! extracted:%s provided:%s" % (md5, md5_extract)) shutil.rmtree(extract_path) if not os.path.exists(extract_path): # .tar.gz if '.tar.gz' in filepath: with tarfile.open(filepath, 'r:gz') as f: print("Extracting files ...") f.extractall(path) # .zip elif '.zip' in filepath: # TODO raise NotImplementedError # unknown extension else: raise NotImplementedError("Cannot extract file: %s" % filepath) ### return return path, extract_path