Exemple #1
0
def create_split(dataset_builder: tfds.core.DatasetBuilder,
                 batch_size: int,
                 train: bool,
                 dtype: tf.DType = tf.float32,
                 image_size: int = IMAGE_SIZE,
                 cache: bool = False):
    """Creates a split from the ImageNet dataset using TensorFlow Datasets.

  Args:
    dataset_builder: TFDS dataset builder for ImageNet.
    batch_size: the batch size returned by the data pipeline.
    train: Whether to load the train or evaluation split.
    dtype: data type of the image (default: float32).
    image_size: The target size of the images (default: 224).
    cache: Whether to cache the dataset (default: False).
  Returns:
    A `tf.data.Dataset`.
  """
    if train:
        train_size = dataset_builder.info.splits['train'].num_examples
        split_size = train_size // jax.host_count()
        start = jax.host_id() * split_size
        split = 'train[{}:{}]'.format(start, start + split_size)
    else:
        validation_size = dataset_builder.info.splits[
            'validation'].num_examples
        split_size = validation_size // jax.host_count()
        start = jax.host_id() * split_size
        split = 'validation[{}:{}]'.format(start, start + split_size)

    def _decode_example(example):
        if train:
            image = preprocess_for_train(example['image'], dtype, image_size)
        else:
            image = preprocess_for_eval(example['image'], dtype, image_size)
        return {'image': image, 'label': example['label']}

    ds = dataset_builder.as_dataset(
        split=split, decoders={'image': tfds.decode.SkipDecoding()})
    ds.options().experimental_threading.private_threadpool_size = 48
    ds.options().experimental_threading.max_intra_op_parallelism = 1

    if cache:
        ds = ds.cache()

    if train:
        ds = ds.repeat()
        ds = ds.shuffle(16 * batch_size, seed=0)

    ds = ds.map(_decode_example,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

    ds = ds.batch(batch_size, drop_remainder=True)

    if not train:
        ds = ds.repeat()

    ds = ds.prefetch(10)

    return ds
Exemple #2
0
def _create_eval_dataset(config, dataset_builder, split):
    """Create evaluation dataset (validation or test sets)."""
    # This ensures the correct number of elements in the validation sets.
    num_validation_examples = (dataset_builder.info.splits[split].num_examples)
    eval_split = deterministic_data.get_read_instruction_for_host(
        split, dataset_info=dataset_builder.info, drop_remainder=False)

    eval_num_batches = None
    if config.eval_pad_last_batch:
        # This is doing some extra work to get exactly all examples in the
        # validation split. Without this the dataset would first be split between
        # the different hosts and then into batches (both times dropping the
        # remainder). If you don't mind dropping a few extra examples you can omit
        # the `pad_up_to_batches` argument.
        eval_batch_size = jax.local_device_count(
        ) * config.per_device_batch_size
        eval_num_batches = int(
            np.ceil(num_validation_examples / eval_batch_size /
                    jax.host_count()))
    return deterministic_data.create_dataset(
        dataset_builder,
        split=eval_split,
        # Only cache dataset in distributed setup to avoid consuming a lot of
        # memory in Colab and unit tests.
        cache=jax.host_count() > 1,
        batch_dims=[jax.local_device_count(), config.per_device_batch_size],
        num_epochs=1,
        shuffle=False,
        preprocess_fn=_preprocess_spherical_mnist,
        pad_up_to_batches=eval_num_batches,
    )
Exemple #3
0
    def set_metadata(self):
        """Set meta information about the dataset."""

        num_classes = self.get_num_classes()
        input_shape = self.get_input_shape()

        self.meta_data = {
            'num_classes':
            num_classes,
            'input_shape':
            input_shape,
            'input_dtype':
            self.dtype.jax_dtype,
            'num_train_examples_per_env': {
                env: self.splits.train[env].num_examples
                for env in self.splits.train
            },
            'num_eval_examples_per_env': {
                env: self.splits.validation[env].num_examples
                for env in self.splits.validation
            },
            # We don't sum them, because in this case we are processing them in
            # parallel batches, and we are taking the number of examples in the
            # first  one because we are assuming the number of training examples
            # in all environment is the same, otherwise the other attributes:
            # num_train_examples_per_env and num_eval_examples_per_env should be
            # used.
            'num_train_examples':
            self.splits.train[str(self.env2id(
                self.train_environments[0]))].num_examples * jax.host_count(),
            'num_eval_examples':
            self.splits.validation[str(self.env2id(
                self.eval_environments[0]))].num_examples * jax.host_count(),
        }
Exemple #4
0
def get_translate_wmt(shuffle_rng, batch_size, eval_batch_size=None, hps=None):
    """Wrapper to conform to the general dataset API."""

    per_host_batch_size = batch_size // jax.host_count()
    per_host_eval_batch_size = eval_batch_size // jax.host_count()
    return _get_translate_wmt(per_host_batch_size, per_host_eval_batch_size,
                              hps, shuffle_rng)
def load_split(batch_size,
               train,
               dtype=tf.float32,
               image_size=IMAGE_SIZE,
               cache=False):
    """Creates a split from the ImageNet dataset using TensorFlow Datasets.

  Args:
    batch_size: the batch size returned by the data pipeline.
    train: Whether to load the train or evaluation split.
    dtype: data type of the image.
    image_size: The target size of the images.
    cache: Whether to cache the dataset.
  Returns:
    A `tf.data.Dataset`.
  """
    if train:
        split_size = TRAIN_IMAGES // jax.host_count()
        start = jax.host_id() * split_size
        split = 'train[{}:{}]'.format(start, start + split_size)
    else:
        split_size = EVAL_IMAGES // jax.host_count()
        start = jax.host_id() * split_size
        split = 'validation[{}:{}]'.format(start, start + split_size)

    def decode_example(example):
        if train:
            image = preprocess_for_train(example['image'], dtype, image_size)
        else:
            image = preprocess_for_eval(example['image'], dtype, image_size)
        return {'image': image, 'label': example['label']}

    ds = tfds.load('imagenet2012:5.*.*',
                   split=split,
                   decoders={
                       'image': tfds.decode.SkipDecoding(),
                   })
    options = tf.data.Options()
    options.experimental_threading.private_threadpool_size = 48
    ds = ds.with_options(options)

    if cache:
        ds = ds.cache()

    if train:
        ds = ds.repeat()
        ds = ds.shuffle(16 * batch_size, seed=0)

    ds = ds.map(decode_example,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.batch(batch_size, drop_remainder=True)

    if not train:
        ds = ds.repeat()

    ds = ds.prefetch(10)

    return ds
def main(executable_dict, argv):
    del argv

    work_unit = platform.work_unit()
    tf.enable_v2_behavior()
    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count())
    logging.info('JAX devices: %r', jax.devices())

    work_unit.set_task_status(
        f'host_id: {jax.host_id()}, host_count: {jax.host_count()}')

    # Read configuration
    if FLAGS.config_json:
        logging.info('Reading config from JSON: %s', FLAGS.config_json)
        with tf.io.gfile.GFile(FLAGS.config_json, 'r') as f:
            config = ml_collections.ConfigDict(json.loads(f.read()))
    else:
        config = FLAGS.config
    logging.info('config=%s',
                 config.to_json_best_effort(indent=4, sort_keys=True))

    # Make output directories
    if FLAGS.experiment_dir:
        work_unit.create_artifact(platform.ArtifactType.DIRECTORY,
                                  FLAGS.experiment_dir, 'experiment_dir')
    if FLAGS.work_unit_dir:
        work_unit.create_artifact(platform.ArtifactType.DIRECTORY,
                                  FLAGS.work_unit_dir, 'work_unit_dir')
    logging.info('experiment_dir=%s work_unit_dir=%s', FLAGS.experiment_dir,
                 FLAGS.work_unit_dir)

    # Seeding
    random.seed(config.seed * jax.host_count() + jax.host_id())
    onp.random.seed(config.seed * jax.host_count() + jax.host_id())
    rng = utils.RngGen(
        jax.random.fold_in(jax.random.PRNGKey(config.seed), jax.host_id()))

    # Run the main function
    logging.info('Running executable: %s', FLAGS.executable_name)

    extra_args = {}
    if FLAGS.extra_args_json_str:
        extra_args = json.loads(FLAGS.extra_args_json_str)
        logging.info('Extra args passed in: %r', extra_args)

    executable_dict[FLAGS.executable_name](config=config,
                                           experiment_dir=FLAGS.experiment_dir,
                                           work_unit_dir=FLAGS.work_unit_dir,
                                           rng=rng,
                                           **extra_args)

    utils.barrier()
Exemple #7
0
def load(
    split: Split,
    *,
    is_training: bool,
    batch_dims: Sequence[int],
    bfloat16: bool = False,
) -> Generator[Batch, None, None]:
    """Loads the given split of the dataset."""
    if is_training:
        start, end = _shard(split, jax.host_id(), jax.host_count())
    else:
        start, end = _shard(split, 0, 1)
    tfds_split = tfds.core.ReadInstruction(_to_tfds_split(split),
                                           from_=start,
                                           to=end,
                                           unit='abs')
    ds = tfds.load('imagenet2012:5.*.*',
                   split=tfds_split,
                   decoders={'image': tfds.decode.SkipDecoding()})

    total_batch_size = np.prod(batch_dims)

    options = ds.options()
    options.experimental_threading.private_threadpool_size = 48
    options.experimental_threading.max_intra_op_parallelism = 1
    if is_training:
        options.experimental_deterministic = False

    if is_training:
        if jax.host_count() > 1:
            # Only cache if we are reading a subset of the dataset.
            ds = ds.cache()
        ds = ds.repeat()
        ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)

    else:
        if split.num_examples % total_batch_size != 0:
            raise ValueError(
                f'Test/valid must be divisible by {total_batch_size}')

    def preprocess(example):
        image = _preprocess_image(example['image'], is_training)
        if bfloat16:
            image = tf.cast(image, tf.bfloat16)
        label = tf.cast(example['label'], tf.int32)
        return {'images': image, 'labels': label}

    ds = ds.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    for batch_size in reversed(batch_dims):
        ds = ds.batch(batch_size)

    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

    yield from tfds.as_numpy(ds)
