def create_input_iter(batch_size, image_size, dtype, train, cache):
            image_shape = (batch_size, image_size, image_size, 3)
            fake_image = np.random.rand(*image_shape)
            fake_image = fake_image.astype(dtype.as_numpy_dtype)
            fake_image = fake_image.reshape((jax.local_device_count(), -1) +
                                            fake_image.shape[1:])

            fake_label = np.random.randint(1, 1000, (batch_size, ))
            fake_label = fake_label.astype(np.int32)
            fake_label = fake_label.reshape((jax.local_device_count(), -1))

            fake_batch = {'image': fake_image, 'label': fake_label}
            it = itertools.repeat(fake_batch)
            jax_utils.prefetch_to_device(it, 2)
            return it
示例#2
0
def create_input_iter(split, dataset_builder, rng, global_batch_size, mean_rgb,
                      stddev_rgb, image_size, resize_size, aspect_ratio_range,
                      area_range, train, cache, repeat_final_dataset,
                      num_batches):
    ds = create_split(split,
                      dataset_builder,
                      rng,
                      global_batch_size,
                      train=train,
                      dtype=tf.float32,
                      image_size=image_size,
                      resize_size=resize_size,
                      mean_rgb=mean_rgb,
                      stddev_rgb=stddev_rgb,
                      cache=cache,
                      repeat_final_dataset=repeat_final_dataset,
                      num_batches=num_batches,
                      aspect_ratio_range=aspect_ratio_range,
                      area_range=area_range)
    it = map(shard_numpy_ds, ds)

    # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%.
    it = jax_utils.prefetch_to_device(it, 2)

    return it
示例#3
0
def iterator_from_dataset(
    dataset: tf.data.Dataset,
    batch_size: int,
    repeat: bool = True,
    prefetch_size: int = 0,
    devices: Optional[Sequence[Any]] = None,
):
    """Create a data iterator that returns JAX arrays from a TF dataset.

    Args:
      dataset: the dataset to iterate over.
      batch_size: the batch sizes the iterator should return.
      repeat: whether the iterator should repeat the dataset.
      prefetch_size: the number of batches to prefetch to device.
      devices: the devices to prefetch to.

    Returns:
      An iterator that returns data batches.
    """
    if repeat:
        dataset = dataset.repeat()

    if batch_size > 0:
        dataset = dataset.batch(batch_size)
        it = map(prepare_tf_data, dataset)
    else:
        it = map(prepare_tf_data_unbatched, dataset)

    if prefetch_size > 0:
        it = jax_utils.prefetch_to_device(it, prefetch_size, devices)

    return it
示例#4
0
def create_input_iter(dataset_builder, batch_size, image_size, dtype, train,
                      cache):
  ds = input_pipeline.create_split(
      dataset_builder, batch_size, image_size=image_size, dtype=dtype,
      train=train, cache=cache)
  it = map(prepare_tf_data, ds)
  it = jax_utils.prefetch_to_device(it, 2)
  return it
示例#5
0
def prepare_train_data(dataset):
    """Convert a input batch from TF tensors to NumPy arrays."""
    local_device_count = jax.local_device_count()

    def _prepare(x):
        x = x._numpy()
        return x.reshape((local_device_count, ) + x.shape[1:])

    it = map(lambda x: jax.tree_map(_prepare, x), dataset)
    return jax_utils.prefetch_to_device(it, 2)
示例#6
0
def create_input_iter(batch_size, data_dir, image_size, dtype, train, cache):
    """Creates input data iterator."""
    ds = input_pipeline.load_split(batch_size,
                                   data_dir=data_dir,
                                   image_size=image_size,
                                   dtype=dtype,
                                   train=train,
                                   cache=cache)
    it = map(prepare_tf_data, ds)
    it = jax_utils.prefetch_to_device(it, 2)
    return it
