def create_split(dataset_builder: tfds.core.DatasetBuilder, batch_size: int, train: bool, dtype: tf.DType = tf.float32, image_size: int = IMAGE_SIZE, cache: bool = False): """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: dataset_builder: TFDS dataset builder for ImageNet. batch_size: the batch size returned by the data pipeline. train: Whether to load the train or evaluation split. dtype: data type of the image (default: float32). image_size: The target size of the images (default: 224). cache: Whether to cache the dataset (default: False). Returns: A `tf.data.Dataset`. """ if train: train_size = dataset_builder.info.splits['train'].num_examples split_size = train_size // jax.host_count() start = jax.host_id() * split_size split = 'train[{}:{}]'.format(start, start + split_size) else: validation_size = dataset_builder.info.splits[ 'validation'].num_examples split_size = validation_size // jax.host_count() start = jax.host_id() * split_size split = 'validation[{}:{}]'.format(start, start + split_size) def _decode_example(example): if train: image = preprocess_for_train(example['image'], dtype, image_size) else: image = preprocess_for_eval(example['image'], dtype, image_size) return {'image': image, 'label': example['label']} ds = dataset_builder.as_dataset( split=split, decoders={'image': tfds.decode.SkipDecoding()}) ds.options().experimental_threading.private_threadpool_size = 48 ds.options().experimental_threading.max_intra_op_parallelism = 1 if cache: ds = ds.cache() if train: ds = ds.repeat() ds = ds.shuffle(16 * batch_size, seed=0) ds = ds.map(_decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=True) if not train: ds = ds.repeat() ds = ds.prefetch(10) return ds
def _create_eval_dataset(config, dataset_builder, split): """Create evaluation dataset (validation or test sets).""" # This ensures the correct number of elements in the validation sets. num_validation_examples = (dataset_builder.info.splits[split].num_examples) eval_split = deterministic_data.get_read_instruction_for_host( split, dataset_info=dataset_builder.info, drop_remainder=False) eval_num_batches = None if config.eval_pad_last_batch: # This is doing some extra work to get exactly all examples in the # validation split. Without this the dataset would first be split between # the different hosts and then into batches (both times dropping the # remainder). If you don't mind dropping a few extra examples you can omit # the `pad_up_to_batches` argument. eval_batch_size = jax.local_device_count( ) * config.per_device_batch_size eval_num_batches = int( np.ceil(num_validation_examples / eval_batch_size / jax.host_count())) return deterministic_data.create_dataset( dataset_builder, split=eval_split, # Only cache dataset in distributed setup to avoid consuming a lot of # memory in Colab and unit tests. cache=jax.host_count() > 1, batch_dims=[jax.local_device_count(), config.per_device_batch_size], num_epochs=1, shuffle=False, preprocess_fn=_preprocess_spherical_mnist, pad_up_to_batches=eval_num_batches, )
def set_metadata(self): """Set meta information about the dataset.""" num_classes = self.get_num_classes() input_shape = self.get_input_shape() self.meta_data = { 'num_classes': num_classes, 'input_shape': input_shape, 'input_dtype': self.dtype.jax_dtype, 'num_train_examples_per_env': { env: self.splits.train[env].num_examples for env in self.splits.train }, 'num_eval_examples_per_env': { env: self.splits.validation[env].num_examples for env in self.splits.validation }, # We don't sum them, because in this case we are processing them in # parallel batches, and we are taking the number of examples in the # first one because we are assuming the number of training examples # in all environment is the same, otherwise the other attributes: # num_train_examples_per_env and num_eval_examples_per_env should be # used. 'num_train_examples': self.splits.train[str(self.env2id( self.train_environments[0]))].num_examples * jax.host_count(), 'num_eval_examples': self.splits.validation[str(self.env2id( self.eval_environments[0]))].num_examples * jax.host_count(), }
def get_translate_wmt(shuffle_rng, batch_size, eval_batch_size=None, hps=None): """Wrapper to conform to the general dataset API.""" per_host_batch_size = batch_size // jax.host_count() per_host_eval_batch_size = eval_batch_size // jax.host_count() return _get_translate_wmt(per_host_batch_size, per_host_eval_batch_size, hps, shuffle_rng)
def load_split(batch_size, train, dtype=tf.float32, image_size=IMAGE_SIZE, cache=False): """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: batch_size: the batch size returned by the data pipeline. train: Whether to load the train or evaluation split. dtype: data type of the image. image_size: The target size of the images. cache: Whether to cache the dataset. Returns: A `tf.data.Dataset`. """ if train: split_size = TRAIN_IMAGES // jax.host_count() start = jax.host_id() * split_size split = 'train[{}:{}]'.format(start, start + split_size) else: split_size = EVAL_IMAGES // jax.host_count() start = jax.host_id() * split_size split = 'validation[{}:{}]'.format(start, start + split_size) def decode_example(example): if train: image = preprocess_for_train(example['image'], dtype, image_size) else: image = preprocess_for_eval(example['image'], dtype, image_size) return {'image': image, 'label': example['label']} ds = tfds.load('imagenet2012:5.*.*', split=split, decoders={ 'image': tfds.decode.SkipDecoding(), }) options = tf.data.Options() options.experimental_threading.private_threadpool_size = 48 ds = ds.with_options(options) if cache: ds = ds.cache() if train: ds = ds.repeat() ds = ds.shuffle(16 * batch_size, seed=0) ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=True) if not train: ds = ds.repeat() ds = ds.prefetch(10) return ds
def main(executable_dict, argv): del argv work_unit = platform.work_unit() tf.enable_v2_behavior() # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count()) logging.info('JAX devices: %r', jax.devices()) work_unit.set_task_status( f'host_id: {jax.host_id()}, host_count: {jax.host_count()}') # Read configuration if FLAGS.config_json: logging.info('Reading config from JSON: %s', FLAGS.config_json) with tf.io.gfile.GFile(FLAGS.config_json, 'r') as f: config = ml_collections.ConfigDict(json.loads(f.read())) else: config = FLAGS.config logging.info('config=%s', config.to_json_best_effort(indent=4, sort_keys=True)) # Make output directories if FLAGS.experiment_dir: work_unit.create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.experiment_dir, 'experiment_dir') if FLAGS.work_unit_dir: work_unit.create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.work_unit_dir, 'work_unit_dir') logging.info('experiment_dir=%s work_unit_dir=%s', FLAGS.experiment_dir, FLAGS.work_unit_dir) # Seeding random.seed(config.seed * jax.host_count() + jax.host_id()) onp.random.seed(config.seed * jax.host_count() + jax.host_id()) rng = utils.RngGen( jax.random.fold_in(jax.random.PRNGKey(config.seed), jax.host_id())) # Run the main function logging.info('Running executable: %s', FLAGS.executable_name) extra_args = {} if FLAGS.extra_args_json_str: extra_args = json.loads(FLAGS.extra_args_json_str) logging.info('Extra args passed in: %r', extra_args) executable_dict[FLAGS.executable_name](config=config, experiment_dir=FLAGS.experiment_dir, work_unit_dir=FLAGS.work_unit_dir, rng=rng, **extra_args) utils.barrier()
def load( split: Split, *, is_training: bool, batch_dims: Sequence[int], bfloat16: bool = False, ) -> Generator[Batch, None, None]: """Loads the given split of the dataset.""" if is_training: start, end = _shard(split, jax.host_id(), jax.host_count()) else: start, end = _shard(split, 0, 1) tfds_split = tfds.core.ReadInstruction(_to_tfds_split(split), from_=start, to=end, unit='abs') ds = tfds.load('imagenet2012:5.*.*', split=tfds_split, decoders={'image': tfds.decode.SkipDecoding()}) total_batch_size = np.prod(batch_dims) options = ds.options() options.experimental_threading.private_threadpool_size = 48 options.experimental_threading.max_intra_op_parallelism = 1 if is_training: options.experimental_deterministic = False if is_training: if jax.host_count() > 1: # Only cache if we are reading a subset of the dataset. ds = ds.cache() ds = ds.repeat() ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0) else: if split.num_examples % total_batch_size != 0: raise ValueError( f'Test/valid must be divisible by {total_batch_size}') def preprocess(example): image = _preprocess_image(example['image'], is_training) if bfloat16: image = tf.cast(image, tf.bfloat16) label = tf.cast(example['label'], tf.int32) return {'images': image, 'labels': label} ds = ds.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) for batch_size in reversed(batch_dims): ds = ds.batch(batch_size) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) yield from tfds.as_numpy(ds)
def get_wmt_dataset(batch_size, train, shuffle_size=16384): """Get the train or eval split of WMT as a tf.data.Dataset.""" keys = TRAIN_KEYS if train else EVAL_KEYS def parse_function(example_proto): return tf.io.parse_single_example( example_proto, { k: tf.io.FixedLenSequenceFeature( [], tf.int64, allow_missing=True) for k in keys }) def cast_to_int32(x): return {k: tf.dtypes.cast(x[k], tf.int32) for k in keys} def pad(x): return { k: pad_up_to(x[k], [ MAX_TRAIN_LEN if train else MAX_EVAL_LEN, ]) for k in keys } file_pattern = os.path.join( FLAGS.train_data_path if train else FLAGS.eval_data_path) dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False) dataset = dataset.shard(jax.host_count(), jax.host_id()) concurrent_files = min(10, 1024 // jax.host_count()) dataset = dataset.interleave(tf.data.TFRecordDataset, concurrent_files, 1, concurrent_files) dataset = dataset.map(parse_function, num_parallel_calls=32) dataset = dataset.map(cast_to_int32, num_parallel_calls=32) if train: # Filter out rare long, unpacked single-examples. dataset = dataset.filter(length_filter(MAX_TRAIN_LEN)) dataset = dataset.map(pad, num_parallel_calls=32) if train: dataset = dataset.cache().shuffle(shuffle_size).repeat() dataset = dataset.batch(batch_size, drop_remainder=train) if not train: dataset = dataset.cache().repeat() dataset = dataset.prefetch(1024) options = tf.data.Options() options.experimental_deterministic = False options.experimental_threading.max_intra_op_parallelism = 1 options.experimental_threading.private_threadpool_size = 48 dataset = dataset.with_options(options) return dataset
def _prepare_dataset( dataset: tf.data.Dataset, global_batch_size: int, shuffle: bool, rng: np.ndarray, preprocess_fn: Optional[Callable[[Any], Any]] = None, num_epochs: Optional[int] = None, filter_fn: Optional[Callable[[Any], Any]] = None) -> tf.data.Dataset: """Batches, shuffles, prefetches and preprocesses a dataset. Args: dataset: The dataset to prepare. global_batch_size: The global batch size to use. shuffle: Whether the shuffle the data on example level. rng: PRNG for seeding the shuffle operations. preprocess_fn: Preprocessing function that will be applied to every example. num_epochs: Number of epochs to repeat the dataset. filter_fn: Funtion that filters samples according to some criteria. Returns: The dataset. """ if shuffle and rng is None: raise ValueError("Shuffling without RNG is not supported.") if global_batch_size % jax.host_count() != 0: raise ValueError( f"Batch size {global_batch_size} not divisible by number " f"of hosts ({jax.host_count()}).") local_batch_size = global_batch_size // jax.host_count() batch_dims = [jax.local_device_count(), local_batch_size] # tf.data uses single integers as seed. if rng is not None: rng = rng[0] ds = dataset.repeat(num_epochs) if shuffle: ds = ds.shuffle(1024, seed=rng) if preprocess_fn is not None: ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) if filter_fn is not None: ds = ds.filter(filter_fn) for batch_size in reversed(batch_dims): ds = ds.batch(batch_size, drop_remainder=True) return ds.prefetch(tf.data.experimental.AUTOTUNE)
def __init__(self, dataset, tokenizer): self.tokenizer = tokenizer # shard train here already to avoid unnecessary tokenization. dataset['train'] = dataset['train'].shard(jax.host_count(), jax.host_id()) if isinstance(dataset, dict): single_split = dataset['train'] else: single_split = dataset name_a, *names_other = [ name for name, feature in single_split.features.items() if feature.dtype=='string'] assert len(names_other) <= 1, ( 'Only single sentences and sentence pairs allowed.') if names_other: name_b = names_other[0] tokenize = lambda example: self.tokenizer( example[name_a], example[name_b], truncation=True) else: tokenize = lambda example: self.tokenizer( example[name_a], truncation=True) mapped_dataset = dataset.map(tokenize, batched=True) mapped_dataset.set_format('numpy', columns=[ 'idx', 'input_ids', 'token_type_ids', 'attention_mask', 'label']) super().__init__(mapped_dataset)
def load_split(train: bool, cache: bool) -> tf.data.Dataset: """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: train: Whether to load the train or evaluation split. cache: Whether to cache the dataset. Returns: A `tf.data.Dataset`. """ if train: split_size = TRAIN_IMAGES // jax.host_count() start = jax.host_id() * split_size split = 'train[{}:{}]'.format(start, start + split_size) else: # For validation, we load up the dataset on each host. This will have the # effect of evaluating on the whole dataset num_host times, but will # prevent size issues. This makes the performance slightly worse when # evaluating often, but spares us the need to pad the datasets and mask the # loss accordingly. split = 'validation' ds = tfds.load('imagenet2012:5.*.*', split=split, decoders={ 'image': tfds.decode.SkipDecoding(), }) ds.options().experimental_threading.private_threadpool_size = 48 ds.options().experimental_threading.max_intra_op_parallelism = 1 if cache: ds = ds.cache() return ds
def parallel_write_images(image_write_fn, img_and_path_list): """Parallelizes image writing over JAX hosts and CPU cores. Args: image_write_fn: A function that takes a tuple as input (path, image) and writes the result to disk. img_and_path_list: A list of tuples (image, path) containing all the images that should be written. """ num_hosts = jax.host_count() host_id = jax.host_id() num_images = len(img_and_path_list) num_images_per_batch = math.ceil(num_images / num_hosts) # First shard the images onto each host. per_host_images_and_paths = [] for i in range(num_images_per_batch): base_index = i * num_hosts global_index = base_index + host_id if global_index < num_images: per_host_images_and_paths.append(img_and_path_list[global_index]) # Now within each JAX host, use multi-processing to save the sharded images. with multiprocessing.pool.ThreadPool() as pool: pool.map(image_write_fn, per_host_images_and_paths) pool.close() pool.join()
def main(argv): del argv print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) experiment = Experiment() experiment.train_and_eval()
def load_ckpt_v2(model_state, dir): start = time.time() with open(dir + "/meta.json", "r") as f: meta = json.load(f) # TODO: make this work in the general case assert meta["total_hosts"] == jax.host_count( ), "Must load with same number of hosts as when saved" head_print(f"meta loaded in {time.time() - start:.06}s") new_state = { "step": np.array([meta["step"]]), } start = time.time() new_state["params"] = parallel_read( model_state["params"], dir + f"/params/shard_{jax.host_id()}.npz") head_print(f"params loaded in {time.time() - start:.06}s") start = time.time() new_state["opt_state"] = parallel_read( model_state["opt_state"], dir + f"/opt_state/shard_{jax.host_id()}.npz") head_print(f"opt_state loaded in {time.time() - start:.06}s") return new_state
def _init_host_and_devices(self, n_devices=None, random_seed=None): """Initializes host and device attributes for this trainer. Args: n_devices: Number of devices this trainer will use. If `None`, get the number from the backend. random_seed: Random seed as the starting point for all random numbers used by the trainer. If `None`, calculate one from system time and host id. Returns: is_chief: True if this trainer has special chief responsibilities. n_devices: The passed in value of n_devices or a computed default. random_seed: The passed in value of random_seed or a computed default. """ if math.backend_name() == 'jax': host_id = jax.host_id() host_count = jax.host_count() else: host_id = 0 host_count = 1 is_chief = (host_id == 0) device_count = math.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count and math.backend_name() == 'jax': raise ValueError( 'JAX cannot work yet with n_devices != all devices: ' '%d != %d' % (n_devices, device_count)) if random_seed is None and host_count > 1: random_seed = int(1e6 * (host_id + time.time())) % 2**32 return is_chief, n_devices, init_random_number_generators(random_seed)
def main(): args = parser.parse_args() logging.set_verbosity(logging.ERROR) print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) if get_model_cfg(args.model) is not None: validate(args) else: models = list_models(pretrained=True) if args.model != 'all': models = fnmatch.filter(models, args.model) if not models: print(f'ERROR: No models found to validate with pattern ({args.model}).') exit(1) print('Validating:', ', '.join(models)) results = [] for m in models: args.model = m res = validate(args) res.update(dict(model=m)) results.append(res) print('Results:') for r in results: print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') FLAGS.log_dir = FLAGS.workdir FLAGS.stderrthreshold = 'info' logging.get_absl_handler().start_logging_to_file() # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'host_id: {jax.host_id()}, host_count: {jax.host_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir') if FLAGS.sample: sample.save_images(sample.generate_sample(FLAGS.config, FLAGS.workdir), 'sample.png') else: train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
def _load_tfds_imagenet(split_name, n_total): split_size = float(n_total) // jax.host_count() start = split_size * jax.host_id() end = start + split_size start_index = int(round(start)) end_index = int(round(end)) split = '{}[{}:{}]'.format(split_name, start_index, end_index) return tfds.load('imagenet2012:5.*.*', split=split)
def _load_custom_imagenet_split(split_path): """Load a custom split of the ImageNet dataset.""" if not tf.io.gfile.exists(split_path): raise RuntimeError('Cannot find {}'.format(split_path)) shard_filenames = tf.io.gfile.listdir(split_path) shard_filenames.sort() if jax.host_count() > 1: n_hosts = jax.host_count() host_id = jax.host_id() shard_filenames = [f for i, f in enumerate(shard_filenames) if (i % n_hosts) == host_id] files_in_split = [os.path.join(split_path, f) for f in shard_filenames] ds = tf.data.TFRecordDataset(files_in_split, buffer_size=128 * 1024 * 1024, num_parallel_reads=len(files_in_split)) # ds = deserialize_and_decode_image_dataset(ds, batch_size=256) ds = deserialize_and_decode_image_dataset(ds, batch_size=1) return ds
def load( split: Split, is_training: bool, batch_dims: Sequence[int], image_size: int = IMAGE_SIZE, chw: bool = False, mean: Optional[Tuple[float]] = None, std: Optional[Tuple[float]] = None, interpolation: str = 'bicubic', tfds_data_dir: Optional[str] = None, ): mean = MEAN_RGB if mean is None else mean std = STDDEV_RGB if std is None else std """Loads the given split of the dataset.""" if is_training: start, end = _shard(split, jax.host_id(), jax.host_count()) else: start, end = _shard(split, 0, 1) tfds_split = tfds.core.ReadInstruction(_to_tfds_split(split), from_=start, to=end, unit='abs') ds = tfds.load( 'imagenet2012:5.*.*', split=tfds_split, decoders={'image': tfds.decode.SkipDecoding()}, data_dir=tfds_data_dir) total_batch_size = np.prod(batch_dims) options = ds.options() options.experimental_threading.private_threadpool_size = 48 options.experimental_threading.max_intra_op_parallelism = 1 if is_training: options.experimental_deterministic = False if is_training: ds = ds.repeat() ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0) else: if split.num_examples % total_batch_size != 0: raise ValueError(f'Test set size must be divisible by {total_batch_size}') num_batches = split.num_examples // total_batch_size interpolation = tf.image.ResizeMethod.BILINEAR if 'bilinear' in interpolation else tf.image.ResizeMethod.BICUBIC def preprocess(example): image = _preprocess_image( example['image'], is_training, image_size=image_size, mean=mean, std=std, interpolation=interpolation) if chw: image = tf.transpose(image, (2, 0, 1)) # transpose HWC image to CHW format label = tf.cast(example['label'], tf.int32) return {'images': image, 'labels': label} ds = ds.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) for batch_size in reversed(batch_dims): ds = ds.batch(batch_size) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) return tfds.as_numpy(ds), num_batches
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) train_and_evaluate(config=FLAGS.config, resume=FLAGS.resume)
def render_image(state, rays, render_fn, rng, normalize_disp, chunk=8192): """Render all the pixels of an image (in test mode). Args: state: model_utils.TrainState. rays: a `Rays` namedtuple, the rays to be rendered. render_fn: function, jit-ed render function. rng: jnp.ndarray, random number generator (used in training mode only). normalize_disp: bool, if true then normalize `disp` to [0, 1]. chunk: int, the size of chunks to render sequentially. Returns: rgb: jnp.ndarray, rendered color image. disp: jnp.ndarray, rendered disparity image. acc: jnp.ndarray, rendered accumulated weights per pixel. """ height, width = rays[0].shape[:2] num_rays = height * width rays = datasets.ray_fn(lambda r: r.reshape((num_rays, -1)), rays) unused_rng, key_0, key_1 = jax.random.split(rng, 3) model = state.optimizer.target model_state = state.model_state host_id = jax.host_id() results = [] with nn.stateful(model_state, mutable=False): for i in range(0, num_rays, chunk): # pylint: disable=cell-var-from-loop print(" " + "X" * int((i / num_rays) * 78), end="\r") chunk_rays = datasets.ray_fn(lambda r: r[i:i + chunk], rays) chunk_size = chunk_rays[0].shape[0] rays_remaining = chunk_size % jax.device_count() rays_per_host = chunk_size // jax.host_count() if rays_remaining != 0: padding = jax.device_count() - rays_remaining chunk_rays = datasets.ray_fn( lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"), chunk_rays) else: padding = 0 # After padding the number of chunk_rays is always divisible by # host_count. start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host chunk_rays = datasets.ray_fn(lambda r: shard(r[start:stop]), chunk_rays) chunk_results = render_fn(key_0, key_1, model, chunk_rays)[-1] results.append([unshard(x[0], padding) for x in chunk_results]) # pylint: enable=cell-var-from-loop print("") rgb, disp, acc = [jnp.concatenate(r, axis=0) for r in zip(*results)] # Normalize disp for visualization for ndc_rays in llff front-facing scenes. if normalize_disp: disp = (disp - disp.min()) / (disp.max() - disp.min()) return (rgb.reshape((height, width, -1)), disp.reshape( (height, width, -1)), acc.reshape((height, width, -1)))
def get_fake(shuffle_rng, batch_size, eval_batch_size, hps=None): """Data generators for imagenet.""" del shuffle_rng per_host_batch_size = batch_size // jax.host_count() per_host_eval_batch_size = eval_batch_size // jax.host_count() fake_train_batch = get_fake_batch(per_host_batch_size, hps.input_shape, hps.output_shape[0]) fake_test_batch = get_fake_batch(per_host_eval_batch_size, hps.input_shape, hps.output_shape[0]) def train_iterator_fn(): while True: yield fake_train_batch def valid_epoch(epoch, num_batches=None): del num_batches del epoch # Note that we do // beacuse we do not support partial batching for the fake # dataset. for _ in range(hps.valid_size // eval_batch_size): yield fake_test_batch # pylint: disable=unreachable def eval_train_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def _initialize_training(self, rng): # Initialize inputs. if self.config.emulated_workers > 0: per_device_workers, ragged = divmod(self.config.emulated_workers, jax.host_count()) if ragged: raise ValueError( 'Number of emulated workers must be divisible by the ' 'number of physical workers `jax.host_count()`.') self._repeat_batch = per_device_workers else: self._repeat_batch = 1 self.supervised_train_input = jl_utils.py_prefetch( self._supervised_train_dataset) if self.config.training.extra_data_path is None: self.extra_train_input = None else: self.extra_train_input = jl_utils.py_prefetch( self._extra_train_dataset) self.normalize_fn = datasets.cifar10_normalize # Optimizer. self.optimizer = utils.sgd_momentum(self.config.training.learning_rate, momentum=.9, nesterov=True) # Initialize parameters. if self._params is None: logging.info( 'Initializing parameters randomly rather than restoring ' 'from checkpoint.') # Create inputs to initialize the network state. images, _, _ = jax.pmap(self.concatenate)( next(self.supervised_train_input), next(self.extra_train_input) if self.extra_train_input is not None else None) images = jax.pmap(self.normalize_fn)(images) # Initialize weights and biases. init_net = jax.pmap( lambda *a: self.model.init(*a, is_training=True), axis_name='i') init_rng = jl_utils.bcast_local_devices(rng) self._params, self._state = init_net(init_rng, images) # Setup weight averaging. if self.config.training.swa_decay > 0: self._avg_params = self._params else: self._avg_params = None # Initialize optimizer state. init_opt = jax.pmap(self.optimizer.init, axis_name='i') self._opt_state = init_opt(self._params) # Initialize step function. self.train_fn = jax.pmap(self._train_fn, axis_name='i', donate_argnums=(0, 1, 2, 3))
def load_split_from_tfds(self, name, batch_size, train, split=None, shuffle_seed=1): """Loads a split from the dataset using TensorFlow Datasets. Args: name: str; Name of the environment passed to `tfds.load`. batch_size: int; The batch size returned by the data pipeline. train: bool; Whether to load the train or evaluation split. split: str; Name of the dataset split passed to tfds, if None, the value is set with respect to the `train` argument. shuffle_seed: int; Seed for shuffling the training data. Returns: A `tf.data.Dataset`. """ if split is None: split = 'train' if train else 'test' builder = self.get_builder(name) # Each host is responsible for a fixed subset of data base_split_name, host_start, host_end = dataset_utils.get_data_range( builder, split, jax.host_id(), jax.host_count()) data_range = tfds.core.ReadInstruction(base_split_name, unit='abs', from_=host_start, to=host_end) ds, ds_info = self.get_tfds_ds_and_info(name, data_range) # Applying preprocessing before `ds.cache()` to re-use it decode_example = functools.partial(self.preprocess_example, env_name=name) ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self.if_cache: ds = ds.cache() if train: ds = ds.repeat() ds = ds.shuffle(16 * batch_size, seed=shuffle_seed) ds = ds.map(self.process_train_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=False) else: ds = ds.map(self.process_eval_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=False) ds = ds.repeat() ds = ds.prefetch(tf.data.experimental.AUTOTUNE) return ds, ds_info.splits[split].num_examples
def render_image(state, data, render_fn, rng, chunk=8192): """Render all the pixels of an image (in test mode). Args: state: model_utils.TrainState. data: dict, test example. render_fn: function, jit-ed render function. rng: jnp.ndarray, random number generator (used in training mode only). chunk: int, the size of chunks to render sequentially. Returns: rgb: jnp.ndarray, rendered color image. disp: jnp.ndarray, rendered disparity image. acc: jnp.ndarray, rendered accumulated weights per pixel. """ rays = data["rays"] h, w = rays.shape[:2] rays = rays.reshape((h * w, -1)) unused_rng, key_0, key_1 = jax.random.split(rng, 3) model = state.optimizer.target model_state = state.model_state host_id = jax.host_id() rgb = [] disp = [] acc = [] with nn.stateful(model_state, mutable=False): for i in range(0, rays.shape[0], chunk): print(" " + "X" * int((i / rays.shape[0]) * 78), end="\r") chunk_rays = rays[i:i + chunk] remainder = chunk_rays.shape[0] % jax.device_count() if remainder != 0: padding = jax.device_count() - remainder chunk_rays = jnp.pad(chunk_rays, ((0, padding), (0, 0)), mode="edge") else: padding = 0 # After padding the number of chunk_rays is always divisible by # host_count. per_host_rays = chunk_rays.shape[0] // jax.host_count() chunk_rays = chunk_rays[(host_id * per_host_rays):((host_id + 1) * per_host_rays)] chunk_rays = shard(chunk_rays) ret = render_fn(key_0, key_1, model, chunk_rays) rgb.append(unshard(ret[-1][0][0], padding)) disp.append(unshard(ret[-1][1][0], padding)) acc.append(unshard(ret[-1][2][0], padding)) print("") rgb = jnp.concatenate(rgb, axis=0) disp = jnp.concatenate(disp, axis=0) # Normalize disp for visualization for ndc_rays in llff front-facing scenes. if rays.shape[-1] > 6: disp = (disp - disp.min()) / (disp.max() - disp.min()) acc = jnp.concatenate(acc, axis=0) return (rgb.reshape((h, w, -1)), disp.reshape( (h, w, -1)), acc.reshape((h, w, -1)))
def _get_birds200_dataset( mode: str, rng: np.ndarray) -> Tuple[tf.data.Dataset, tf.data.Dataset, int]: """Load the caltech_birds2011 dataset.""" assert jax.host_count() == 1, ( "caltech_birds2011 dataset does not support multihost training. " "Found {} hosts.".format(jax.host_count())) dataset_builder = tfds.builder("caltech_birds2011") num_classes = 200 # Make sure each host uses a different RNG for the training data. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.host_id()) data_rng, shuffle_rng = jax.random.split(data_rng) if mode == "train-val": read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[0]) ds = dataset_builder.as_dataset(split="train", shuffle_files=False, read_config=read_config) train_ds = ds.take(5000).shuffle(5000, seed=shuffle_rng[0]) eval_ds = ds.skip(5000) elif mode == "train-test": train_split = "train" eval_split = "test" train_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[0]) train_ds = dataset_builder.as_dataset(split=train_split, shuffle_files=True, read_config=train_read_config) eval_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[1]) eval_ds = dataset_builder.as_dataset(split=eval_split, shuffle_files=False, read_config=eval_read_config) else: raise ValueError(f"Unknown mode: {mode}.") return train_ds, eval_ds, num_classes
def main(unused_argv): # Necessary to use the tfds loader. tf.enable_v2_behavior() if jax.process_count() > 1: # TODO(ankugarg): Add support for multihost inference. raise NotImplementedError( 'BLEU eval does not support multihost inference.') rng = jax.random.PRNGKey(FLAGS.seed) mt_eval_config = json.loads(FLAGS.mt_eval_config) if FLAGS.experiment_config_filename: with tf.io.gfile.GFile(FLAGS.experiment_config_filename) as f: experiment_config = json.load(f) if jax.process_index() == 0: logging.info('experiment_config: %r', experiment_config) dataset_name = experiment_config['dataset'] model_name = experiment_config['model'] else: assert FLAGS.dataset and FLAGS.model dataset_name = FLAGS.dataset model_name = FLAGS.model if jax.process_index() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.host_count()) logging.info('host_id : %d', jax.host_id()) model_class = models.get_model(model_name) dataset_builder = datasets.get_dataset(dataset_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) hparam_overrides = None if FLAGS.hparam_overrides: if isinstance(FLAGS.hparam_overrides, str): hparam_overrides = json.loads(FLAGS.hparam_overrides) merged_hps = hyperparameters.build_hparams( model_name=model_name, initializer_name=experiment_config['initializer'], dataset_name=dataset_name, hparam_file=FLAGS.trial_hparams_filename, hparam_overrides=hparam_overrides) if jax.process_index() == 0: logging.info('Merged hps are: %s', json.dumps(merged_hps.to_json())) evaluator = bleu_evaluator.BLEUEvaluator(FLAGS.checkpoint_dir, merged_hps, rng, model_class, dataset_builder, dataset_meta_data, mt_eval_config) evaluator.translate_and_calculate_bleu()
def bounds_from_last_device(device): # Must be passed the device at the highest-coordinate corner of the # relevant mesh, which is a requirement we know is satisfied by the last # device in jax.devices() if hasattr(device, 'coords'): x, y, z = device.coords return x + 1, y + 1, z + 1, device.id % 2 + 1 else: # On non-TPU platforms, the "mesh" is hosts x devices per host in order # to take advantage of faster within-host interconnect return jax.host_count(), jax.local_device_count()
def harmonize_across_hosts(optimizer): """Ensure that model and optimizer parameters are identical for all hosts.""" if jax.host_count() == 1: return optimizer else: selector = jnp.zeros(jax.local_device_count()) if jax.host_id() == 0: selector = jax.ops.index_update(selector, 0, 1.0) optimizer = jax.pmap(lambda opt, sel: jax.tree_map( lambda x: jax.lax.psum(x * sel.astype(x.dtype), 'i'), opt), axis_name='i')(optimizer, selector) return optimizer