Exemple #8
0
def get_wmt_dataset(batch_size, train, shuffle_size=16384):
    """Get the train or eval split of WMT as a tf.data.Dataset."""
    keys = TRAIN_KEYS if train else EVAL_KEYS

    def parse_function(example_proto):
        return tf.io.parse_single_example(
            example_proto, {
                k: tf.io.FixedLenSequenceFeature(
                    [], tf.int64, allow_missing=True)
                for k in keys
            })

    def cast_to_int32(x):
        return {k: tf.dtypes.cast(x[k], tf.int32) for k in keys}

    def pad(x):
        return {
            k: pad_up_to(x[k], [
                MAX_TRAIN_LEN if train else MAX_EVAL_LEN,
            ])
            for k in keys
        }

    file_pattern = os.path.join(
        FLAGS.train_data_path if train else FLAGS.eval_data_path)
    dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)
    dataset = dataset.shard(jax.host_count(), jax.host_id())
    concurrent_files = min(10, 1024 // jax.host_count())
    dataset = dataset.interleave(tf.data.TFRecordDataset, concurrent_files, 1,
                                 concurrent_files)

    dataset = dataset.map(parse_function, num_parallel_calls=32)
    dataset = dataset.map(cast_to_int32, num_parallel_calls=32)
    if train:
        # Filter out rare long, unpacked single-examples.
        dataset = dataset.filter(length_filter(MAX_TRAIN_LEN))

    dataset = dataset.map(pad, num_parallel_calls=32)
    if train:
        dataset = dataset.cache().shuffle(shuffle_size).repeat()
    dataset = dataset.batch(batch_size, drop_remainder=train)
    if not train:
        dataset = dataset.cache().repeat()
    dataset = dataset.prefetch(1024)

    options = tf.data.Options()
    options.experimental_deterministic = False
    options.experimental_threading.max_intra_op_parallelism = 1
    options.experimental_threading.private_threadpool_size = 48
    dataset = dataset.with_options(options)

    return dataset
Exemple #9
0
def _prepare_dataset(
        dataset: tf.data.Dataset,
        global_batch_size: int,
        shuffle: bool,
        rng: np.ndarray,
        preprocess_fn: Optional[Callable[[Any], Any]] = None,
        num_epochs: Optional[int] = None,
        filter_fn: Optional[Callable[[Any], Any]] = None) -> tf.data.Dataset:
    """Batches, shuffles, prefetches and preprocesses a dataset.

  Args:
    dataset: The dataset to prepare.
    global_batch_size: The global batch size to use.
    shuffle: Whether the shuffle the data on example level.
    rng: PRNG for seeding the shuffle operations.
    preprocess_fn: Preprocessing function that will be applied to every example.
    num_epochs: Number of epochs to repeat the dataset.
    filter_fn: Funtion that filters samples according to some criteria.

  Returns:
    The dataset.
  """
    if shuffle and rng is None:
        raise ValueError("Shuffling without RNG is not supported.")

    if global_batch_size % jax.host_count() != 0:
        raise ValueError(
            f"Batch size {global_batch_size} not divisible by number "
            f"of hosts ({jax.host_count()}).")
    local_batch_size = global_batch_size // jax.host_count()
    batch_dims = [jax.local_device_count(), local_batch_size]

    # tf.data uses single integers as seed.
    if rng is not None:
        rng = rng[0]

    ds = dataset.repeat(num_epochs)
    if shuffle:
        ds = ds.shuffle(1024, seed=rng)

    if preprocess_fn is not None:
        ds = ds.map(preprocess_fn,
                    num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if filter_fn is not None:
        ds = ds.filter(filter_fn)

    for batch_size in reversed(batch_dims):
        ds = ds.batch(batch_size, drop_remainder=True)
    return ds.prefetch(tf.data.experimental.AUTOTUNE)
Exemple #10
0
  def __init__(self, dataset, tokenizer):
    self.tokenizer = tokenizer

    # shard train here already to avoid unnecessary tokenization.
    dataset['train'] = dataset['train'].shard(jax.host_count(), jax.host_id())

    if isinstance(dataset, dict):
      single_split = dataset['train']
    else:
      single_split = dataset

    name_a, *names_other = [
      name for name, feature in single_split.features.items()
      if feature.dtype=='string']
    assert len(names_other) <= 1, (
      'Only single sentences and sentence pairs allowed.')
    if names_other:
      name_b = names_other[0]
      tokenize = lambda example: self.tokenizer(
        example[name_a], example[name_b], truncation=True)
    else:
      tokenize = lambda example: self.tokenizer(
        example[name_a], truncation=True)
    
    mapped_dataset = dataset.map(tokenize, batched=True)
    mapped_dataset.set_format('numpy', columns=[
      'idx', 'input_ids', 'token_type_ids', 'attention_mask', 'label'])
    super().__init__(mapped_dataset)
def load_split(train: bool,
               cache: bool) -> tf.data.Dataset:
  """Creates a split from the ImageNet dataset using TensorFlow Datasets.

  Args:
    train: Whether to load the train or evaluation split.
    cache: Whether to cache the dataset.
  Returns:
    A `tf.data.Dataset`.
  """
  if train:
    split_size = TRAIN_IMAGES // jax.host_count()
    start = jax.host_id() * split_size
    split = 'train[{}:{}]'.format(start, start + split_size)
  else:
    # For validation, we load up the dataset on each host. This will have the
    # effect of evaluating on the whole dataset num_host times, but will
    # prevent size issues. This makes the performance slightly worse when
    # evaluating often, but spares us the need to pad the datasets and mask the
    # loss accordingly.
    split = 'validation'

  ds = tfds.load('imagenet2012:5.*.*', split=split, decoders={
      'image': tfds.decode.SkipDecoding(),
  })
  ds.options().experimental_threading.private_threadpool_size = 48
  ds.options().experimental_threading.max_intra_op_parallelism = 1

  if cache:
    ds = ds.cache()

  return ds
Exemple #12
0
def parallel_write_images(image_write_fn, img_and_path_list):
    """Parallelizes image writing over JAX hosts and CPU cores.

  Args:
    image_write_fn: A function that takes a tuple as input (path, image) and
      writes the result to disk.
    img_and_path_list: A list of tuples (image, path) containing all the images
      that should be written.
  """
    num_hosts = jax.host_count()
    host_id = jax.host_id()
    num_images = len(img_and_path_list)
    num_images_per_batch = math.ceil(num_images / num_hosts)

    # First shard the images onto each host.
    per_host_images_and_paths = []
    for i in range(num_images_per_batch):
        base_index = i * num_hosts
        global_index = base_index + host_id
        if global_index < num_images:
            per_host_images_and_paths.append(img_and_path_list[global_index])

    # Now within each JAX host, use multi-processing to save the sharded images.
    with multiprocessing.pool.ThreadPool() as pool:
        pool.map(image_write_fn, per_host_images_and_paths)
        pool.close()
        pool.join()
Exemple #13
0
def main(argv):
    del argv
    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()),
          flush=True)
    experiment = Experiment()
    experiment.train_and_eval()
