예제 #1
0
    def ls(self, path: str, recursive=False) -> List[File]:
        def _get_file_stats(path: str):
            stat = gfile.stat(path)
            return File(path=path,
                        size=stat.length,
                        mtime=int(stat.mtime_nsec / 1e9))

        if not gfile.exists(path):
            return []
        # If it is a file
        if not gfile.isdir(path):
            return [_get_file_stats(path)]

        files = []
        if recursive:
            for root, _, res in gfile.walk(path):
                for file in res:
                    if not gfile.isdir(os.path.join(root, file)):
                        files.append(_get_file_stats(os.path.join(root, file)))
        else:
            for file in gfile.listdir(path):
                if not gfile.isdir(os.path.join(path, file)):
                    files.append(_get_file_stats(os.path.join(path, file)))
        # Files only
        return files
예제 #2
0
 def remove(self, path: str) -> bool:
     try:
         if not gfile.isdir(path):
             os.remove(path)
             return True
         if gfile.isdir(path):
             gfile.rmtree(path)
             return True
     except Exception as e:  # pylint: disable=broad-except
         logging.error('Error during remove %s', str(e))
     return False
예제 #3
0
def isdir_remote(path):
    """
    Wrapper to check if remote and local paths are directories
    """
    if is_remote_path(path):
        return gfile.isdir(path)
    return os.path.isdir(path)
예제 #4
0
 def copy(self, source: str, destination: str) -> bool:
     if gfile.isdir(destination):
         # gfile requires a file name for copy destination.
         return gfile.copy(source,
                           os.path.join(destination,
                                        os.path.basename(source)),
                           overwrite=True)
     return gfile.copy(source, destination, overwrite=True)
예제 #5
0
def maybe_load_checkpoint(logdir, optimizer, clobber_checkpoint=False):
  if not clobber_checkpoint:
    if has_checkpoint(logdir):
      print("Loading checkpoint from %s" % logdir)
      optimizer = checkpoints.restore_checkpoint(logdir, optimizer)
      print("Checkpoint loaded from step %d" % optimizer.state.step)
  else:
    if gfile.isdir(logdir):
      gfile.rmtree(logdir)
  return optimizer
예제 #6
0
    def __init__(self, output_path):
        """Creates AtomicInputOutputter.

    Args:
      output_path: directory to write output files to
    """
        self.output_path = output_path
        if output_path and not gfile.isdir(self.output_path):
            raise ValueError(
                'Atomic input requires directory as output path, got {}'.
                format(self.output_path))
        self.atomic_writer = smu_writer_lib.AtomicInputWriter()
예제 #7
0
    def __init__(self, log_dir):
        """Create a new SummaryWriter.

    Args:
      log_dir: path to record tfevents files in.
    """
        # If needed, create log_dir directory as well as missing parent directories.
        if not gfile.isdir(log_dir):
            gfile.makedirs(log_dir)

        self._event_writer = EventFileWriter(log_dir, 10, 120, None)
        self._step = 0
        self._closed = False
예제 #8
0
    def __init__(self, log_dir):
        """Create a new SummaryWriter.

    Args:
      log_dir: path to record tfevents files in.
    """
        # If needed, create log_dir directory as well as missing parent directories.
        if not gfile.isdir(log_dir):
            gfile.makedirs(log_dir)

        self.writer = tf.summary.FileWriter(log_dir, graph=None)
        self.end_summaries = []
        self.step = 0
        self.closed = False
예제 #9
0
  def __init__(self, log_dir, enable=True):
    """Create a new SummaryWriter.

    Args:
      log_dir: path to record tfevents files in.
      enable: bool: if False don't actually write or flush data.  Used in
        multihost training.
    """
    # If needed, create log_dir directory as well as missing parent directories.
    if not gfile.isdir(log_dir):
      gfile.makedirs(log_dir)

    self._event_writer = EventFileWriter(log_dir, 10, 120, None)
    self._step = 0
    self._closed = False
    self._enabled = enable
예제 #10
0
def copy(src: str, dst: str, overwrite: bool = False) -> str:
    if gfile.isdir(dst):
        dst = os.path.join(dst, os.path.basename(src))
    gfile.copy(src, dst, overwrite=overwrite)

    return dst
