コード例 #1
0
def validate_data_dir(path_dir, md5):
  if not os.path.exists(path_dir):
    os.makedirs(path_dir)
  elif md5_folder(path_dir) != md5:
    shutil.rmtree(path_dir)
    print(f"MD5 preprocessed at {path_dir} mismatch, remove and override!")
    os.makedirs(path_dir)
  return path_dir
コード例 #2
0
 def _run(self, cfg: DictConfig):
     # the cfg is dispatched by hydra.run_job, we couldn't change anything here
     logger = LOGGER
     with warnings.catch_warnings():
         warnings.filterwarnings('ignore', category=DeprecationWarning)
         # prepare the paths
         model_path = self.get_model_path(cfg)
         md5_path = model_path + '.md5'
         # check the configs
         config_path = self.get_config_path(cfg, datetime=True)
         with open(config_path, 'w') as f:
             OmegaConf.save(cfg, f)
         logger.info("Save config: %s" % config_path)
         # load data
         self.on_load_data(cfg)
         logger.info("Loaded data")
         # create or load model
         if os.path.exists(model_path) and len(os.listdir(model_path)) > 0:
             # check if the loading model is consistent with the saved model
             if os.path.exists(md5_path) and self.consistent_model:
                 md5_loaded = md5_folder(model_path)
                 with open(md5_path, 'r') as f:
                     md5_saved = f.read().strip()
                 assert md5_loaded == md5_saved, \
                   "MD5 of saved model mismatch, probably files are corrupted"
             model = self.on_load_model(cfg, model_path)
             if model is None:
                 raise RuntimeError(
                     "The implementation of on_load_model must return the loaded model."
                 )
             logger.info("Loaded model: %s" % model_path)
         else:
             self.on_create_model(cfg)
             logger.info("Create model: %s" % model_path)
         # training
         self.on_train(cfg, model_path)
         logger.info("Finish training")
         # saving the model hash
         if os.path.exists(model_path) and len(os.listdir(model_path)) > 0:
             with open(md5_path, 'w') as f:
                 f.write(md5_folder(model_path))
             logger.info("Save model:%s" % model_path)
コード例 #3
0
  def __init__(self,
               path="~/tensorflow_datasets/lego_faces",
               image_size=64,
               background_threshold=255):
    super().__init__()
    path = os.path.abspath(os.path.expanduser(path))
    if not os.path.exists(path):
      os.makedirs(path)
    ### download metadata
    meta_path = os.path.join(path, 'meta.csv')
    if not os.path.exists(meta_path):
      print("Download lego faces metadata ...")
      meta_path, _ = urlretrieve(url=LegoFaces.METADATA, filename=meta_path)
    import pandas as pd
    metadata = pd.read_csv(meta_path)
    metadata = metadata[metadata["Category Name"] == "Minifigure, Head"]
    ### check downloaded images
    image_folder = os.path.join(path, "dataset")
    if os.path.exists(image_folder):
      if md5_folder(image_folder) != LegoFaces.MD5:
        shutil.rmtree(image_folder)
    ### download data
    zip_path = os.path.join(path, "dataset.zip")
    if not os.path.exists(zip_path):
      print("Download zip lego faces dataset ...")
      zip_path, _ = urlretrieve(url=LegoFaces.DATASET, filename=zip_path)
    if not os.path.exists(image_folder):
      with zipfile.ZipFile(zip_path, mode="r") as f:
        print("Extract all lego faces images ...")
        f.extractall(path)
    ### load all images, downsample if necessary
    images = glob.glob(image_folder + '/*.jpg', recursive=True)
    if image_size != 128:
      image_folder = image_folder + '_%d' % int(image_size)
      if not os.path.exists(image_folder):
        os.mkdir(image_folder)
      if len(os.listdir(image_folder)) != len(images):
        shutil.rmtree(image_folder)
        os.mkdir(image_folder)
        from tqdm import tqdm
        images = [
            i for i in tqdm(MPI(jobs=images,
                                func=partial(_resize,
                                             image_size=image_size,
                                             outpath=image_folder),
                                ncpu=3,
                                batch=1),
                            total=len(images),
                            desc="Resizing images to %d" % image_size)
        ]
      else:
        images = glob.glob(image_folder + '/*.jpg', recursive=True)
    ### extract the heuristic factors
    metadata = {
        part_id: desc
        for part_id, desc in zip(metadata["Number"], metadata["Name"])
    }
    images_desc = {}
    for path in images:
      name = os.path.basename(path)[:-4]
      if name in metadata:
        desc = metadata[name]
      else:
        name = name.split('_')
        desc = metadata[name[0]]
      images_desc[path] = _process_desc(desc)
    ### tokenizing the description
    from PIL import Image

    def imread(p):
      img = Image.open(p, mode='r')
      arr = np.array(img, dtype=np.uint8)
      del img
      return arr

    self.image_size = image_size
    self.images = np.stack(
        [i for i in MPI(jobs=images, func=imread, ncpu=2, batch=1)])
    self.factors = _extract_factors(list(images_desc.keys()),
                                    list(images_desc.values()))
    ### remove images with background
    ids = np.array([
        True if np.min(i) <= int(background_threshold) else False
        for i in self.images
    ])
    self.images = self.images[ids]
    self.factors = self.factors[ids]
    ### split the dataset
    rand = np.random.RandomState(seed=1)
    n = len(self.images)
    ids = rand.permutation(n)
    self.train = (self.images[:int(0.8 * n)], self.factors[:int(0.8 * n)])
    self.valid = (self.images[int(0.8 * n):int(0.9 * n)],
                  self.factors[int(0.8 * n):int(0.9 * n)])
    self.test = (self.images[int(0.9 * n):], self.factors[int(0.9 * n):])