def load_ckpt_v2(model_state, dir):
    start = time.time()
    with open(dir + "/meta.json", "r") as f:
        meta = json.load(f)

    # TODO: make this work in the general case
    assert meta["total_hosts"] == jax.host_count(
    ), "Must load with same number of hosts as when saved"

    head_print(f"meta loaded in {time.time() - start:.06}s")

    new_state = {
        "step": np.array([meta["step"]]),
    }

    start = time.time()
    new_state["params"] = parallel_read(
        model_state["params"], dir + f"/params/shard_{jax.host_id()}.npz")
    head_print(f"params loaded in {time.time() - start:.06}s")

    start = time.time()
    new_state["opt_state"] = parallel_read(
        model_state["opt_state"],
        dir + f"/opt_state/shard_{jax.host_id()}.npz")
    head_print(f"opt_state loaded in {time.time() - start:.06}s")

    return new_state
Exemple #15
0
    def _init_host_and_devices(self, n_devices=None, random_seed=None):
        """Initializes host and device attributes for this trainer.

    Args:
      n_devices: Number of devices this trainer will use. If `None`, get the
          number from the backend.
      random_seed: Random seed as the starting point for all random numbers used
          by the trainer. If `None`, calculate one from system time and host id.

    Returns:
      is_chief: True if this trainer has special chief responsibilities.
      n_devices: The passed in value of n_devices or a computed default.
      random_seed: The passed in value of random_seed or a computed default.
    """
        if math.backend_name() == 'jax':
            host_id = jax.host_id()
            host_count = jax.host_count()
        else:
            host_id = 0
            host_count = 1
        is_chief = (host_id == 0)

        device_count = math.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count and math.backend_name() == 'jax':
            raise ValueError(
                'JAX cannot work yet with n_devices != all devices: '
                '%d != %d' % (n_devices, device_count))

        if random_seed is None and host_count > 1:
            random_seed = int(1e6 * (host_id + time.time())) % 2**32
        return is_chief, n_devices, init_random_number_generators(random_seed)
