def create_input_iter(batch_size, image_size, dtype, train, cache): image_shape = (batch_size, image_size, image_size, 3) fake_image = np.random.rand(*image_shape) fake_image = fake_image.astype(dtype.as_numpy_dtype) fake_image = fake_image.reshape((jax.local_device_count(), -1) + fake_image.shape[1:]) fake_label = np.random.randint(1, 1000, (batch_size, )) fake_label = fake_label.astype(np.int32) fake_label = fake_label.reshape((jax.local_device_count(), -1)) fake_batch = {'image': fake_image, 'label': fake_label} it = itertools.repeat(fake_batch) jax_utils.prefetch_to_device(it, 2) return it
def create_input_iter(split, dataset_builder, rng, global_batch_size, mean_rgb, stddev_rgb, image_size, resize_size, aspect_ratio_range, area_range, train, cache, repeat_final_dataset, num_batches): ds = create_split(split, dataset_builder, rng, global_batch_size, train=train, dtype=tf.float32, image_size=image_size, resize_size=resize_size, mean_rgb=mean_rgb, stddev_rgb=stddev_rgb, cache=cache, repeat_final_dataset=repeat_final_dataset, num_batches=num_batches, aspect_ratio_range=aspect_ratio_range, area_range=area_range) it = map(shard_numpy_ds, ds) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. it = jax_utils.prefetch_to_device(it, 2) return it
def iterator_from_dataset( dataset: tf.data.Dataset, batch_size: int, repeat: bool = True, prefetch_size: int = 0, devices: Optional[Sequence[Any]] = None, ): """Create a data iterator that returns JAX arrays from a TF dataset. Args: dataset: the dataset to iterate over. batch_size: the batch sizes the iterator should return. repeat: whether the iterator should repeat the dataset. prefetch_size: the number of batches to prefetch to device. devices: the devices to prefetch to. Returns: An iterator that returns data batches. """ if repeat: dataset = dataset.repeat() if batch_size > 0: dataset = dataset.batch(batch_size) it = map(prepare_tf_data, dataset) else: it = map(prepare_tf_data_unbatched, dataset) if prefetch_size > 0: it = jax_utils.prefetch_to_device(it, prefetch_size, devices) return it
def create_input_iter(dataset_builder, batch_size, image_size, dtype, train, cache): ds = input_pipeline.create_split( dataset_builder, batch_size, image_size=image_size, dtype=dtype, train=train, cache=cache) it = map(prepare_tf_data, ds) it = jax_utils.prefetch_to_device(it, 2) return it
def prepare_train_data(dataset): """Convert a input batch from TF tensors to NumPy arrays.""" local_device_count = jax.local_device_count() def _prepare(x): x = x._numpy() return x.reshape((local_device_count, ) + x.shape[1:]) it = map(lambda x: jax.tree_map(_prepare, x), dataset) return jax_utils.prefetch_to_device(it, 2)
def create_input_iter(batch_size, data_dir, image_size, dtype, train, cache): """Creates input data iterator.""" ds = input_pipeline.load_split(batch_size, data_dir=data_dir, image_size=image_size, dtype=dtype, train=train, cache=cache) it = map(prepare_tf_data, ds) it = jax_utils.prefetch_to_device(it, 2) return it
def prefetch_input_pipeline(ds, n_prefetch=0, devices=None): """Modify input pipeline to prefetch from host to device. Args: ds: tf.data pipeline n_prefetch: number of items to prefetch devices: devices to prefetch to Returns: prefetching ds """ it = iter(ds) it = (data_utils.shard(x) for x in it) if n_prefetch > 0: it = jax_utils.prefetch_to_device(it, n_prefetch, devices=devices) return it
def create_input_iter(dataset_builder, batch_size, mean_rgb, stddev_rgb, image_size, resize_size, aspect_ratio_range, area_range, train, cache): ds = create_split(dataset_builder, batch_size, train=train, image_size=image_size, resize_size=resize_size, mean_rgb=mean_rgb, stddev_rgb=stddev_rgb, cache=cache, aspect_ratio_range=aspect_ratio_range, area_range=area_range) it = map(shard_numpy_ds, ds) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10% it = jax_utils.prefetch_to_device(it, 2) return it
def train(data_path): nonlocal iterator nonlocal state if iterator is None: dataset = npz.load_dataset_from_directory(data_path, duration, batch) iterator = dataset.make_one_shot_iterator() iterator = map( lambda x: jax.tree_map( lambda x: np.reshape(x, (jax.local_device_count(), -1) + x. numpy().shape[1:]), x), iterator) iterator = jax_utils.prefetch_to_device(iterator, 2) for _ in range(train_steps): obs = next(iterator) state, l = train_step(obs, state) local_state = get_first_device(state) l = get_first_device(l) checkpoints.save_checkpoint(model_dir, local_state, local_state.step)
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) test_render_fn = jax.pmap( # Note rng_keys are useless in eval mode since there's no randomness. # pylint: disable=g-long-lambda lambda key_0, key_1, model, rays: jax.lax.all_gather( model(key_0, key_1, *rays), axis_name="batch"), in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=3, axis_name="batch", ) rng, key = random.split(rng) init_model, init_state = models.get_model(key, dataset.peek(), FLAGS) optimizer_def = optim.Adam(FLAGS.lr_init) optimizer = optimizer_def.create(init_model) state = model_utils.TrainState(step=0, optimizer=optimizer, model_state=init_state) if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) offset = state.step + 1 state = jax_utils.replicate(state) del init_model, init_state if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) t_loop_start = time.time() learning_rate_fn = functools.partial(utils.learning_rate_decay, lr_init=FLAGS.lr_init, lr_final=FLAGS.lr_final, max_steps=FLAGS.max_steps, lr_delay_steps=FLAGS.lr_delay_steps, lr_delay_mult=FLAGS.lr_delay_mult) ptrain_step = jax.pmap(train_step, axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=2) # Prefetch_buffer_size = 3 x batch_size pdataset = jax_utils.prefetch_to_device(dataset, 3) n_local_deices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_deices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. stats_trace = [] for step, batch in zip(range(offset, FLAGS.max_steps + 1), pdataset): lr = learning_rate_fn(step) state, stats, keys = ptrain_step(keys, state, batch, lr) if jax.host_id() == 0: stats_trace.append(stats[0]) if step % FLAGS.gc_every == 0: gc.collect() # --- Train logs start --- # Put the training time visualization before the host_id check as in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. state_to_eval = jax.device_get(jax.tree_map(lambda x: x[0], state)) test_case = next(test_dataset) pred_color, pred_disp, pred_acc = utils.render_image( state_to_eval, test_case["rays"], test_render_fn, keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() == 0: psnr = utils.compute_psnr( ((pred_color - test_case["pixels"])**2).mean()) summary_writer.scalar("test_psnr", psnr, step) summary_writer.image("test_pred_color", pred_color, step) summary_writer.image("test_pred_disp", pred_disp, step) summary_writer.image("test_pred_acc", pred_acc, step) summary_writer.image("test_target", test_case["pixels"], step) if jax.host_id() != 0: # Only log via host 0. continue if step % FLAGS.print_every == 0: summary_writer.scalar("train_loss", stats[0].loss[0], step) summary_writer.scalar("train_psnr", stats[0].psnr[0], step) if len(stats) > 1: summary_writer.scalar("train_loss_coarse", stats[1].loss[0], step) summary_writer.scalar("train_psnr_coarse", stats[1].psnr[0], step) avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace])) avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace])) stats_trace = [] summary_writer.scalar("train_avg_loss", avg_loss, step) summary_writer.scalar("train_avg_psnr", avg_psnr, step) summary_writer.scalar("learning_rate", lr, step) steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) t_loop_start = time.time() rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("steps_per_sec", steps_per_sec, step) summary_writer.scalar("rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"i_loss={stats[0].loss[0]:0.5f}, " + f"avg_loss={avg_loss:0.5f}, " + f"lr={lr:0.2e}, " + f"{rays_per_sec:0.3f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state_to_save, state_to_save.step, keep=100) # --- Train logs end --- if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state, int(state.step), keep=100)
def create_data_iter(self, ds, batch_size): """Create an iterator from a tf dataset. Args: ds: tfds dataset; Dataset which we want to build an iterator for. batch_size: int; Batch size for the given dataset split. Returns: Data iter for the given dataset. """ data_iter = iter(ds) def prepare_tf_data(xs): """Reshapes input batch.""" def _prepare(x): # reshape (host_batch_size, height, width, 3) to # (local_devices, device_batch_size, height, width, 3) return x.reshape((self.num_shards, -1) + x.shape[1:]) return jax.tree_map(_prepare, xs) def to_numpy(xs): """Convert a input batch from tf Tensors to numpy arrays.""" return jax.tree_map( lambda x: x._numpy(), # pylint: disable=protected-access xs) def maybe_pad_batch(batch): """Zero pad the batch on the right to the batch_size. Args: batch: dict; A dictionary mapping keys to arrays. We assume that inputs is one of the keys. Returns: A dictionary mapping the same keys to the padded batches. Additionally we add a key representing weights, to indicate how the batch was padded. """ batch_pad = batch_size - batch['inputs'].shape[0] unpadded_mask_shape = batch['inputs'].shape[0] # Most batches will not need padding so we quickly return to avoid # slowdown. if batch_pad == 0: if 'weights' not in batch: batch['weights'] = onp.ones(unpadded_mask_shape, dtype=onp.float32) return batch def zero_pad(array): pad_with = [(0, batch_pad)] + [(0, 0)] * (array.ndim - 1) return onp.pad(array, pad_with, mode='constant') padded_batch = jax.tree_map(zero_pad, batch) padded_batch_mask = zero_pad( onp.ones(unpadded_mask_shape, dtype=onp.float32)) if 'weights' in padded_batch: padded_batch['weights'] *= padded_batch_mask else: padded_batch['weights'] = padded_batch_mask return padded_batch it = map(to_numpy, data_iter) it = map(maybe_pad_batch, it) it = map(prepare_tf_data, it) it = jax_utils.prefetch_to_device(it, 2) return it