예제 #11
0
def restore_checkpoint(ckpt_dir,
                       target,
                       step=None,
                       prefix='checkpoint_',
                       parallel=True):
    """Restore last/best checkpoint from checkpoints in path.

  Sorts the checkpoint files naturally, returning the highest-valued
  file, e.g.:
    ckpt_1, ckpt_2, ckpt_3 --> ckpt_3
    ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1
    ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

  Args:
    ckpt_dir: str: checkpoint file or directory of checkpoints to restore from.
    target: matching object to rebuild via deserialized state-dict. If None,
      the deserialized state-dict is returned as-is.
    step: int: step number to load or None to load latest. If specified,
      ckpt_dir must be a directory.
    prefix: str: name prefix of checkpoint files.
    parallel: bool: whether to load seekable checkpoints in parallel, for speed.

  Returns:
    Restored `target` updated from checkpoint file, or if no step specified and
    no checkpoint files present, returns the passed-in `target` unchanged.
    If a file path is specified and is not found, the passed-in `target` will be
    returned. This is to match the behavior of the case where a directory path
    is specified but the directory has not yet been created.
  """
    if step:
        ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
        if not gfile.exists(ckpt_path):
            raise ValueError(f'Matching checkpoint not found: {ckpt_path}')
    else:
        if gfile.isdir(ckpt_dir):
            ckpt_path = latest_checkpoint(ckpt_dir, prefix)
            if not ckpt_path:
                logging.info(f'Found no checkpoint files in {ckpt_dir}')
                return target
        else:
            ckpt_path = ckpt_dir
            if not gfile.exists(ckpt_path):
                logging.info(f'Found no checkpoint file at {ckpt_path}')
                return target

    logging.info('Restoring checkpoint from %s', ckpt_path)
    with gfile.GFile(ckpt_path, 'rb') as fp:
        if parallel and fp.seekable():
            buf_size = 128 << 20  # 128M buffer.
            num_bufs = fp.size() / buf_size
            logging.debug('num_bufs: %d', num_bufs)
            checkpoint_contents = bytearray(fp.size())

            def read_chunk(i):
                # NOTE: We have to re-open the file to read each chunk, otherwise the
                # parallelism has no effect. But we could reuse the file pointers
                # within each thread.
                with gfile.GFile(ckpt_path, 'rb') as f:
                    f.seek(i * buf_size)
                    buf = f.read(buf_size)
                    if buf:
                        checkpoint_contents[i * buf_size:i * buf_size +
                                            len(buf)] = buf
                    return len(buf) / buf_size

            pool_size = 32
            pool = thread.ThreadPoolExecutor(pool_size)
            results = pool.map(read_chunk, range(int(num_bufs) + 1))
            results = list(results)
            pool.shutdown(wait=False)
            logging.debug('results: %s', results)
        else:
            checkpoint_contents = fp.read()

        if target is None:
            return serialization.msgpack_restore(checkpoint_contents)
        else:
            return serialization.from_bytes(target, checkpoint_contents)
예제 #12
0
 def remove(self, path: str) -> bool:
     if not gfile.isdir(path):
         return os.remove(path)
     return gfile.rmtree(path)
예제 #13
0
def main(unused_argv):
    workdir = FLAGS.workdir

    if not gfile.isdir(workdir):
        gfile.makedirs(workdir)

    tf.random.set_seed(FLAGS.random_seed)
    np.random.seed(FLAGS.random_seed)
    data = get_dataset(FLAGS.dataset,
                       FLAGS.batchsize,
                       to_grayscale=FLAGS.grayscale,
                       train_fraction=FLAGS.train_fraction,
                       random_seed=FLAGS.random_seed,
                       augment=FLAGS.augment_traindata)

    # Figure out TPU related stuff and create distribution strategy
    use_remote_eager = FLAGS.master and FLAGS.master != 'local'
    if FLAGS.use_tpu:
        logging.info("Use TPU at %s with job name '%s'.", FLAGS.master,
                     FLAGS.tpu_job_name)
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.master, job_name=FLAGS.tpu_job_name)
        if use_remote_eager:
            tf.config.experimental_connect_to_cluster(resolver)
            logging.warning(
                'Remote eager configured. Remote eager can be slow.')
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)
    else:
        if use_remote_eager:
            tf.config.experimental_connect_to_host(FLAGS.master,
                                                   job_name='gpu_worker')
            logging.warning(
                'Remote eager configured. Remote eager can be slow.')
        gpus = tf.config.experimental.list_logical_devices(device_type='GPU')
        if gpus:
            logging.info('Found GPUs: %s', gpus)
            strategy = tf.distribute.MirroredStrategy()
        else:
            logging.info('Devices: %s', tf.config.list_logical_devices())
            strategy = tf.distribute.OneDeviceStrategy('CPU')
    logging.info('Devices: %s', tf.config.list_logical_devices())
    logging.info('Distribution strategy: %s', strategy)
    logging.info('Model directory: %s', workdir)

    run(workdir,
        data,
        strategy,
        architecture=FLAGS.dnn_architecture,
        n_layers=FLAGS.num_layers,
        n_hiddens=FLAGS.num_units,
        activation=FLAGS.activation,
        dropout_rate=FLAGS.dropout,
        l2_penalty=FLAGS.l2reg,
        w_init_name=FLAGS.w_init,
        b_init_name=FLAGS.b_init,
        optimizer_name=FLAGS.optimizer,
        learning_rate=FLAGS.learning_rate,
        n_epochs=FLAGS.epochs,
        epochs_between_checkpoints=FLAGS.epochs_between_checkpoints,
        init_stddev=FLAGS.init_std,
        cnn_stride=FLAGS.cnn_stride,
        reduce_learningrate=FLAGS.reduce_learningrate,
        verbosity=FLAGS.verbose)
예제 #14
0
def has_checkpoint(logdir):
  return (gfile.isdir(logdir) and
          gfile.glob(os.path.join(logdir, "checkpoint_*")))