def main():
    args = parser.parse_args()
    logging.set_verbosity(logging.ERROR)
    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True)

    if get_model_cfg(args.model) is not None:
        validate(args)
    else:
        models = list_models(pretrained=True)
        if args.model != 'all':
            models = fnmatch.filter(models, args.model)
        if not models:
            print(f'ERROR: No models found to validate with pattern ({args.model}).')
            exit(1)

        print('Validating:', ', '.join(models))
        results = []
        for m in models:
            args.model = m
            res = validate(args)
            res.update(dict(model=m))
            results.append(res)
        print('Results:')
        for r in results:
            print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
Exemple #17
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    FLAGS.log_dir = FLAGS.workdir
    FLAGS.stderrthreshold = 'info'
    logging.get_absl_handler().start_logging_to_file()

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count())
    logging.info('JAX local devices: %r', jax.local_devices())

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'host_id: {jax.host_id()}, host_count: {jax.host_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, 'workdir')

    if FLAGS.sample:
        sample.save_images(sample.generate_sample(FLAGS.config, FLAGS.workdir),
                           'sample.png')
    else:
        train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
Exemple #18
0
def _load_tfds_imagenet(split_name, n_total):
    split_size = float(n_total) // jax.host_count()
    start = split_size * jax.host_id()
    end = start + split_size
    start_index = int(round(start))
    end_index = int(round(end))
    split = '{}[{}:{}]'.format(split_name, start_index, end_index)
    return tfds.load('imagenet2012:5.*.*', split=split)