コード例 #4
0
ファイル: net_utils.py プロジェクト: trungnt13/odin-ai
def download_and_extract(path,
                         url,
                         extract=True,
                         md5_download=None,
                         md5_extract=None):
  r""" Download a file to given path then extract the file

  Arguments:
    path : a String path to a folder
    url : a String of download URL
    extract : a Boolean, if True decompress the file
  """
  from tqdm import tqdm
  path = os.path.abspath(os.path.expanduser(path))
  if not os.path.exists(path):
    os.makedirs(path)
  assert os.path.isdir(path), "path to '%s' is not a directory" % path
  ### file name
  filename = url.split('/')[-1]
  filepath = os.path.join(path, filename)
  ### download
  if os.path.exists(filepath) and md5_download is not None:
    md5 = md5_checksum(filepath)
    if md5 != md5_download:
      print("MD5 of downloaded file mismatch! downloaded:%s  provided:%s" %
            (md5, md5_download))
      os.remove(filepath)
  if not os.path.exists(filepath):
    prog = tqdm(desc="Download '%s'" % filename, total=-1, unit="MB")

    def _progress(count, block_size, total_size):
      # to MB
      total_size = total_size / 1024. / 1024.
      block_size = block_size / 1024. / 1024.
      if prog.total < 0:
        prog.total = total_size
      prog.update(block_size)

    filepath, _ = urlretrieve(url, filepath, reporthook=_progress)
  ### no extraction needed
  if not extract:
    return filepath
  ### extract
  extract_path = os.path.join(path, os.path.basename(filename).split('.')[0])
  if os.path.exists(extract_path) and md5_extract is not None:
    md5 = md5_folder(extract_path)
    if md5 != md5_extract:
      print("MD5 extracted folder mismatch! extracted:%s provided:%s" %
            (md5, md5_extract))
      shutil.rmtree(extract_path)
  if not os.path.exists(extract_path):
    # .tar.gz
    if '.tar.gz' in filepath:
      with tarfile.open(filepath, 'r:gz') as f:
        print("Extracting files ...")
        f.extractall(path)
    # .zip
    elif '.zip' in filepath:
      # TODO
      raise NotImplementedError
    # unknown extension
    else:
      raise NotImplementedError("Cannot extract file: %s" % filepath)
  ### return
  return path, extract_path