示例#7
0
def prefetch_input_pipeline(ds, n_prefetch=0, devices=None):
    """Modify input pipeline to prefetch from host to device.

  Args:
    ds: tf.data pipeline
    n_prefetch: number of items to prefetch
    devices: devices to prefetch to

  Returns:
    prefetching ds

  """
    it = iter(ds)
    it = (data_utils.shard(x) for x in it)
    if n_prefetch > 0:
        it = jax_utils.prefetch_to_device(it, n_prefetch, devices=devices)
    return it
def create_input_iter(dataset_builder, batch_size, mean_rgb, stddev_rgb,
                      image_size, resize_size, aspect_ratio_range, area_range,
                      train, cache):
    ds = create_split(dataset_builder,
                      batch_size,
                      train=train,
                      image_size=image_size,
                      resize_size=resize_size,
                      mean_rgb=mean_rgb,
                      stddev_rgb=stddev_rgb,
                      cache=cache,
                      aspect_ratio_range=aspect_ratio_range,
                      area_range=area_range)
    it = map(shard_numpy_ds, ds)

    # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%
    it = jax_utils.prefetch_to_device(it, 2)

    return it
示例#9
0
    def train(data_path):
        nonlocal iterator
        nonlocal state

        if iterator is None:
            dataset = npz.load_dataset_from_directory(data_path, duration,
                                                      batch)
            iterator = dataset.make_one_shot_iterator()
            iterator = map(
                lambda x: jax.tree_map(
                    lambda x: np.reshape(x, (jax.local_device_count(), -1) + x.
                                         numpy().shape[1:]), x), iterator)
            iterator = jax_utils.prefetch_to_device(iterator, 2)

        for _ in range(train_steps):
            obs = next(iterator)
            state, l = train_step(obs, state)
        local_state = get_first_device(state)
        l = get_first_device(l)
        checkpoints.save_checkpoint(model_dir, local_state, local_state.step)
示例#10
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)
    test_render_fn = jax.pmap(
        # Note rng_keys are useless in eval mode since there's no randomness.
        # pylint: disable=g-long-lambda
        lambda key_0, key_1, model, rays: jax.lax.all_gather(
            model(key_0, key_1, *rays), axis_name="batch"),
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=3,
        axis_name="batch",
    )
    rng, key = random.split(rng)
    init_model, init_state = models.get_model(key, dataset.peek(), FLAGS)
    optimizer_def = optim.Adam(FLAGS.lr_init)
    optimizer = optimizer_def.create(init_model)
    state = model_utils.TrainState(step=0,
                                   optimizer=optimizer,
                                   model_state=init_state)
    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    offset = state.step + 1
    state = jax_utils.replicate(state)
    del init_model, init_state

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)
    t_loop_start = time.time()
    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         lr_init=FLAGS.lr_init,
                                         lr_final=FLAGS.lr_final,
                                         max_steps=FLAGS.max_steps,
                                         lr_delay_steps=FLAGS.lr_delay_steps,
                                         lr_delay_mult=FLAGS.lr_delay_mult)

    ptrain_step = jax.pmap(train_step,
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=2)
    # Prefetch_buffer_size = 3 x batch_size
    pdataset = jax_utils.prefetch_to_device(dataset, 3)
    n_local_deices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_deices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    for step, batch in zip(range(offset, FLAGS.max_steps + 1), pdataset):
        lr = learning_rate_fn(step)
        state, stats, keys = ptrain_step(keys, state, batch, lr)
        if jax.host_id() == 0:
            stats_trace.append(stats[0])
        if step % FLAGS.gc_every == 0:
            gc.collect()
        # --- Train logs start ---
        # Put the training time visualization before the host_id check as in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            state_to_eval = jax.device_get(jax.tree_map(lambda x: x[0], state))
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                state_to_eval,
                test_case["rays"],
                test_render_fn,
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"])**2).mean())
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_target", test_case["pixels"], step)
        if jax.host_id() != 0:  # Only log via host 0.
            continue
        if step % FLAGS.print_every == 0:
            summary_writer.scalar("train_loss", stats[0].loss[0], step)
            summary_writer.scalar("train_psnr", stats[0].psnr[0], step)
            if len(stats) > 1:
                summary_writer.scalar("train_loss_coarse", stats[1].loss[0],
                                      step)
                summary_writer.scalar("train_psnr_coarse", stats[1].psnr[0],
                                      step)
            avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
            avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
            stats_trace = []
            summary_writer.scalar("train_avg_loss", avg_loss, step)
            summary_writer.scalar("train_avg_psnr", avg_psnr, step)
            summary_writer.scalar("learning_rate", lr, step)
            steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
            t_loop_start = time.time()
            rays_per_sec = FLAGS.batch_size * steps_per_sec
            summary_writer.scalar("steps_per_sec", steps_per_sec, step)
            summary_writer.scalar("rays_per_sec", rays_per_sec, step)
            precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
            print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                  f"/{FLAGS.max_steps:d}: " +
                  f"i_loss={stats[0].loss[0]:0.5f}, " +
                  f"avg_loss={avg_loss:0.5f}, " + f"lr={lr:0.2e}, " +
                  f"{rays_per_sec:0.3f} rays/sec")
        if step % FLAGS.save_every == 0:
            state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
            checkpoints.save_checkpoint(FLAGS.train_dir,
                                        state_to_save,
                                        state_to_save.step,
                                        keep=100)
        # --- Train logs end ---

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(state.step),
                                    keep=100)