Exemple #19
0
def _load_custom_imagenet_split(split_path):
  """Load a custom split of the ImageNet dataset."""
  if not tf.io.gfile.exists(split_path):
    raise RuntimeError('Cannot find {}'.format(split_path))
  shard_filenames = tf.io.gfile.listdir(split_path)
  shard_filenames.sort()
  if jax.host_count() > 1:
    n_hosts = jax.host_count()
    host_id = jax.host_id()
    shard_filenames = [f for i, f in enumerate(shard_filenames)
                       if (i % n_hosts) == host_id]
  files_in_split = [os.path.join(split_path, f) for f in shard_filenames]
  ds = tf.data.TFRecordDataset(files_in_split, buffer_size=128 * 1024 * 1024,
                               num_parallel_reads=len(files_in_split))
  # ds = deserialize_and_decode_image_dataset(ds, batch_size=256)
  ds = deserialize_and_decode_image_dataset(ds, batch_size=1)
  return ds
Exemple #20
0
def load(
        split: Split,
        is_training: bool,
        batch_dims: Sequence[int],
        image_size: int = IMAGE_SIZE,
        chw: bool = False,
        mean: Optional[Tuple[float]] = None,
        std: Optional[Tuple[float]] = None,
        interpolation: str = 'bicubic',
        tfds_data_dir: Optional[str] = None,
):
    mean = MEAN_RGB if mean is None else mean
    std = STDDEV_RGB if std is None else std
    """Loads the given split of the dataset."""
    if is_training:
        start, end = _shard(split, jax.host_id(), jax.host_count())
    else:
        start, end = _shard(split, 0, 1)
    tfds_split = tfds.core.ReadInstruction(_to_tfds_split(split), from_=start, to=end, unit='abs')
    ds = tfds.load(
        'imagenet2012:5.*.*',
        split=tfds_split,
        decoders={'image': tfds.decode.SkipDecoding()},
        data_dir=tfds_data_dir)

    total_batch_size = np.prod(batch_dims)

    options = ds.options()
    options.experimental_threading.private_threadpool_size = 48
    options.experimental_threading.max_intra_op_parallelism = 1
    if is_training:
        options.experimental_deterministic = False

    if is_training:
        ds = ds.repeat()
        ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)
    else:
        if split.num_examples % total_batch_size != 0:
            raise ValueError(f'Test set size must be divisible by {total_batch_size}')
    num_batches = split.num_examples // total_batch_size

    interpolation = tf.image.ResizeMethod.BILINEAR if 'bilinear' in interpolation  else tf.image.ResizeMethod.BICUBIC
    def preprocess(example):
        image = _preprocess_image(
            example['image'], is_training, image_size=image_size, mean=mean, std=std, interpolation=interpolation)
        if chw:
            image = tf.transpose(image, (2, 0, 1))  # transpose HWC image to CHW format
        label = tf.cast(example['label'], tf.int32)
        return {'images': image, 'labels': label}

    ds = ds.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    for batch_size in reversed(batch_dims):
        ds = ds.batch(batch_size)

    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

    return tfds.as_numpy(ds), num_batches
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()),
          flush=True)

    train_and_evaluate(config=FLAGS.config, resume=FLAGS.resume)
