def main(_): master = jax.host_id() == 0 # make sure TF does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() # load configs from a config json string hparams = FLAGS.config logging.info('=========== Hyperparameters ============') logging.info(hparams) if hparams.get('debug'): logging.warning('DEBUG MODE IS ENABLED!') # set tensorflow random seed tf.random.set_seed(jax.host_id() + hparams.rng_seed) experiment_dir = FLAGS.experiment_dir logging.info('Experiment directory: %s', experiment_dir) summary_writer = None if master and hparams.write_summary: tensorboard_dir = os.path.join(experiment_dir, 'tb_summaries') gfile.makedirs(tensorboard_dir) summary_writer = tensorboard.SummaryWriter(tensorboard_dir) run(hparams, experiment_dir, summary_writer) pool.close() pool.join()
def create_split(dataset_builder: tfds.core.DatasetBuilder, batch_size: int, train: bool, dtype: tf.DType = tf.float32, image_size: int = IMAGE_SIZE, cache: bool = False): """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: dataset_builder: TFDS dataset builder for ImageNet. batch_size: the batch size returned by the data pipeline. train: Whether to load the train or evaluation split. dtype: data type of the image (default: float32). image_size: The target size of the images (default: 224). cache: Whether to cache the dataset (default: False). Returns: A `tf.data.Dataset`. """ if train: train_size = dataset_builder.info.splits['train'].num_examples split_size = train_size // jax.host_count() start = jax.host_id() * split_size split = 'train[{}:{}]'.format(start, start + split_size) else: validation_size = dataset_builder.info.splits[ 'validation'].num_examples split_size = validation_size // jax.host_count() start = jax.host_id() * split_size split = 'validation[{}:{}]'.format(start, start + split_size) def _decode_example(example): if train: image = preprocess_for_train(example['image'], dtype, image_size) else: image = preprocess_for_eval(example['image'], dtype, image_size) return {'image': image, 'label': example['label']} ds = dataset_builder.as_dataset( split=split, decoders={'image': tfds.decode.SkipDecoding()}) ds.options().experimental_threading.private_threadpool_size = 48 ds.options().experimental_threading.max_intra_op_parallelism = 1 if cache: ds = ds.cache() if train: ds = ds.repeat() ds = ds.shuffle(16 * batch_size, seed=0) ds = ds.map(_decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=True) if not train: ds = ds.repeat() ds = ds.prefetch(10) return ds
def process_iterator(tag: str, item_ids: Sequence[str], iterator, rng: types.PRNGKey, state: model_utils.TrainState, step: int, render_fn: Any, summary_writer: tensorboard.SummaryWriter, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource): """Process a dataset iterator and compute metrics.""" save_dir = save_dir / f'{step:08d}' / tag if save_dir else None meters = collections.defaultdict(utils.ValueMeter) for i, (item_id, batch) in enumerate(zip(item_ids, iterator)): logging.info('[%s:%d/%d] Processing %s ', tag, i + 1, len(item_ids), item_id) if tag == 'test': test_rng = random.PRNGKey(step) shape = batch['origins'][..., :1].shape metadata = {} if datasource.use_appearance_id: appearance_id = random.choice( test_rng, jnp.asarray(datasource.appearance_ids)) logging.info('\tUsing appearance_id = %d', appearance_id) metadata['appearance'] = jnp.full(shape, fill_value=appearance_id, dtype=jnp.uint32) if datasource.use_warp_id: warp_id = random.choice(test_rng, jnp.asarray(datasource.warp_ids)) logging.info('\tUsing warp_id = %d', warp_id) metadata['warp'] = jnp.full(shape, fill_value=warp_id, dtype=jnp.uint32) if datasource.use_camera_id: camera_id = random.choice(test_rng, jnp.asarray(datasource.camera_ids)) logging.info('\tUsing camera_id = %d', camera_id) metadata['camera'] = jnp.full(shape, fill_value=camera_id, dtype=jnp.uint32) batch['metadata'] = metadata stats = process_batch(batch=batch, rng=rng, state=state, tag=tag, item_id=item_id, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource) if jax.host_id() == 0: for k, v in stats.items(): meters[k].update(v) if jax.host_id() == 0: for meter_name, meter in meters.items(): summary_writer.scalar(tag=f'metrics-eval/{meter_name}/{tag}', value=meter.reduce('mean'), step=step)
def load_split(batch_size, train, dtype=tf.float32, image_size=IMAGE_SIZE, cache=False): """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: batch_size: the batch size returned by the data pipeline. train: Whether to load the train or evaluation split. dtype: data type of the image. image_size: The target size of the images. cache: Whether to cache the dataset. Returns: A `tf.data.Dataset`. """ if train: split_size = TRAIN_IMAGES // jax.host_count() start = jax.host_id() * split_size split = 'train[{}:{}]'.format(start, start + split_size) else: split_size = EVAL_IMAGES // jax.host_count() start = jax.host_id() * split_size split = 'validation[{}:{}]'.format(start, start + split_size) def decode_example(example): if train: image = preprocess_for_train(example['image'], dtype, image_size) else: image = preprocess_for_eval(example['image'], dtype, image_size) return {'image': image, 'label': example['label']} ds = tfds.load('imagenet2012:5.*.*', split=split, decoders={ 'image': tfds.decode.SkipDecoding(), }) options = tf.data.Options() options.experimental_threading.private_threadpool_size = 48 ds = ds.with_options(options) if cache: ds = ds.cache() if train: ds = ds.repeat() ds = ds.shuffle(16 * batch_size, seed=0) ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=True) if not train: ds = ds.repeat() ds = ds.prefetch(10) return ds
def eval_once(run_configuration, checkpoint_path, optimizer=None): """Evaluates a single checkpoint on a single epoch of data.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter optimizer = optimizer or adapter.create_optimizer(run_configuration) dataset = run_configuration.dataset_info.dataset info = run_configuration.dataset_info.info eval_name = config.eval_name or 'eval' log_dir = os.path.join(run_dir, eval_name) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(log_dir) # Restore checkpoint optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer) step = int(optimizer.state.step) # Replicate optimizer. optimizer = flax.jax_utils.replicate(optimizer) eval_step = adapter.make_eval_step() eval_step_parallel = jax.pmap(eval_step, axis_name='batch') # Perform evaluation tick = time.time() metrics_all = [] example = None dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw) for unused_eval_step, example in zip(range(config.eval_steps), dataset_iter): train_inputs = adapter.get_train_inputs(example) metrics, logits, state = eval_step_parallel(optimizer.target, train_inputs) metrics_all.append(metrics) # Write results. metrics_all = common_utils.get_metrics(metrics_all) metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('eval @ train step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = len(metrics_all) / (tock - tick) examples_per_sec = denominator / (tock - tick) summary_writer.scalar('per-second/steps', steps_per_sec, step) summary_writer.scalar('per-second/examples', examples_per_sec, step) for key, val in summary.items(): summary_writer.scalar(key, val, step) adapter.write_summaries(example, logits, summary_writer, info, step, state) summary_writer.flush()
def main(executable_dict, argv): del argv work_unit = platform.work_unit() tf.enable_v2_behavior() # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count()) logging.info('JAX devices: %r', jax.devices()) work_unit.set_task_status( f'host_id: {jax.host_id()}, host_count: {jax.host_count()}') # Read configuration if FLAGS.config_json: logging.info('Reading config from JSON: %s', FLAGS.config_json) with tf.io.gfile.GFile(FLAGS.config_json, 'r') as f: config = ml_collections.ConfigDict(json.loads(f.read())) else: config = FLAGS.config logging.info('config=%s', config.to_json_best_effort(indent=4, sort_keys=True)) # Make output directories if FLAGS.experiment_dir: work_unit.create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.experiment_dir, 'experiment_dir') if FLAGS.work_unit_dir: work_unit.create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.work_unit_dir, 'work_unit_dir') logging.info('experiment_dir=%s work_unit_dir=%s', FLAGS.experiment_dir, FLAGS.work_unit_dir) # Seeding random.seed(config.seed * jax.host_count() + jax.host_id()) onp.random.seed(config.seed * jax.host_count() + jax.host_id()) rng = utils.RngGen( jax.random.fold_in(jax.random.PRNGKey(config.seed), jax.host_id())) # Run the main function logging.info('Running executable: %s', FLAGS.executable_name) extra_args = {} if FLAGS.extra_args_json_str: extra_args = json.loads(FLAGS.extra_args_json_str) logging.info('Extra args passed in: %r', extra_args) executable_dict[FLAGS.executable_name](config=config, experiment_dir=FLAGS.experiment_dir, work_unit_dir=FLAGS.work_unit_dir, rng=rng, **extra_args) utils.barrier()
def meta_init(loss_fn, model, hps, input_shape, output_shape, rng_key, metrics_logger=None, log_every=10): """Implements MetaInit initializer. Args: loss_fn: Loss function. model: Flax Model class. hps: HParam object. Required hparams are meta_learning_rate, meta_batch_size, meta_steps, and epsilon. input_shape: Must agree with batch[0].shape[1:]. output_shape: Must agree with batch[1].shape[1:]. rng_key: jax.PRNGKey, used to seed all randomness. metrics_logger: Instance of utils.MetricsLogger log_every: Print meta loss every k steps. Returns: A Flax model with the learned initialization. """ # Pretty print the preinitialized norms with the variable shapes. if jax.host_id() == 0: logging.info('Preinitialized norms:') _log_shape_and_norms(model.params, metrics_logger, key='init_norms') # First grab the norms of all weights and rescale params to have norm 1. logging.info('Running meta init') norms = jax.tree_map(lambda node: jnp.linalg.norm(node.reshape(-1)), model.params) normalized_params = jax.tree_map(normalize, model.params) learned_norms, _ = meta_optimize_scales(loss_fn, model.module.call, normalized_params, norms, hps, input_shape, output_shape, rng_key, metrics_logger=metrics_logger, log_every=log_every) new_init = scale_params(normalized_params, learned_norms) if jax.host_id() == 0: # Pretty print the meta init norms with the variable shapes. logging.info('Learned norms from meta_init:') _log_shape_and_norms(new_init, metrics_logger, key='meta_init_norms') return nn.Model(model.module, new_init)
def compute_interpolations(self, model, gdirs, udirs, hvex, cvex, step): """Compute the linear interpolation along directions of gdirs or udirs.""" row = {'step': step} if not self.eval_config['compute_interps']: return row lower = self.eval_config['lower_thresh'] upper = self.eval_config['upper_thresh'] num_points = self.eval_config['num_points'] etas = np.linspace(lower, upper, num=num_points, endpoint=True) row = {'step_size': etas} for i, u_dir in enumerate(gdirs): u_dir = _tree_normalize(u_dir) loss_values = np.zeros(shape=(num_points,)) for j in range(num_points): eta = etas[j] loss_values[j] = self._full_batch_eval(model, u_dir, eta) row['loss%d' % (i,)] = np.copy(loss_values) if jax.host_id() == 0: logging.info('Loss interpolation along gradients finished.') for i, u_dir in enumerate(udirs): u_dir = _tree_normalize(u_dir) loss_values = np.zeros(shape=(num_points,)) for j in range(num_points): eta = etas[j] loss_values[j] = self._full_batch_eval(model, u_dir, eta) row['loss_u%d' % (i,)] = np.copy(loss_values) if jax.host_id() == 0: logging.info('Loss interpolation along optimizer directions finished.') _, unflatten = ravel_pytree(gdirs[0]) for i, u_dir in enumerate(hvex): loss_values = np.zeros(shape=(num_points,)) u_dir = unflatten(u_dir) for j in range(num_points): eta = etas[j] loss_values[j] = self._full_batch_eval(model, u_dir, eta) row['loss_hvec%d' % (i,)] = np.copy(loss_values) for i, u_dir in enumerate(cvex): loss_values = np.zeros(shape=(num_points,)) u_dir = unflatten(u_dir) for j in range(num_points): eta = etas[j] loss_values[j] = self._full_batch_eval(model, u_dir, eta) row['loss_cvec%d' % (i,)] = np.copy(loss_values) if jax.host_id() == 0: logging.info('Loss interpolations finished. Statistics captured:') logging.info(row.keys()) return row
def run_train_single_device(run_configuration): """Runs the training workflow without pmap or jit.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter checkpoint_path = run_configuration.original_checkpoint_path dataset = run_configuration.dataset_info.dataset random_seed = 0 rng = jax.random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) dropout_rng, init_rng = jax.random.split(rng) # Set up optimizer. optimizer = adapter.create_optimizer(run_configuration, rng=init_rng) # Set up train step. train_step = adapter.make_train_step(single_device=True) # Set up checkpointing. # TODO(dbieber): Set up phoenix. checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir) if checkpoint_path is None: checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir) optimizer = checkpoint_utils.handle_restart_behavior( checkpoint_path, optimizer, config) start_step = int(optimizer.state.step) num_train_steps = config.train.total_steps # Begin training loop. dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw, single_device=True) for step, example in zip(range(start_step, num_train_steps), dataset_iter): print(f'Step #{step}') train_inputs = adapter.get_train_inputs(example) optimizer, metrics, dropout_rng, logits, state = train_step( optimizer, train_inputs, dropout_rng) del metrics, logits, state # Unused. # Save a Checkpoint. if ((step % config.logging.save_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.logging.save_freq: # Save unreplicated optimizer + model state. checkpoint_utils.save_checkpoint(checkpoint_dir, optimizer, step)
def run_eval(self, flax_module, batch_stats, optimizer_state, global_step): """Computes the loss hessian and returns the max eigenvalue. Note, the full lanczos tridiagonal matrix is saved via the logger to train_dir/checkpoints/config['name']. Args: flax_module: Replicated flax module. batch_stats: Replicated batch_stats from the trainer. optimizer_state: Replicated optimizer state from the trainer. global_step: Current training step. Returns: Max eigenvalue of the loss (full tridiag is saved to disk). """ del batch_stats if self.callback_config.get('precondition'): precondition_config = self.callback_config.get( 'precondition_config', default=FrozenConfigDict()) diag_preconditioner = precondition.make_diag_preconditioner( self.hps.optimizer, self.hps.opt_hparams, jax_utils.unreplicate(optimizer_state), precondition_config) else: diag_preconditioner = None hessian_metrics, _, _ = self.hessian_evaluator.evaluate_spectrum( flax_module, global_step, diag_preconditioner=diag_preconditioner) if jax.host_id() == 0: self.logger.append_pytree(hessian_metrics) max_eig_key = self.name + '/max_eig' return {max_eig_key: hessian_metrics['max_eig_hess']}
def save_checkpoint( self, experiment_state: Mapping[str, jnp.ndarray], opt_state: Mapping[str, jnp.ndarray], step: int, extra_checkpoint_info: Optional[Mapping[str, Any]] = None) -> None: """Save checkpoint with experiment state and step information. Args: experiment_state: Experiment params to be stored. opt_state: Optimizer state to be stored. step: Training iteration step. extra_checkpoint_info: Extra information to be stored. """ if jax.host_id() != 0: return checkpoint_data = dict( experiment_state=jax.tree_map(jax.device_get, experiment_state), opt_state=jax.tree_map(jax.device_get, opt_state), step=step) if extra_checkpoint_info is not None: for key in extra_checkpoint_info: checkpoint_data[key] = extra_checkpoint_info[key] with open(self._checkpoint_path, 'wb') as checkpoint_file: dill.dump(checkpoint_data, checkpoint_file, protocol=2)
def __init__(self, multihost_base_directory: str, tf_state: Optional[Dict[str, Any]] = None, *, host_id: Optional[int] = None, max_to_keep: int = 5, checkpoint_name: str = "ckpt"): """Initializes a MultihostCheckpoint with a dict of TensorFlow Trackables. Args: multihost_base_directory: Directory that will be used to construct a host-specific `base_directory` under which the checkpoints will be stored. tf_state: A dictionary of TensorFlow `Trackable` to be serialized, for example a dataset iterator. host_id: Host ID used to construct the `base_directory`. Taken from `jax.host_id()` if not specified. max_to_keep: Number of checkpoints to keep in the directory. If there are more checkpoints than specified by this number, then the oldest checkpoints are removed. checkpoint_name: Prefix of the checkpoint files (before `-{number}`). """ if max_to_keep < 2: raise ValueError("Requires multiple checkpoints (max_to_keep>=2).") multihost_base_directory = multihost_base_directory.rstrip("/") self.multihost_base_directory = multihost_base_directory if host_id is None: host_id = jax.host_id() base_directory = f"{multihost_base_directory}-{host_id}" super().__init__(base_directory, tf_state, max_to_keep=max_to_keep, checkpoint_name=checkpoint_name)
def main(argv): del argv print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) experiment = Experiment() experiment.train_and_eval()
def __init__(self, dataset, tokenizer): self.tokenizer = tokenizer # shard train here already to avoid unnecessary tokenization. dataset['train'] = dataset['train'].shard(jax.host_count(), jax.host_id()) if isinstance(dataset, dict): single_split = dataset['train'] else: single_split = dataset name_a, *names_other = [ name for name, feature in single_split.features.items() if feature.dtype=='string'] assert len(names_other) <= 1, ( 'Only single sentences and sentence pairs allowed.') if names_other: name_b = names_other[0] tokenize = lambda example: self.tokenizer( example[name_a], example[name_b], truncation=True) else: tokenize = lambda example: self.tokenizer( example[name_a], truncation=True) mapped_dataset = dataset.map(tokenize, batched=True) mapped_dataset.set_format('numpy', columns=[ 'idx', 'input_ids', 'token_type_ids', 'attention_mask', 'label']) super().__init__(mapped_dataset)
def maybe_save_checkpoint(self, experiment_state: Mapping[Text, jnp.ndarray], step: int, rng: jnp.ndarray, is_final: bool): """Saves a checkpoint if enough time has passed since the previous one.""" current_time = time.time() if (not self._checkpoint_enabled or jax.host_id() != 0 or # Only checkpoint the first worker. (not is_final and current_time - self._last_checkpoint_time < self._checkpoint_every)): return checkpoint_data = dict(experiment_state=jax.tree_map( lambda x: jax.device_get(x[0]), experiment_state), step=step, rng=rng) with open(self._checkpoint_path + '_tmp', 'wb') as checkpoint_file: dill.dump(checkpoint_data, checkpoint_file, protocol=2) try: os.rename(self._checkpoint_path, self._checkpoint_path + '_old') remove_old = True except FileNotFoundError: remove_old = False # No previous checkpoint to remove os.rename(self._checkpoint_path + '_tmp', self._checkpoint_path) if remove_old: os.remove(self._checkpoint_path + '_old') self._last_checkpoint_time = current_time
def prepare_batches_gen(dataset, eval_config): """Returns a data iterator. The API for the data iterator will be for b in batches_gen(): pass We yield the same "epoch" every time to the data iterator is called. Args: dataset: An init2winit.dataset_lib.Dataset object. This is ignored if eval_config['use_training_gen'] == False. eval_config: A dict specifying the parameters for the hessian eval. Returns: A data generator. """ train_iter = itertools.islice(dataset.train_iterator_fn(), 0, eval_config['num_batches']) batches = list(train_iter) init_rng = jax.random.PRNGKey(eval_config['rng_key']) init_rng = jax.random.fold_in(init_rng, jax.host_id()) def training_batches_gen(): for counter, batch in enumerate(batches): batch = data_utils.shard(batch) rng = jax.random.fold_in(init_rng, counter) rng = jax_utils.replicate(rng) yield (batch, rng) return training_batches_gen
def specialize_rng_host_device(rng, axis_name, mode="unique_host_unique_device"): """Specializes a rng to the host/device we are on. Must be called from within a pmapped function. Args: rng: a jax.random.PRNGKey. axis_name: the axis of the devices we are specializing across. mode: str mode. Must be one of "unique_host_unique_device", "unique_host_same_device", "same_host_unique_device", "same_host_same_device". Returns: jax.random.PRNGKey specialized to host/device. """ # Will throw an error if mode is not a valid enumeration. enum_mode = DistributedRNGMode(mode) if enum_mode in [ DistributedRNGMode.UNIQUE_HOST_UNIQUE_DEVICE, DistributedRNGMode.UNIQUE_HOST_SAME_DEVICE ]: rng = jax.random.fold_in(rng, jax.host_id()) if enum_mode in [ DistributedRNGMode.UNIQUE_HOST_UNIQUE_DEVICE, DistributedRNGMode.SAME_HOST_UNIQUE_DEVICE ]: rng = jax.random.fold_in(rng, jax.lax.axis_index(axis_name)) return rng
def _get_tfds_dataset( dataset: str, rng: np.ndarray) -> Tuple[tf.data.Dataset, tf.data.Dataset, int]: """Loads a TFDS dataset.""" dataset_builder = tfds.builder(dataset) num_classes = 0 if "label" in dataset_builder.info.features: num_classes = dataset_builder.info.features["label"].num_classes # Make sure each host uses a different RNG for the training data. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.host_id()) data_rng, shuffle_rng = jax.random.split(data_rng) train_split = deterministic_data.get_read_instruction_for_host( "train", dataset_builder.info.splits["train"].num_examples) train_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[0]) train_ds = dataset_builder.as_dataset(split=train_split, shuffle_files=True, read_config=train_read_config) eval_split_name = { "cifar10": "test", "imagenet2012": "validation" }.get(dataset, "test") eval_split_size = dataset_builder.info.splits[eval_split_name].num_examples eval_split = deterministic_data.get_read_instruction_for_host( eval_split_name, eval_split_size) eval_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[1]) eval_ds = dataset_builder.as_dataset(split=eval_split, shuffle_files=False, read_config=eval_read_config) return train_ds, eval_ds, num_classes
def load_extra(batch_sizes: Sequence[int], path_npz: str, is_training: bool = True, drop_remainder: bool = True) -> tf.data.Dataset: """Loads extra data from a given path.""" if not tf.io.gfile.exists(path_npz): if path_npz in _ALLOWED_FILES: path_npz = tf.keras.utils.get_file(path_npz, _DATA_URL + path_npz) else: raise ValueError( f'Extra data not found ({path_npz}). See {_WEBPAGE} for ' 'more details.') with tf.io.gfile.GFile(path_npz, 'rb') as fp: npzfile = np.load(fp) data = {'image': npzfile['image'], 'label': npzfile['label']} with tf.device( '/device:cpu:0'): # Prevent allocation to happen on GPU. ds = tf.data.Dataset.from_tensor_slices(data) ds = ds.cache() if is_training: ds = ds.repeat() ds = ds.shuffle(buffer_size=50_000, seed=jax.host_id()) ds = ds.map(cifar10_preprocess('train' if is_training else 'test'), num_parallel_calls=tf.data.AUTOTUNE) for batch_size in reversed(batch_sizes): ds = ds.batch(batch_size, drop_remainder=drop_remainder) return ds.prefetch(tf.data.AUTOTUNE)
def parallel_write_images(image_write_fn, img_and_path_list): """Parallelizes image writing over JAX hosts and CPU cores. Args: image_write_fn: A function that takes a tuple as input (path, image) and writes the result to disk. img_and_path_list: A list of tuples (image, path) containing all the images that should be written. """ num_hosts = jax.host_count() host_id = jax.host_id() num_images = len(img_and_path_list) num_images_per_batch = math.ceil(num_images / num_hosts) # First shard the images onto each host. per_host_images_and_paths = [] for i in range(num_images_per_batch): base_index = i * num_hosts global_index = base_index + host_id if global_index < num_images: per_host_images_and_paths.append(img_and_path_list[global_index]) # Now within each JAX host, use multi-processing to save the sharded images. with multiprocessing.pool.ThreadPool() as pool: pool.map(image_write_fn, per_host_images_and_paths) pool.close() pool.join()
def load_split(train: bool, cache: bool) -> tf.data.Dataset: """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: train: Whether to load the train or evaluation split. cache: Whether to cache the dataset. Returns: A `tf.data.Dataset`. """ if train: split_size = TRAIN_IMAGES // jax.host_count() start = jax.host_id() * split_size split = 'train[{}:{}]'.format(start, start + split_size) else: # For validation, we load up the dataset on each host. This will have the # effect of evaluating on the whole dataset num_host times, but will # prevent size issues. This makes the performance slightly worse when # evaluating often, but spares us the need to pad the datasets and mask the # loss accordingly. split = 'validation' ds = tfds.load('imagenet2012:5.*.*', split=split, decoders={ 'image': tfds.decode.SkipDecoding(), }) ds.options().experimental_threading.private_threadpool_size = 48 ds.options().experimental_threading.max_intra_op_parallelism = 1 if cache: ds = ds.cache() return ds
def main(): args = parser.parse_args() logging.set_verbosity(logging.ERROR) print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) if get_model_cfg(args.model) is not None: validate(args) else: models = list_models(pretrained=True) if args.model != 'all': models = fnmatch.filter(models, args.model) if not models: print(f'ERROR: No models found to validate with pattern ({args.model}).') exit(1) print('Validating:', ', '.join(models)) results = [] for m in models: args.model = m res = validate(args) res.update(dict(model=m)) results.append(res) print('Results:') for r in results: print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
def test(optimizer, state, p_eval_step, step, test_ds, summary_writer): """Test the flax module in optimizer on test_ds. Args: optimizer: flax optimizer (contains flax module). state: model state, e.g. batch statistics. p_eval_step: fn; Pmapped evaluation step function. step: int; Number of training steps passed so far. test_ds: tf.dataset; Test dataset. summary_writer: tensorflow summary writer. """ # Test Metrics test_metrics = [] test_iter = iter(test_ds) for _, test_batch in zip(itertools.repeat(1), test_iter): # pylint: disable=protected-access test_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), test_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, state, test_batch) test_metrics.append(metrics) test_metrics = common_utils.get_metrics(test_metrics) test_metrics_sums = jax.tree_map(jnp.sum, test_metrics) test_denominator = test_metrics_sums.pop('denominator') test_summary = jax.tree_map( lambda x: x / test_denominator, # pylint: disable=cell-var-from-loop test_metrics_sums) logging.info('test in step: %d, loss: %.4f, acc: %.4f', step, test_summary['loss'], test_summary['accuracy']) if jax.host_id() == 0: for key, val in test_summary.items(): summary_writer.scalar(f'test_{key}', val, step) summary_writer.flush()
def train(self): """Training loop.""" master = jax.host_id() == 0 train_metrics = [] train_summary, eval_summary = None, None tick = time.time() # Main train loop. for step in range(self.start_step + 1, self.total_steps + 1): train_batch = self.get_next_batch( self.task.dataset.data_iters.train) self.train_state, t_metrics = self.pmapped_train_step( self.train_state, train_batch) train_metrics.append(t_metrics) eval_summary, train_metrics, train_summary, tick = self.maybe_eval_and_log( eval_summary, master, step, tick, train_metrics, train_summary) # sync and save self.train_state = self.checkpoint(self.train_state, step) # wait until computations are done before exiting (for timing!) jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() # return the train and eval summary after last step for regresesion testing return train_summary, eval_summary
def _init_host_and_devices(self, n_devices=None, random_seed=None): """Initializes host and device attributes for this trainer. Args: n_devices: Number of devices this trainer will use. If `None`, get the number from the backend. random_seed: Random seed as the starting point for all random numbers used by the trainer. If `None`, calculate one from system time and host id. Returns: is_chief: True if this trainer has special chief responsibilities. n_devices: The passed in value of n_devices or a computed default. random_seed: The passed in value of random_seed or a computed default. """ if math.backend_name() == 'jax': host_id = jax.host_id() host_count = jax.host_count() else: host_id = 0 host_count = 1 is_chief = (host_id == 0) device_count = math.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count and math.backend_name() == 'jax': raise ValueError( 'JAX cannot work yet with n_devices != all devices: ' '%d != %d' % (n_devices, device_count)) if random_seed is None and host_count > 1: random_seed = int(1e6 * (host_id + time.time())) % 2**32 return is_chief, n_devices, init_random_number_generators(random_seed)
def checkpoint(self, train_state, step): """Saves checkpoint. Syncs the model state across replicas if needed. Args: train_state: TrainSate; A flax struct that keeps model state and optimizer state. step: int; Number of steps passes so far during training. Returns: train_state """ checkpoint_flag = False if self.hparams.get('ckpnt_steps', None) and self.hparams.checkpoint: if step in self.hparams.get('ckpnt_steps'): checkpoint_flag = True elif ((step % self.checkpoint_frequency == 0) or (step == self.total_steps)) and self.hparams.checkpoint: checkpoint_flag = True if checkpoint_flag: # Sync model state across replicas. train_state = pipeline_utils.sync_model_state_across_replicas( train_state) if jax.host_id() == 0: pipeline_utils.save_checkpoint(self.experiment_dir, train_state, keep=self.hparams.keep_ckpts) return train_state
def save_checkpoint(optimizer: flax.optim.Optimizer, model_state: Any, directory: str, epoch: int): """Saves a model and its state. Removes a checkpoint if it already exists for a given epoch. For multi-host training, only the first host will save the checkpoint. Args: optimizer: The optimizer containing the model that we are training. model_state: Current state associated with the model. directory: Directory where the checkpoints should be saved. epoch: Number of epochs the model has been trained for. """ if jax.host_id() != 0: return # Sync across replicas before saving. optimizer = jax.tree_map(lambda x: x[0], optimizer) model_state = jax.tree_map(lambda x: jnp.mean(x, axis=0), model_state) train_state = dict(optimizer=optimizer, model_state=model_state, epoch=epoch) if gfile.exists(os.path.join(directory, 'checkpoint_' + str(epoch))): gfile.remove(os.path.join(directory, 'checkpoint_' + str(epoch))) checkpoints.save_checkpoint(directory, train_state, epoch, keep=2)
def eval_loop(experiment_class, config): """The main evaluation loop. This loop periodically loads a checkpoint and evaluates its performance on the test set, by calling experiment.evaluate. Args: experiment_class: the constructor for the experiment (either byol_experiment or eval_experiment). config: the experiment config. """ experiment = experiment_class(**config) last_evaluated_step = -1 while True: checkpoint_data = experiment.load_checkpoint() if checkpoint_data is None: logging.info('No checkpoint found. Waiting for 10s.') time.sleep(10) continue step, _ = checkpoint_data if step <= last_evaluated_step: logging.info('Checkpoint at step %d already evaluated, waiting.', step) time.sleep(10) continue host_id = jax.host_id() local_device_count = jax.local_device_count() step_device = np.broadcast_to(step, [local_device_count]) scalars = experiment.evaluate(global_step=step_device) if host_id == 0: # Only perform logging in one host. logging.info('Evaluation at step %d: %s', step, scalars) last_evaluated_step = step if last_evaluated_step >= config['max_steps']: return
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if not gfile.IsDirectory(FLAGS.model_dir): gfile.MakeDirs(os.path.dirname(FLAGS.model_dir)) logging.info('Number of recognized devices: %d', jax.local_device_count()) logging.info('Import pretrained weights: %s', FLAGS.load_tf_weights) jax_squad_model = get_squad_model() with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) step = 0 optimizer = create_optimizer(jax_squad_model, FLAGS.learning_rate) if FLAGS.load_checkpoint: optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) step = optimizer.state[0].step[0] if FLAGS.mode in ('train', 'train_and_predict'): optimizer = train_squad(optimizer, input_meta_data, step) if FLAGS.mode in ('predict', 'train_and_predict'): if not FLAGS.use_eval_sharding: optimizer = optimizer.unreplicate() if jax.host_id() == 0: predict_squad(optimizer, input_meta_data) global_barrier()
def train(self): """Training loop.""" master = jax.host_id() == 0 train_metrics = [] train_summary, eval_summary = None, None tick = time.time() eval_env_ids = list( map(int, self.task.dataset.data_iters.validation.keys())) train_env_ids, train_iters = list( zip(*dict(self.task.dataset.data_iters['train']).items())) train_env_ids = list(map(int, train_env_ids)) for step in range(self.start_step + 1, self.total_steps + 1): train_batches = self.get_next_batch(train_iters) self.train_state, t_metrics = self.pmapped_train_step( self.train_state, train_batches) t_metrics = jax.tree_map(lambda x: x[0], t_metrics) train_metrics.append(t_metrics) eval_summary, train_metrics, train_summary, tick = self.maybe_eval_and_log( eval_env_ids, eval_summary, master, step, tick, train_metrics, train_summary) # Sync and save self.train_state = self.checkpoint(self.train_state, step) # wait until computations are done before exiting (for timing!) jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() # return the train and eval summary after last step for regresesion testing return train_summary, eval_summary