示例#11
0
    def create_data_iter(self, ds, batch_size):
        """Create an iterator from a tf dataset.

    Args:
      ds: tfds dataset; Dataset which we want to build an iterator for.
      batch_size: int; Batch size for the given dataset split.

    Returns:
      Data iter for the given dataset.
    """
        data_iter = iter(ds)

        def prepare_tf_data(xs):
            """Reshapes input batch."""
            def _prepare(x):
                # reshape (host_batch_size, height, width, 3) to
                # (local_devices, device_batch_size, height, width, 3)
                return x.reshape((self.num_shards, -1) + x.shape[1:])

            return jax.tree_map(_prepare, xs)

        def to_numpy(xs):
            """Convert a input batch from tf Tensors to numpy arrays."""

            return jax.tree_map(
                lambda x: x._numpy(),  # pylint: disable=protected-access
                xs)

        def maybe_pad_batch(batch):
            """Zero pad the batch on the right to the batch_size.

      Args:
        batch: dict; A dictionary mapping keys to arrays. We assume that inputs
          is one of the keys.

      Returns:
        A dictionary mapping the same keys to the padded batches. Additionally
        we
        add a key representing weights, to indicate how the batch was padded.
      """
            batch_pad = batch_size - batch['inputs'].shape[0]
            unpadded_mask_shape = batch['inputs'].shape[0]

            # Most batches will not need padding so we quickly return to avoid
            # slowdown.
            if batch_pad == 0:
                if 'weights' not in batch:
                    batch['weights'] = onp.ones(unpadded_mask_shape,
                                                dtype=onp.float32)
                return batch

            def zero_pad(array):
                pad_with = [(0, batch_pad)] + [(0, 0)] * (array.ndim - 1)
                return onp.pad(array, pad_with, mode='constant')

            padded_batch = jax.tree_map(zero_pad, batch)
            padded_batch_mask = zero_pad(
                onp.ones(unpadded_mask_shape, dtype=onp.float32))
            if 'weights' in padded_batch:
                padded_batch['weights'] *= padded_batch_mask
            else:
                padded_batch['weights'] = padded_batch_mask

            return padded_batch

        it = map(to_numpy, data_iter)
        it = map(maybe_pad_batch, it)
        it = map(prepare_tf_data, it)
        it = jax_utils.prefetch_to_device(it, 2)
        return it