Exemple #22
0
def render_image(state, rays, render_fn, rng, normalize_disp, chunk=8192):
    """Render all the pixels of an image (in test mode).

  Args:
    state: model_utils.TrainState.
    rays: a `Rays` namedtuple, the rays to be rendered.
    render_fn: function, jit-ed render function.
    rng: jnp.ndarray, random number generator (used in training mode only).
    normalize_disp: bool, if true then normalize `disp` to [0, 1].
    chunk: int, the size of chunks to render sequentially.

  Returns:
    rgb: jnp.ndarray, rendered color image.
    disp: jnp.ndarray, rendered disparity image.
    acc: jnp.ndarray, rendered accumulated weights per pixel.
  """
    height, width = rays[0].shape[:2]
    num_rays = height * width
    rays = datasets.ray_fn(lambda r: r.reshape((num_rays, -1)), rays)

    unused_rng, key_0, key_1 = jax.random.split(rng, 3)
    model = state.optimizer.target
    model_state = state.model_state
    host_id = jax.host_id()
    results = []
    with nn.stateful(model_state, mutable=False):
        for i in range(0, num_rays, chunk):
            # pylint: disable=cell-var-from-loop
            print("  " + "X" * int((i / num_rays) * 78), end="\r")
            chunk_rays = datasets.ray_fn(lambda r: r[i:i + chunk], rays)
            chunk_size = chunk_rays[0].shape[0]
            rays_remaining = chunk_size % jax.device_count()
            rays_per_host = chunk_size // jax.host_count()
            if rays_remaining != 0:
                padding = jax.device_count() - rays_remaining
                chunk_rays = datasets.ray_fn(
                    lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"),
                    chunk_rays)
            else:
                padding = 0
            # After padding the number of chunk_rays is always divisible by
            # host_count.
            start, stop = host_id * rays_per_host, (host_id +
                                                    1) * rays_per_host
            chunk_rays = datasets.ray_fn(lambda r: shard(r[start:stop]),
                                         chunk_rays)
            chunk_results = render_fn(key_0, key_1, model, chunk_rays)[-1]
            results.append([unshard(x[0], padding) for x in chunk_results])
            # pylint: enable=cell-var-from-loop
        print("")
    rgb, disp, acc = [jnp.concatenate(r, axis=0) for r in zip(*results)]
    # Normalize disp for visualization for ndc_rays in llff front-facing scenes.
    if normalize_disp:
        disp = (disp - disp.min()) / (disp.max() - disp.min())
    return (rgb.reshape((height, width, -1)), disp.reshape(
        (height, width, -1)), acc.reshape((height, width, -1)))
Exemple #23
0
def get_fake(shuffle_rng, batch_size, eval_batch_size, hps=None):
    """Data generators for imagenet."""
    del shuffle_rng
    per_host_batch_size = batch_size // jax.host_count()
    per_host_eval_batch_size = eval_batch_size // jax.host_count()

    fake_train_batch = get_fake_batch(per_host_batch_size, hps.input_shape,
                                      hps.output_shape[0])
    fake_test_batch = get_fake_batch(per_host_eval_batch_size, hps.input_shape,
                                     hps.output_shape[0])

    def train_iterator_fn():
        while True:
            yield fake_train_batch

    def valid_epoch(epoch, num_batches=None):
        del num_batches
        del epoch
        # Note that we do // beacuse we do not support partial batching for the fake
        # dataset.
        for _ in range(hps.valid_size // eval_batch_size):
            yield fake_test_batch

    # pylint: disable=unreachable
    def eval_train_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable
    # pylint: disable=unreachable

    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                              test_epoch)
    def _initialize_training(self, rng):
        # Initialize inputs.
        if self.config.emulated_workers > 0:
            per_device_workers, ragged = divmod(self.config.emulated_workers,
                                                jax.host_count())
            if ragged:
                raise ValueError(
                    'Number of emulated workers must be divisible by the '
                    'number of physical workers `jax.host_count()`.')
            self._repeat_batch = per_device_workers
        else:
            self._repeat_batch = 1
        self.supervised_train_input = jl_utils.py_prefetch(
            self._supervised_train_dataset)
        if self.config.training.extra_data_path is None:
            self.extra_train_input = None
        else:
            self.extra_train_input = jl_utils.py_prefetch(
                self._extra_train_dataset)
        self.normalize_fn = datasets.cifar10_normalize

        # Optimizer.
        self.optimizer = utils.sgd_momentum(self.config.training.learning_rate,
                                            momentum=.9,
                                            nesterov=True)

        # Initialize parameters.
        if self._params is None:
            logging.info(
                'Initializing parameters randomly rather than restoring '
                'from checkpoint.')
            # Create inputs to initialize the network state.
            images, _, _ = jax.pmap(self.concatenate)(
                next(self.supervised_train_input), next(self.extra_train_input)
                if self.extra_train_input is not None else None)
            images = jax.pmap(self.normalize_fn)(images)
            # Initialize weights and biases.
            init_net = jax.pmap(
                lambda *a: self.model.init(*a, is_training=True),
                axis_name='i')
            init_rng = jl_utils.bcast_local_devices(rng)
            self._params, self._state = init_net(init_rng, images)
            # Setup weight averaging.
            if self.config.training.swa_decay > 0:
                self._avg_params = self._params
            else:
                self._avg_params = None
            # Initialize optimizer state.
            init_opt = jax.pmap(self.optimizer.init, axis_name='i')
            self._opt_state = init_opt(self._params)

        # Initialize step function.
        self.train_fn = jax.pmap(self._train_fn,
                                 axis_name='i',
                                 donate_argnums=(0, 1, 2, 3))
