def create_val_dataset(config, split, batch_size, pad_last_batch): """Create validataion dataset. Args: config: ml_collections.ConfigDict to use. split: The validation split. batch_size: The batch size. pad_last_batch: Bool to indicate whether to pad last patch or not. Returns: The validation dataset. """ dataset_builder = gscan_dataset.GSCANDataset(**config) num_batches = None cardinality = None if pad_last_batch: num_examples = dataset_builder.num_examples[split] val_batch_size = jax.local_device_count() * batch_size num_batches = int( np.ceil(num_examples / val_batch_size / jax.process_count())) cardinality = int(np.ceil(num_examples / jax.process_count())) ds = deterministic_data.create_dataset( dataset_builder, split=split, preprocess_fn=dataset_builder.preprocess, cache=jax.process_count() > 1, batch_dims=[jax.local_device_count(), batch_size], num_epochs=1, shuffle=False, pad_up_to_batches=num_batches, cardinality=cardinality) return ds
def setup_jax(globally_use_hardware_rng: bool, jax_use_gda: bool, jax_backend_target: Optional[str], jax_xla_backend: Optional[str]) -> None: """Setups JAX and logs information about this job.""" # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') if globally_use_hardware_rng: py_utils.set_globally_use_rbg_prng_key() # We use xmap only with SPMD. jax.config.update('experimental_xmap_spmd_lowering', True) # Use the manual partitioning lowering of xmap to avoid vectorization. jax.config.update('experimental_xmap_spmd_lowering_manual', True) if jax_use_gda: logging.info('Using JAX GSDA for pjit and checkpointing') if jax_backend_target: logging.info('Using JAX backend target %s', jax_backend_target) jax_xla_backend = 'None' if jax_xla_backend is None else jax_xla_backend logging.info('Using JAX XLA backend %s', jax_xla_backend) logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX devices: %r', jax.devices()) logging.info('jax.device_count(): %d', jax.device_count()) logging.info('jax.local_device_count(): %d', jax.local_device_count()) logging.info('jax.process_count(): %d', jax.process_count())
def get_translate_wmt(shuffle_rng, batch_size, eval_batch_size=None, hps=None): """Wrapper to conform to the general dataset API.""" per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() return _get_translate_wmt(per_host_batch_size, per_host_eval_batch_size, hps, shuffle_rng)
def _write_lock_values(self): write_count = 0 for p in range(jax.process_count()): # TODO(yashkatariya): Make the key value store writes safe if checkpoint # managers are created concurrently. self._client.key_value_set(_get_key(str(p)), f'Lock value for process {str(p)}') write_count += 1 if write_count != jax.process_count(): raise ValueError("Process 0 couldn't write all the lock values.") logging.info('Lock values for all processes have been written by process 0.')
def get_imagenet(shuffle_rng, batch_size, eval_batch_size, hps): """Data generators for imagenet.""" per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() image_size = hps.input_shape[0] # TODO(gilmer) Currently the training data is not determistic. logging.info('Loading train split') train_ds = load_split(per_host_batch_size, 'train', hps=hps, image_size=image_size, shuffle_rng=shuffle_rng) train_ds = tfds.as_numpy(train_ds) logging.info('Loading eval_train split') eval_train_ds = load_split(per_host_eval_batch_size, 'eval_train', hps=hps, image_size=image_size) eval_train_ds = tfds.as_numpy(eval_train_ds) logging.info('Loading eval split') eval_ds = load_split(per_host_eval_batch_size, 'valid', hps=hps, image_size=image_size) eval_ds = tfds.as_numpy(eval_ds) def train_iterator_fn(): return train_ds def eval_train_epoch(num_batches=None): # This uses per_host_batch_size and not per_host_eval_batch_size. for batch in itertools.islice(eval_train_ds, num_batches): yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size) def valid_epoch(num_batches=None): for batch in itertools.islice(eval_ds, num_batches): yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def get_total_steps(config): """Get total_steps of training. Args: config: The config of the experiment. Returns: Total_steps of training. """ local_batch_size = config.batch_size // jax.process_count() ntrain_img = input_utils.get_num_examples( config.dataset, split=config.train_split, process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = ntrain_img // config.batch_size if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get('total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * config.batch_size / ntrain_img, steps_per_epoch) return total_steps
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir') if FLAGS.mode == 'train': train.train_and_evaluate(FLAGS.config, FLAGS.workdir) else: predict.predict_and_evaluate(FLAGS.config, FLAGS.workdir, FLAGS.ckpt_path)
def main(argv): del argv # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX devices: %r", jax.devices()) # Add a note so that we can tell which task is which JAX process. platform.work_unit().set_task_status( f"process_index: {jax.process_index()}, process_count: {jax.process_count()}" ) platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") train_mode = FLAGS.config.mode if train_mode == TrainingMode.PRETRAINING: train_lib = run_pretraining elif train_mode == TrainingMode.CLASSIFICATION: train_lib = run_classifier else: raise ValueError("Unknown training mode: %s" % train_mode) train_lib.train_and_evaluate(FLAGS.config, FLAGS.workdir, FLAGS.vocab_filepath)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') utils.add_gfile_logger(_WORKDIR.value) # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') jax.config.update('jax_log_compiles', True) logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) jax_xla_backend = ('None' if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend) logging.info('Using JAX XLA backend %s', jax_xla_backend) logging.info('Config: %s', FLAGS.config) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, _WORKDIR.value, 'workdir') if FLAGS.config.trainer == 'train': train.train_and_evaluate(FLAGS.config, _WORKDIR.value) elif FLAGS.config.trainer == 'inference_time': inference_time.inference_time(FLAGS.config, _WORKDIR.value) else: raise app.UsageError(f'Unknown trainer: {FLAGS.config.trainer}')
def evaluate( model_name: str, job_log_dir: Optional[str], multi_host_checkpointing: Optional[bool], maybe_use_persistence_checkpointing: bool, ) -> None: """Runs the evaluation loop on the entire eval data set. Args: model_name: The name of the model from the registry to evaluate. job_log_dir: The directory for the job logs. multi_host_checkpointing: Whether to use multi-host checkpointing. maybe_use_persistence_checkpointing: If set, it will try to use persistence-based checkpointing if suitable. """ model_config = model_utils.get_model(model_name)() task_p = model_config.task() model_p = task_p.model eval_input_p = [v for v in model_config.datasets() if not v.is_training] for inp in eval_input_p: inp.num_infeed_hosts = jax.process_count() inp.infeed_host_index = jax.process_index() if model_p.device_mesh is not None: checkpoint_type = checkpoints.retrieve_checkpoint_type( multi_host_checkpointing, maybe_use_persistence_checkpointing, task_p) evaluate_spmd_model(task_p, eval_input_p, job_log_dir, checkpoint_type) else: evaluate_pmap_model(task_p, eval_input_p, job_log_dir)
def create_init(model, config, train_ds): """Create the initialization function for model parameters. Args: model: The model to be used in updates. config: The config of the experiment. train_ds: tf.data.Dataset. Returns: Function that returns initialized model parameters. """ local_batch_size = config.batch_size // jax.process_count() # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @functools.partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input, train=False))['params'] # Set bias in the head to a low value, such that loss is small initially. params['batchensemble_head']['bias'] = jnp.full_like( params['batchensemble_head']['bias'], config.get('init_head_bias', 0)) # init head kernel to all zeros for fine-tuning if config.get('model_init'): params['batchensemble_head']['kernel'] = jnp.full_like( params['batchensemble_head']['kernel'], 0) return params return init
def main(argv): del argv # Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX host: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX devices: %r", jax.devices()) # Add a note so that we can tell which task is which JAX host. (Task 0 is not # guaranteed to be host 0) platform.work_unit().set_task_status( f"host_id: {jax.process_index()}, host_count: {jax.process_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, _WORKDIR.value, "workdir") train_mode = _CONFIG.value.mode if train_mode == TrainingMode.PRETRAINING: train_lib = run_pretraining elif train_mode == TrainingMode.CLASSIFICATION: train_lib = run_classifier else: raise ValueError("Unknown mode: %s" % train_mode) train_lib.train_and_evaluate(_CONFIG.value, _WORKDIR.value, _VOCAB_FILEPATH.value)
def get_data_iterator(config): """Get iterator over the dataset.""" task = memory_generation_task.MemoryGenerationTask # Establish host information local_device_count = jax.local_device_count() host_count = jax.process_count() host_id = jax.process_index() # Load datasets logging.info('Loading dataset.') decode_fn = data_utils.make_decode_fn( name_to_features=task.get_name_to_features(config), samples_per_example=config.samples_per_example, ) preprocess_fn = task.make_preprocess_fn(config) collater_fn = task.make_collater_fn(config) data = data_utils.load_dataset( patterns=config.data_patterns, decode_fn=decode_fn, preprocess_fn=preprocess_fn, collater_fn=collater_fn, is_training=False, per_device_batch_size=config.per_device_batch_size, local_device_count=local_device_count, host_count=host_count, host_id=host_id, seed=0, ) return iter(data)
def main(argv): del argv # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") if FLAGS.jax_backend_target: logging.info("Using JAX backend target %s", FLAGS.jax_backend_target) jax_xla_backend = ("None" if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend) logging.info("Using JAX XLA backend %s", jax_xla_backend) logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX devices: %r", jax.devices()) if FLAGS.is_train: # Add a note so that we can tell which task is which JAX host. # (Depending on platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") train_lib.train_and_evaluate(FLAGS.ml_config, FLAGS.workdir) else: eval_lib.evaluate(FLAGS.ml_config, FLAGS.workdir)
def update_fn(updates, state, params): del params key, subkey = jax.random.split(state.key) # Compute updates based on inner optimizer updates, inner_state = optimizer.update(updates, state.inner_state) prob = state.expert_weights.sum(axis=0) / state.expert_weights.sum() # NOTE(dsuo): we rely on jax determinism for each host to behave the same. current_expert = jax.random.choice(subkey, jnp.arange(prob.size), p=prob) # Synchronize train_losses across hosts. # NOTE(dsuo): since we are already insider a pmap, we can't use # jax.experimental.multihost_utils. # NOTE(dsuo): train_losses is of shape (jax.process_count(),). train_losses = jax.lax.all_gather(train_loss, 'batch').reshape( jax.process_count(), jax.local_device_count())[:, 0] # Compute loss regret and update expert weights. loss_regret = train_losses.at[current_expert].get() - train_losses expert_weights = state.expert_weights * jnp.exp(mw_etas * loss_regret) state = SamuelState( inner_state=inner_state, expert_weights=expert_weights, key=key, current_expert=current_expert, step=state.step + 1, ) return updates, state
def _thread_func(self, temp_checkpoint_dir, final_checkpoint_dir): try: for future in self._commit_futures: for f in future: f.result() logging.info('Commit to storage layer has completed.') current_process = jax.process_index() lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles') all_lockfile_paths = [ os.path.join(lockfiles_dir, f'lockfile_{p}') for p in range(jax.process_count()) ] current_process_lockfile = os.path.join( lockfiles_dir, f'lockfile_{current_process}') with _RetryWithTimeout(self._timeout_secs) as t: while not tf.io.gfile.exists(current_process_lockfile): if t.timed_out: raise RuntimeError( 'Terminating after waiting for ' f'{self._timeout_secs} secs for lockfile to appear' ) logging.info( 'Waiting for current process %s lockfile to appear.', current_process) time.sleep(60) tf.io.gfile.remove(current_process_lockfile) logging.info('Lockfile removed for process %s', current_process) # This while loop will not trigger until all commits have finished. if current_process == 0: with _RetryWithTimeout(self._timeout_secs) as t: while True: if t.timed_out: raise RuntimeError( 'Terminating after waiting for ' f'{self._timeout_secs} secs for ' 'finishing the serialization.') # Mark as done when no lockfiles exist. if no_lockfiles_exists(all_lockfile_paths): tf.io.gfile.rmtree(lockfiles_dir) logging.info('Lockfiles directory removed.') logging.info('Renaming %s to %s', temp_checkpoint_dir, final_checkpoint_dir) tf.io.gfile.rename(temp_checkpoint_dir, final_checkpoint_dir) logging.info( 'Finished saving GDA checkpoint to `%s`.', final_checkpoint_dir) break else: logging.info('Thread sleeping for 60 seconds.') time.sleep(60) except Exception as e: self._exception = e
def get_fastmri(shuffle_rng, batch_size, eval_batch_size, hps): """FastMRI dataset. Args: shuffle_rng: rng for shuffling. batch_size: batch size. eval_batch_size: batch size for eval. hps: hyperparameters. Returns: An init2winit Dataset. """ per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() train_ds = load_split(per_host_batch_size, 'train', hps, shuffle_rng) train_ds = tfds.as_numpy(train_ds) # NOTE(dsuo): fastMRI has fixed randomness for eval. eval_train_ds = load_split(per_host_eval_batch_size, 'eval_train', hps, shuffle_rng) eval_train_ds = tfds.as_numpy(eval_train_ds) eval_ds = load_split(per_host_eval_batch_size, 'val', hps, shuffle_rng) eval_ds = tfds.as_numpy(eval_ds) def train_iterator_fn(): return train_ds def eval_train_epoch(num_batches=None): for batch in itertools.islice(eval_train_ds, num_batches): yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size) def valid_epoch(num_batches=None): for batch in itertools.islice(eval_ds, num_batches): yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def _write_lockfiles(self, temp_checkpoint_dir): lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles') tf.io.gfile.mkdir(lockfiles_dir) write_count = 0 for p in range(jax.process_count()): with tf.io.gfile.GFile(os.path.join(lockfiles_dir, f'lockfile_{p}'), mode='w') as f: f.write('File to track if all chunks have been written.') write_count += 1 if write_count != jax.process_count(): raise ValueError("Process 0 couldn't write all the lockfiles.") logging.info( 'Lock files for all processes have been written by process 0.')
def _get_uniref(per_host_batch_size, per_host_eval_batch_size, hps, data_rng): """Data generators for Uniref50 clustered protein dataset.""" # TODO(gilmer) Currently uniref drops the last partial batch on eval. logging.warning( 'Currently the Protein dataset drops the last partial batch on eval') if jax.process_count() > 1: raise NotImplementedError( 'Proteins does not support multihost training') n_devices = jax.local_device_count() if per_host_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_batch_size={}.'.format( n_devices, per_host_batch_size)) if per_host_eval_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_eval_batch_size={}.'.format( n_devices, per_host_eval_batch_size)) train_ds, eval_ds, vocab = load_dataset( hps.data_name, batch_size=per_host_batch_size, eval_batch_size=per_host_eval_batch_size, length=hps.max_target_length) masker = BertMasker(vocab=vocab) def train_iterator_fn(): for batch_index, batch in enumerate(iter(train_ds)): batch_rng = jax.random.fold_in(data_rng, batch_index) yield _batch_to_dict(batch, masker, 'train', batch_rng) def eval_train_epoch(num_batches=None): eval_train_iter = iter(train_ds) for batch_index, batch in enumerate( itertools.islice(eval_train_iter, num_batches)): batch_rng = jax.random.fold_in(data_rng, batch_index) yield _batch_to_dict(batch, masker, 'eval', batch_rng) def valid_epoch(num_batches=None): valid_iter = iter(eval_ds) for batch_index, batch in enumerate( itertools.islice(valid_iter, num_batches)): batch_rng = jax.random.fold_in(data_rng, batch_index) yield _batch_to_dict(batch, masker, 'eval', batch_rng) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def get_split(rng, builder, split, batch_size, num_epochs=None, shuffle_buffer_size=None, repeat_after=False, cache=False): """Loads a audio dataset and shifts audio values to be positive. Args: rng: JAX PRNGKey random number generator state. builder: TFDS dataset builder instance. split: TFDS split to load. batch_size: Global batch size. num_epochs: Number of epochs. None to repeat forever. shuffle_buffer_size: Size of the shuffle buffer. If None, data is not shuffled. repeat_after: If True, the dataset is repeated infinitely *after* CLU. cache: If True, the dataset is cached prior to pre-processing. Returns: Audio datasets with `inputs` and `label` features. The former is shifted to be non-negative. """ host_count = jax.process_count() if batch_size % host_count != 0: raise ValueError( f'Batch size ({batch_size}) must be divisible by the host' f' count ({host_count}).') batch_size = batch_size // host_count device_count = jax.local_device_count() if batch_size % device_count != 0: raise ValueError( f'Local batch size ({batch_size}) must be divisible by the' f' local device count ({device_count}).') batch_dims = [device_count, batch_size // device_count] host_split = data.get_read_instruction_for_host( split, dataset_info=builder.info, remainder_options=data.RemainderOptions.BALANCE_ON_PROCESSES) ds = data.create_dataset(builder, split=host_split, preprocess_fn=PrepareAudio(), cache=cache, batch_dims=batch_dims, rng=rng, num_epochs=num_epochs, pad_up_to_batches='auto', shuffle=shuffle_buffer_size and (shuffle_buffer_size > 0), shuffle_buffer_size=shuffle_buffer_size or 0) if repeat_after: ds = ds.repeat() return ds
def get_fake(shuffle_rng, batch_size, eval_batch_size, hps=None): """Data generators for imagenet.""" del shuffle_rng per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() fake_train_batch = get_fake_batch(per_host_batch_size, hps.input_shape, hps.output_shape[0]) fake_test_batch = get_fake_batch(per_host_eval_batch_size, hps.input_shape, hps.output_shape[0]) def train_iterator_fn(): while True: yield fake_train_batch def valid_epoch(epoch, num_batches=None): del num_batches del epoch # Note that we do // beacuse we do not support partial batching for the fake # dataset. for _ in range(hps.valid_size // eval_batch_size): yield fake_test_batch # pylint: disable=unreachable def eval_train_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def main(unused_argv): # Necessary to use the tfds loader. tf.enable_v2_behavior() if jax.process_count() > 1: # TODO(ankugarg): Add support for multihost inference. raise NotImplementedError( 'BLEU eval does not support multihost inference.') rng = jax.random.PRNGKey(FLAGS.seed) mt_eval_config = json.loads(FLAGS.mt_eval_config) if FLAGS.experiment_config_filename: with tf.io.gfile.GFile(FLAGS.experiment_config_filename) as f: experiment_config = json.load(f) if jax.process_index() == 0: logging.info('experiment_config: %r', experiment_config) dataset_name = experiment_config['dataset'] model_name = experiment_config['model'] else: assert FLAGS.dataset and FLAGS.model dataset_name = FLAGS.dataset model_name = FLAGS.model if jax.process_index() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.host_count()) logging.info('host_id : %d', jax.host_id()) model_class = models.get_model(model_name) dataset_builder = datasets.get_dataset(dataset_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) hparam_overrides = None if FLAGS.hparam_overrides: if isinstance(FLAGS.hparam_overrides, str): hparam_overrides = json.loads(FLAGS.hparam_overrides) merged_hps = hyperparameters.build_hparams( model_name=model_name, initializer_name=experiment_config['initializer'], dataset_name=dataset_name, hparam_file=FLAGS.trial_hparams_filename, hparam_overrides=hparam_overrides) if jax.process_index() == 0: logging.info('Merged hps are: %s', json.dumps(merged_hps.to_json())) evaluator = bleu_evaluator.BLEUEvaluator(FLAGS.checkpoint_dir, merged_hps, rng, model_class, dataset_builder, dataset_meta_data, mt_eval_config) evaluator.translate_and_calculate_bleu()
def main(unused_argv): # Necessary to use the tfds imagenet loader. tf.enable_v2_behavior() rng = jax.random.PRNGKey(FLAGS.seed) if FLAGS.hessian_eval_config: hessian_eval_config = json.loads(FLAGS.hessian_eval_config) else: hessian_eval_config = hessian_eval.DEFAULT_EVAL_CONFIG if FLAGS.experiment_config_filename: with tf.io.gfile.GFile(FLAGS.experiment_config_filename, 'r') as f: experiment_config = json.load(f) if jax.process_index() == 0: logging.info('experiment_config: %r', experiment_config) dataset_name = experiment_config['dataset'] model_name = experiment_config['model'] else: assert FLAGS.dataset and FLAGS.model dataset_name = FLAGS.dataset model_name = FLAGS.model if jax.process_index() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.process_count()) logging.info('host_id : %d', jax.process_index()) model = models.get_model(model_name) dataset_builder = datasets.get_dataset(dataset_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) with tf.io.gfile.GFile(FLAGS.trial_hparams_filename, 'r') as f: hps = config_dict.ConfigDict(json.load(f)) if FLAGS.hparam_overrides: if isinstance(FLAGS.hparam_overrides, str): hparam_overrides = json.loads(FLAGS.hparam_overrides) hps.update_from_flattened_dict(hparam_overrides) run_lanczos.eval_checkpoints( FLAGS.checkpoint_dir, hps, rng, FLAGS.eval_num_batches, model, dataset_builder, dataset_meta_data, hessian_eval_config, FLAGS.min_global_step, FLAGS.max_global_step)
def load_array(self, pattern: str): """Load sharded array as if it was loaded from multiple processes.""" process_count = jax.process_count() arrays = [] for process_index in range(process_count): arrays.append( data_utils.load_sharded_array( pattern, process_count * self.memory_reduction, process_index)) array = np.stack(arrays, axis=0) shape = (-1, ) + arrays[0].shape[1:] array = array.reshape(shape) return array
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") # Log available devices. This code only supports a single host currently. logging.info("JAX host: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX local devices: %r", jax.local_devices()) render_and_save()
def _thread_func(self, temp_checkpoint_dir, final_checkpoint_dir): try: for future in self._commit_futures: for f in future: f.result() logging.info('Commit to storage layer has completed.') current_process = jax.process_index() # TODO(yashkatariya): Add a method to the distributed system to wait for # the key's value to change. # Value already exists -- wait until value is NOT _REMOVED_VALUE. with _RetryWithTimeout(self._timeout_secs) as t: while self._client.blocking_key_value_get( _get_key(str(current_process)), self._timeout_in_ms) == _REMOVED_VALUE: if t.timed_out: raise TimeoutError('Terminating after waiting for ' f'{self._timeout_secs} secs for lock value to appear.') logging.info('Waiting for current process %s lock value to appear.', current_process) time.sleep(60) self._client.key_value_set(_get_key(str(current_process)), _REMOVED_VALUE) logging.info('Lock value removed for process %s', current_process) # This while loop will not trigger until all commits have finished. if current_process == 0: with _RetryWithTimeout(self._timeout_secs) as t: while True: if t.timed_out: raise TimeoutError('Terminating after waiting for ' f'{self._timeout_secs} secs for ' 'finishing the serialization.') # Mark as done when no lock values exist. if all( self._client.blocking_key_value_get( _get_key(str(p)), self._timeout_in_ms) == _REMOVED_VALUE for p in range(jax.process_count())): logging.info('Renaming %s to %s', temp_checkpoint_dir, final_checkpoint_dir) tf.io.gfile.rename(temp_checkpoint_dir, final_checkpoint_dir) logging.info('Finished saving GDA checkpoint to `%s`.', final_checkpoint_dir) self._client.key_value_set(_get_key(self._final_ckpt_dir), _CHECKPOINT_SUCCESS) break else: logging.info('Thread sleeping for 60 seconds.') time.sleep(60) except Exception as e: self._exception = e
def create_eval_dataset( config, dataset_builder, split, preprocess_fn = None, ): """Create evaluation dataset (validation or test sets).""" # This ensures the correct number of elements in the validation sets. num_validation_examples = ( dataset_builder.info.splits[split].num_examples) eval_split = deterministic_data.get_read_instruction_for_host( split, dataset_info=dataset_builder.info, drop_remainder=False) eval_num_batches = None if config.eval_pad_last_batch: # This is doing some extra work to get exactly all examples in the # validation split. Without this the dataset would first be split between # the different hosts and then into batches (both times dropping the # remainder). If you don't mind dropping a few extra examples you can omit # the `pad_up_to_batches` argument. eval_batch_size = jax.local_device_count() * config.per_device_batch_size eval_num_batches = int(np.ceil(num_validation_examples / eval_batch_size / jax.process_count())) return deterministic_data.create_dataset( dataset_builder, split=eval_split, # Only cache dataset in distributed setup to avoid consuming a lot of # memory in Colab and unit tests. cache=jax.process_count() > 1, batch_dims=[jax.local_device_count(), config.per_device_batch_size], num_epochs=1, shuffle=False, preprocess_fn=preprocess_fn, pad_up_to_batches=eval_num_batches, )
def get_criteo1tb(unused_shuffle_rng, batch_size, eval_batch_size, hps): """Get the Criteo 1TB train and eval iterators.""" process_count = jax.process_count() if batch_size % process_count != 0: raise ValueError('process_count={} must divide batch_size={}.'.format( process_count, batch_size)) if eval_batch_size is None: eval_batch_size = batch_size if eval_batch_size % process_count != 0: raise ValueError( 'process_count={} must divide eval_batch_size={}.'.format( process_count, eval_batch_size)) per_host_eval_batch_size = eval_batch_size // process_count per_host_batch_size = batch_size // process_count train_dataset = _criteo_tsv_reader( file_path=hps.train_file_path, num_dense_features=hps.num_dense_features, vocab_sizes=hps.vocab_sizes, batch_size=per_host_batch_size, is_training=True) train_iterator_fn = lambda: tfds.as_numpy(train_dataset) eval_train_dataset = _criteo_tsv_reader( file_path=hps.train_file_path, num_dense_features=hps.num_dense_features, vocab_sizes=hps.vocab_sizes, batch_size=per_host_eval_batch_size, is_training=False) eval_train_epoch = functools.partial(convert_to_numpy_iterator_fn, tf_dataset=eval_train_dataset) eval_dataset = _criteo_tsv_reader( file_path=hps.eval_file_path, num_dense_features=hps.num_dense_features, vocab_sizes=hps.vocab_sizes, batch_size=per_host_eval_batch_size, is_training=False) eval_iterator_fn = functools.partial(convert_to_numpy_iterator_fn, tf_dataset=eval_dataset) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return Dataset(train_iterator_fn, eval_train_epoch, eval_iterator_fn, test_epoch)
def decode( model_name: str, job_log_dir: Optional[str], multi_host_checkpointing: Optional[bool], maybe_use_persistence_checkpointing: bool, restore_checkpoint_dir: Optional[str], restore_checkpoint_step: Optional[int], continuous_decode: bool, ) -> None: """Runs decoding once on the decoder datasets. Args: model_name: The name of the model from the registry to evaluate. job_log_dir: The directory for the job logs. multi_host_checkpointing: Whether to use multi-host checkpointing. maybe_use_persistence_checkpointing: If set, it will try to use persistence-based checkpointing if suitable. restore_checkpoint_dir: The directory from which to restore checkpoint. restore_checkpoint_step: If set, the checkpoint step to restore. If unset, try to restore from the latest checkpoint if any. continuous_decode: whether to continuously decode on the latest ckpt. """ logging.info('running decode_once on model %s restored from %s', model_name, restore_checkpoint_dir) model_config = model_utils.get_model(model_name)() task_p = model_config.task() model_p = task_p.model decoder_inputs = model_config.decoder_datasets() if not decoder_inputs: return for inp in decoder_inputs: inp.num_infeed_hosts = jax.process_count() inp.infeed_host_index = jax.process_index() if model_p.device_mesh is not None: if continuous_decode: raise NotImplementedError('http://b/214589358: not supported') checkpoint_type = checkpoints.retrieve_checkpoint_type( multi_host_checkpointing, maybe_use_persistence_checkpointing, task_p) decode_once_spmd_model(task_p, decoder_inputs, job_log_dir, checkpoint_type, restore_checkpoint_dir, restore_checkpoint_step) else: decode_pmap_model(task_p, decoder_inputs, job_log_dir, restore_checkpoint_dir, restore_checkpoint_step, continuous_decode)
def __init__(self, loss_fn, optimizer, devices=None, has_graph=False): self._net_init_fn, self._apply_fn = hk.transform_with_state( functools.partial(loss_fn, is_training=True)) _, self._eval_apply_fn = hk.transform_with_state( functools.partial(loss_fn, is_training=False)) if optimizer is None: optimizer = optax.identity() self._optimizer = optimizer self._num_devices = jax.local_device_count() if devices is None: devices = [] for host_id in range(jax.process_count()): for device_id in jax.local_devices(host_id): devices.append(device_id) else: self._num_devices = min(self._num_devices, len(devices)) def _pmap(f, static_broadcasted_argnums=()): return jax.pmap( f, axis_name='i', devices=devices, static_broadcasted_argnums=static_broadcasted_argnums) def handle_graph_size(fn): def _fn(*args): batch = args[-1].copy() max_graph_size = batch['max_graph_size'] del batch['max_graph_size'] args = args[:-1] + (batch, max_graph_size) return fn(*args) return _fn # Try to jit. if has_graph: # If the model contains full graphs, we need to set the max_graph_size # as a statically broadcasted argument. self._init_fn = handle_graph_size(_pmap(self._init, 4)) self._update_fn = handle_graph_size(_pmap(self._update, 2)) self._eval_fn = handle_graph_size(_pmap(self._eval, 2)) else: self._init_fn = _pmap(self._init) self._update_fn = _pmap(self._update) self._eval_fn = _pmap(self._eval)