Exemple #25
0
    def load_split_from_tfds(self,
                             name,
                             batch_size,
                             train,
                             split=None,
                             shuffle_seed=1):
        """Loads a split from the dataset using TensorFlow Datasets.

    Args:
      name: str; Name of the environment passed to `tfds.load`.
      batch_size: int; The batch size returned by the data pipeline.
      train: bool; Whether to load the train or evaluation split.
      split: str; Name of the dataset split passed to tfds, if None, the value
        is set with respect to the `train` argument.
      shuffle_seed: int; Seed for shuffling the training data.

    Returns:
      A `tf.data.Dataset`.
    """
        if split is None:
            split = 'train' if train else 'test'

        builder = self.get_builder(name)
        # Each host is responsible for a fixed subset of data
        base_split_name, host_start, host_end = dataset_utils.get_data_range(
            builder, split, jax.host_id(), jax.host_count())
        data_range = tfds.core.ReadInstruction(base_split_name,
                                               unit='abs',
                                               from_=host_start,
                                               to=host_end)

        ds, ds_info = self.get_tfds_ds_and_info(name, data_range)

        # Applying preprocessing before `ds.cache()` to re-use it
        decode_example = functools.partial(self.preprocess_example,
                                           env_name=name)
        ds = ds.map(decode_example,
                    num_parallel_calls=tf.data.experimental.AUTOTUNE)

        if self.if_cache:
            ds = ds.cache()
        if train:
            ds = ds.repeat()
            ds = ds.shuffle(16 * batch_size, seed=shuffle_seed)
            ds = ds.map(self.process_train_example,
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
            ds = ds.batch(batch_size, drop_remainder=False)
        else:
            ds = ds.map(self.process_eval_example,
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
            ds = ds.batch(batch_size, drop_remainder=False)
            ds = ds.repeat()

        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
        return ds, ds_info.splits[split].num_examples
def render_image(state, data, render_fn, rng, chunk=8192):
    """Render all the pixels of an image (in test mode).

  Args:
    state: model_utils.TrainState.
    data: dict, test example.
    render_fn: function, jit-ed render function.
    rng: jnp.ndarray, random number generator (used in training mode only).
    chunk: int, the size of chunks to render sequentially.

  Returns:
    rgb: jnp.ndarray, rendered color image.
    disp: jnp.ndarray, rendered disparity image.
    acc: jnp.ndarray, rendered accumulated weights per pixel.
  """
    rays = data["rays"]
    h, w = rays.shape[:2]
    rays = rays.reshape((h * w, -1))
    unused_rng, key_0, key_1 = jax.random.split(rng, 3)
    model = state.optimizer.target
    model_state = state.model_state
    host_id = jax.host_id()
    rgb = []
    disp = []
    acc = []
    with nn.stateful(model_state, mutable=False):
        for i in range(0, rays.shape[0], chunk):
            print("  " + "X" * int((i / rays.shape[0]) * 78), end="\r")
            chunk_rays = rays[i:i + chunk]
            remainder = chunk_rays.shape[0] % jax.device_count()
            if remainder != 0:
                padding = jax.device_count() - remainder
                chunk_rays = jnp.pad(chunk_rays, ((0, padding), (0, 0)),
                                     mode="edge")
            else:
                padding = 0
            # After padding the number of chunk_rays is always divisible by
            # host_count.
            per_host_rays = chunk_rays.shape[0] // jax.host_count()
            chunk_rays = chunk_rays[(host_id * per_host_rays):((host_id + 1) *
                                                               per_host_rays)]
            chunk_rays = shard(chunk_rays)
            ret = render_fn(key_0, key_1, model, chunk_rays)
            rgb.append(unshard(ret[-1][0][0], padding))
            disp.append(unshard(ret[-1][1][0], padding))
            acc.append(unshard(ret[-1][2][0], padding))
        print("")
    rgb = jnp.concatenate(rgb, axis=0)
    disp = jnp.concatenate(disp, axis=0)
    # Normalize disp for visualization for ndc_rays in llff front-facing scenes.
    if rays.shape[-1] > 6:
        disp = (disp - disp.min()) / (disp.max() - disp.min())
    acc = jnp.concatenate(acc, axis=0)
    return (rgb.reshape((h, w, -1)), disp.reshape(
        (h, w, -1)), acc.reshape((h, w, -1)))
Exemple #27
0
def _get_birds200_dataset(
        mode: str,
        rng: np.ndarray) -> Tuple[tf.data.Dataset, tf.data.Dataset, int]:
    """Load the caltech_birds2011 dataset."""
    assert jax.host_count() == 1, (
        "caltech_birds2011 dataset does not support multihost training. "
        "Found {} hosts.".format(jax.host_count()))

    dataset_builder = tfds.builder("caltech_birds2011")
    num_classes = 200

    # Make sure each host uses a different RNG for the training data.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.host_id())
    data_rng, shuffle_rng = jax.random.split(data_rng)

    if mode == "train-val":
        read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[0])
        ds = dataset_builder.as_dataset(split="train",
                                        shuffle_files=False,
                                        read_config=read_config)

        train_ds = ds.take(5000).shuffle(5000, seed=shuffle_rng[0])
        eval_ds = ds.skip(5000)

    elif mode == "train-test":
        train_split = "train"
        eval_split = "test"

        train_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[0])
        train_ds = dataset_builder.as_dataset(split=train_split,
                                              shuffle_files=True,
                                              read_config=train_read_config)

        eval_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[1])
        eval_ds = dataset_builder.as_dataset(split=eval_split,
                                             shuffle_files=False,
                                             read_config=eval_read_config)
    else:
        raise ValueError(f"Unknown mode: {mode}.")

    return train_ds, eval_ds, num_classes
Exemple #28
0
def main(unused_argv):
    # Necessary to use the tfds loader.
    tf.enable_v2_behavior()

    if jax.process_count() > 1:
        # TODO(ankugarg): Add support for multihost inference.
        raise NotImplementedError(
            'BLEU eval does not support multihost inference.')

    rng = jax.random.PRNGKey(FLAGS.seed)

    mt_eval_config = json.loads(FLAGS.mt_eval_config)

    if FLAGS.experiment_config_filename:
        with tf.io.gfile.GFile(FLAGS.experiment_config_filename) as f:
            experiment_config = json.load(f)
        if jax.process_index() == 0:
            logging.info('experiment_config: %r', experiment_config)
        dataset_name = experiment_config['dataset']
        model_name = experiment_config['model']
    else:
        assert FLAGS.dataset and FLAGS.model
        dataset_name = FLAGS.dataset
        model_name = FLAGS.model

    if jax.process_index() == 0:
        logging.info('argv:\n%s', ' '.join(sys.argv))
        logging.info('device_count: %d', jax.device_count())
        logging.info('num_hosts : %d', jax.host_count())
        logging.info('host_id : %d', jax.host_id())

    model_class = models.get_model(model_name)
    dataset_builder = datasets.get_dataset(dataset_name)
    dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)

    hparam_overrides = None
    if FLAGS.hparam_overrides:
        if isinstance(FLAGS.hparam_overrides, str):
            hparam_overrides = json.loads(FLAGS.hparam_overrides)

    merged_hps = hyperparameters.build_hparams(
        model_name=model_name,
        initializer_name=experiment_config['initializer'],
        dataset_name=dataset_name,
        hparam_file=FLAGS.trial_hparams_filename,
        hparam_overrides=hparam_overrides)

    if jax.process_index() == 0:
        logging.info('Merged hps are: %s', json.dumps(merged_hps.to_json()))

    evaluator = bleu_evaluator.BLEUEvaluator(FLAGS.checkpoint_dir, merged_hps,
                                             rng, model_class, dataset_builder,
                                             dataset_meta_data, mt_eval_config)
    evaluator.translate_and_calculate_bleu()
Exemple #29
0
 def bounds_from_last_device(device):
     # Must be passed the device at the highest-coordinate corner of the
     # relevant mesh, which is a requirement we know is satisfied by the last
     # device in jax.devices()
     if hasattr(device, 'coords'):
         x, y, z = device.coords
         return x + 1, y + 1, z + 1, device.id % 2 + 1
     else:
         # On non-TPU platforms, the "mesh" is hosts x devices per host in order
         # to take advantage of faster within-host interconnect
         return jax.host_count(), jax.local_device_count()
Exemple #30
0
def harmonize_across_hosts(optimizer):
    """Ensure that model and optimizer parameters are identical for all hosts."""
    if jax.host_count() == 1:
        return optimizer
    else:
        selector = jnp.zeros(jax.local_device_count())
        if jax.host_id() == 0:
            selector = jax.ops.index_update(selector, 0, 1.0)
        optimizer = jax.pmap(lambda opt, sel: jax.tree_map(
            lambda x: jax.lax.psum(x * sel.astype(x.dtype), 'i'), opt),
                             axis_name='i')(optimizer, selector)
        return optimizer