def _delete_checkpoint( self, checkpoint: checkpoint_pb2.CheckpointMetadata) -> None: """Deletes the checkpoint files for a given checkpoint.""" logging.info('Deleting checkpoint: %s %s', checkpoint.timestamp_sec, self._root_dir) if (self._checkpoint_history.checkpoint_type == CheckpointType.CHECKPOINT_FLAX): if jax.process_index() != 0: return self._delete_pattern_if_exists( self._root_dir, f'{CHECKPOINT_PREFIX}{checkpoint.global_step_id}') elif (self._checkpoint_history.checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX): root_dir = os.path.join(self._root_dir, f'{jax.process_index():03d}') self._delete_pattern_if_exists( root_dir, f'{CHECKPOINT_PREFIX}{checkpoint.global_step_id}') elif self._checkpoint_history.checkpoint_type in { CheckpointType.CHECKPOINT_PERSISTENCE, CheckpointType.CHECKPOINT_GDA, }: if jax.process_index() != 0: return self._delete_pattern_if_exists( self._root_dir, f'{CHECKPOINT_PREFIX}{checkpoint.global_step_id:08d}')
def process_iterator(tag: str, item_ids: Sequence[str], iterator, rng: types.PRNGKey, state: model_utils.TrainState, step: int, render_fn: Any, summary_writer: tensorboard.SummaryWriter, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource): """Process a dataset iterator and compute metrics.""" save_dir = save_dir / f'{step:08d}' / tag if save_dir else None meters = collections.defaultdict(utils.ValueMeter) for i, (item_id, batch) in enumerate(zip(item_ids, iterator)): logging.info('[%s:%d/%d] Processing %s ', tag, i+1, len(item_ids), item_id) if tag == 'test': test_rng = random.PRNGKey(step) shape = batch['origins'][..., :1].shape metadata = {} if datasource.use_appearance_id: appearance_id = random.choice( test_rng, jnp.asarray(datasource.appearance_ids)) logging.info('\tUsing appearance_id = %d', appearance_id) metadata['appearance'] = jnp.full(shape, fill_value=appearance_id, dtype=jnp.uint32) if datasource.use_warp_id: warp_id = random.choice(test_rng, jnp.asarray(datasource.warp_ids)) logging.info('\tUsing warp_id = %d', warp_id) metadata['warp'] = jnp.full(shape, fill_value=warp_id, dtype=jnp.uint32) if datasource.use_camera_id: camera_id = random.choice(test_rng, jnp.asarray(datasource.camera_ids)) logging.info('\tUsing camera_id = %d', camera_id) metadata['camera'] = jnp.full(shape, fill_value=camera_id, dtype=jnp.uint32) if datasource.use_time: timestamp = random.uniform(test_rng, minval=0.0, maxval=1.0) logging.info('\tUsing time = %d', timestamp) metadata['time'] = jnp.full( shape, fill_value=timestamp, dtype=jnp.uint32) batch['metadata'] = metadata stats = process_batch(batch=batch, rng=rng, state=state, tag=tag, item_id=item_id, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource) if jax.process_index() == 0: for k, v in stats.items(): meters[k].update(v) if jax.process_index() == 0: for meter_name, meter in meters.items(): summary_writer.scalar(tag=f'metrics-eval/{meter_name}/{tag}', value=meter.reduce('mean'), step=step)
def meta_init(loss_fn, flax_module, params, hps, input_shape, output_shape, rng_key, metrics_logger=None, log_every=10): """Implements MetaInit initializer. Args: loss_fn: Loss function. flax_module: Flax nn.Module class. params: The dict of model parameters. hps: HParam object. Required hparams are meta_learning_rate, meta_batch_size, meta_steps, and epsilon. input_shape: Must agree with batch[0].shape[1:]. output_shape: Must agree with batch[1].shape[1:]. rng_key: jax.PRNGKey, used to seed all randomness. metrics_logger: Instance of utils.MetricsLogger log_every: Print meta loss every k steps. Returns: A Flax module with the learned initialization. """ # Pretty print the preinitialized norms with the variable shapes. if jax.process_index() == 0: logging.info('Preinitialized norms:') _log_shape_and_norms(params, metrics_logger, key='init_norms') # First grab the norms of all weights and rescale params to have norm 1. logging.info('Running meta init') norms = jax.tree_map(lambda node: jnp.linalg.norm(node.reshape(-1)), params) normalized_params = jax.tree_map(normalize, params) learned_norms, _ = meta_optimize_scales( loss_fn, flax_module.apply, normalized_params, norms, hps, input_shape, output_shape, rng_key, metrics_logger=metrics_logger, log_every=log_every) new_params = scale_params(normalized_params, learned_norms) if jax.process_index() == 0: # Pretty print the meta init norms with the variable shapes. logging.info('Learned norms from meta_init:') _log_shape_and_norms(new_params, metrics_logger, key='meta_init_norms') return new_params
def main(unused_argv): # Necessary to use the tfds loader. tf.enable_v2_behavior() if jax.process_count() > 1: # TODO(ankugarg): Add support for multihost inference. raise NotImplementedError( 'BLEU eval does not support multihost inference.') rng = jax.random.PRNGKey(FLAGS.seed) mt_eval_config = json.loads(FLAGS.mt_eval_config) if FLAGS.experiment_config_filename: with tf.io.gfile.GFile(FLAGS.experiment_config_filename) as f: experiment_config = json.load(f) if jax.process_index() == 0: logging.info('experiment_config: %r', experiment_config) dataset_name = experiment_config['dataset'] model_name = experiment_config['model'] else: assert FLAGS.dataset and FLAGS.model dataset_name = FLAGS.dataset model_name = FLAGS.model if jax.process_index() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.host_count()) logging.info('host_id : %d', jax.host_id()) model_class = models.get_model(model_name) dataset_builder = datasets.get_dataset(dataset_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) hparam_overrides = None if FLAGS.hparam_overrides: if isinstance(FLAGS.hparam_overrides, str): hparam_overrides = json.loads(FLAGS.hparam_overrides) merged_hps = hyperparameters.build_hparams( model_name=model_name, initializer_name=experiment_config['initializer'], dataset_name=dataset_name, hparam_file=FLAGS.trial_hparams_filename, hparam_overrides=hparam_overrides) if jax.process_index() == 0: logging.info('Merged hps are: %s', json.dumps(merged_hps.to_json())) evaluator = bleu_evaluator.BLEUEvaluator(FLAGS.checkpoint_dir, merged_hps, rng, model_class, dataset_builder, dataset_meta_data, mt_eval_config) evaluator.translate_and_calculate_bleu()
def compute_interpolations(self, params, gdirs, udirs, hvex, cvex, step): """Compute the linear interpolation along directions of gdirs or udirs.""" row = {'step': step} if not self.eval_config['compute_interps']: return row lower = self.eval_config['lower_thresh'] upper = self.eval_config['upper_thresh'] num_points = self.eval_config['num_points'] etas = np.linspace(lower, upper, num=num_points, endpoint=True) row = {'step_size': etas} for i, u_dir in enumerate(gdirs): u_dir = _tree_normalize(u_dir) loss_values = np.zeros(shape=(num_points, )) for j in range(num_points): eta = etas[j] loss_values[j] = self._full_batch_eval(params, u_dir, eta) row['loss%d' % (i, )] = np.copy(loss_values) if jax.process_index() == 0: logging.info('Loss interpolation along gradients finished.') for i, u_dir in enumerate(udirs): u_dir = _tree_normalize(u_dir) loss_values = np.zeros(shape=(num_points, )) for j in range(num_points): eta = etas[j] loss_values[j] = self._full_batch_eval(params, u_dir, eta) row['loss_u%d' % (i, )] = np.copy(loss_values) if jax.process_index() == 0: logging.info( 'Loss interpolation along optimizer directions finished.') _, unflatten = ravel_pytree(gdirs[0]) for i, u_dir in enumerate(hvex): loss_values = np.zeros(shape=(num_points, )) u_dir = unflatten(u_dir) for j in range(num_points): eta = etas[j] loss_values[j] = self._full_batch_eval(params, u_dir, eta) row['loss_hvec%d' % (i, )] = np.copy(loss_values) for i, u_dir in enumerate(cvex): loss_values = np.zeros(shape=(num_points, )) u_dir = unflatten(u_dir) for j in range(num_points): eta = etas[j] loss_values[j] = self._full_batch_eval(params, u_dir, eta) row['loss_cvec%d' % (i, )] = np.copy(loss_values) if jax.process_index() == 0: logging.info('Loss interpolations finished. Statistics captured:') logging.info(row.keys()) return row
def main(unused_argv): # Necessary to use the tfds imagenet loader. tf.enable_v2_behavior() rng = jax.random.PRNGKey(FLAGS.seed) if FLAGS.hessian_eval_config: hessian_eval_config = json.loads(FLAGS.hessian_eval_config) else: hessian_eval_config = hessian_eval.DEFAULT_EVAL_CONFIG if FLAGS.experiment_config_filename: with tf.io.gfile.GFile(FLAGS.experiment_config_filename, 'r') as f: experiment_config = json.load(f) if jax.process_index() == 0: logging.info('experiment_config: %r', experiment_config) dataset_name = experiment_config['dataset'] model_name = experiment_config['model'] else: assert FLAGS.dataset and FLAGS.model dataset_name = FLAGS.dataset model_name = FLAGS.model if jax.process_index() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.process_count()) logging.info('host_id : %d', jax.process_index()) model = models.get_model(model_name) dataset_builder = datasets.get_dataset(dataset_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) with tf.io.gfile.GFile(FLAGS.trial_hparams_filename, 'r') as f: hps = config_dict.ConfigDict(json.load(f)) if FLAGS.hparam_overrides: if isinstance(FLAGS.hparam_overrides, str): hparam_overrides = json.loads(FLAGS.hparam_overrides) hps.update_from_flattened_dict(hparam_overrides) run_lanczos.eval_checkpoints( FLAGS.checkpoint_dir, hps, rng, FLAGS.eval_num_batches, model, dataset_builder, dataset_meta_data, hessian_eval_config, FLAGS.min_global_step, FLAGS.max_global_step)
def evaluate( model_name: str, job_log_dir: Optional[str], multi_host_checkpointing: Optional[bool], maybe_use_persistence_checkpointing: bool, ) -> None: """Runs the evaluation loop on the entire eval data set. Args: model_name: The name of the model from the registry to evaluate. job_log_dir: The directory for the job logs. multi_host_checkpointing: Whether to use multi-host checkpointing. maybe_use_persistence_checkpointing: If set, it will try to use persistence-based checkpointing if suitable. """ model_config = model_utils.get_model(model_name)() task_p = model_config.task() model_p = task_p.model eval_input_p = [v for v in model_config.datasets() if not v.is_training] for inp in eval_input_p: inp.num_infeed_hosts = jax.process_count() inp.infeed_host_index = jax.process_index() if model_p.device_mesh is not None: checkpoint_type = checkpoints.retrieve_checkpoint_type( multi_host_checkpointing, maybe_use_persistence_checkpointing, task_p) evaluate_spmd_model(task_p, eval_input_p, job_log_dir, checkpoint_type) else: evaluate_pmap_model(task_p, eval_input_p, job_log_dir)
def _thread_func(self, temp_checkpoint_dir, final_checkpoint_dir): try: for future in self._commit_futures: for f in future: f.result() current_process = jax.process_index() logging.info( 'Commit to storage layer has completed by process: %s', current_process) # All processes will wait at the barrier. When all processes are at the # barrier, the barrier will be satisfied. If not, then it will timeout. self._client.wait_at_barrier(self._final_ckpt_dir, self._timeout_in_ms) logging.info('Finished waiting at barrier for process %s', current_process) if current_process == 0: logging.info('Renaming %s to %s', temp_checkpoint_dir, final_checkpoint_dir) epath.Path(temp_checkpoint_dir).rename(final_checkpoint_dir) logging.info('Finished saving GDA checkpoint to `%s`.', final_checkpoint_dir) self._client.key_value_set(_get_key(self._final_ckpt_dir), _CHECKPOINT_SUCCESS) except Exception as e: self._exception = e
def main(argv): del argv # Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX host: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX devices: %r", jax.devices()) # Add a note so that we can tell which task is which JAX host. (Task 0 is not # guaranteed to be host 0) platform.work_unit().set_task_status( f"host_id: {jax.process_index()}, host_count: {jax.process_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, _WORKDIR.value, "workdir") train_mode = _CONFIG.value.mode if train_mode == TrainingMode.PRETRAINING: train_lib = run_pretraining elif train_mode == TrainingMode.CLASSIFICATION: train_lib = run_classifier else: raise ValueError("Unknown mode: %s" % train_mode) train_lib.train_and_evaluate(_CONFIG.value, _WORKDIR.value, _VOCAB_FILEPATH.value)
def main(argv): del argv # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") if FLAGS.jax_backend_target: logging.info("Using JAX backend target %s", FLAGS.jax_backend_target) jax_xla_backend = ("None" if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend) logging.info("Using JAX XLA backend %s", jax_xla_backend) logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX devices: %r", jax.devices()) if FLAGS.is_train: # Add a note so that we can tell which task is which JAX host. # (Depending on platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") train_lib.train_and_evaluate(FLAGS.ml_config, FLAGS.workdir) else: eval_lib.evaluate(FLAGS.ml_config, FLAGS.workdir)
def restore_checkpoint(work_dir): """Given a valid dir, restores a checkpoint and returns ALSState object.""" print(f"Attempting restore_checkpoint from dir: {work_dir}.") # Each host stores state in a seperate subdir. host_dir = multihost_utils.get_host_dir(work_dir, host_id=jax.process_index()) # First retore to host then device_put sharded array. state = checkpoints.restore_checkpoint(host_dir, target=None) def device_put_sharded(x): if not isinstance(x, (jnp.ndarray, np.ndarray)): return x # Later, device_put_sharded takes a sequence of tensors, one tensor for # every local device. So we split it on the zeroth (device) dimension. x = np.reshape(x, [jax.local_device_count(), -1, x.shape[2]]) x_list = np.split(x, x.shape[0], axis=0) # Squeeze out the dummy dimension. x_list = jax.tree_map(lambda y: np.squeeze(y, axis=0), x_list) # Send the sharded array in devices. return jax.device_put_sharded(x_list, jax.local_devices()) state = jax.tree_map(device_put_sharded, state) return state
def train_step(train_state, batch, label_smoothing): def loss_fn(params): images = rearrange(batch['images'], 'H W C N -> N H W C') images = images.astype(jnp.bfloat16) logits = train_state.apply_fn(params, images, is_training=True) y = one_hot(batch['labels']) if 'mix_labels' in batch: y1 = one_hot(batch['mix_labels']) y = batch['ratio'][:, None] * y + (1. - batch['ratio'][:, None]) * y1 y = optax.smooth_labels(y, label_smoothing) logits = logits.astype(jnp.float32) loss = jnp.mean(optax.softmax_cross_entropy(logits, y)) scaled_loss = loss / jax.device_count() return scaled_loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True, axis_name='batch') aux, grads = grad_fn(train_state.params) grads = jax.lax.pmean(grads, axis_name='batch') loss, logits = aux top_k_acc = utils.topk_correct(logits, batch['labels'], prefix='train_') top_k_acc = jax.tree_map(jnp.mean, top_k_acc) new_train_state = train_state.apply_gradients(grads=grads) if jax.process_index() == 0: wandb.log( { 'train/loss': float(loss), 'train/top-1-acc': top_k_acc['train_top_1_acc'] }, train_state.step) return new_train_state
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir') if FLAGS.mode == 'train': train.train_and_evaluate(FLAGS.config, FLAGS.workdir) else: predict.predict_and_evaluate(FLAGS.config, FLAGS.workdir, FLAGS.ckpt_path)
def write_per_example_losses(*, p_eval_step, target, eval_ds, num_eval_steps, loss_filename): """Evaluate the target an return a dictionary with the metrics.""" logging.info('Gathering evaluation metrics.') losses = [] lengths = [] eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types for _, eval_batch in zip(range(num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) loss, length = p_eval_step(target, eval_batch) losses.append(common.tohost(loss)) lengths.append(common.tohost(length)) # Write losses and lengths if jax.process_index() == 0: with tf.io.gfile.GFile(loss_filename, 'w') as f: writer = csv.writer(f) for pos_losses in losses: for val in pos_losses: writer.writerow(list(val)) with tf.io.gfile.GFile(loss_filename.replace('.csv', '_length.csv'), 'w') as f: writer = csv.writer(f) for val in lengths: writer.writerow([int(v) for v in list(val)]) return
def save_samples_to_json(features: List[Dict[str, Any]], config: ml_collections.ConfigDict, step: int): """Save samples to a json file.""" save_samples_for_this_step = ( config.get('save_samples_every_steps') and (step % config.get('save_samples_every_steps') == 0)) process_index = jax.process_index() accepted_processes = config.get('save_samples_process_ids', 0) if isinstance(accepted_processes, list): save_samples_for_this_process = (process_index in accepted_processes) elif accepted_processes == -1: save_samples_for_this_process = True else: save_samples_for_this_process = (process_index == accepted_processes) if save_samples_for_this_step and save_samples_for_this_process: logging.info('Saving samples at step %d, process %d', step, process_index) path = os.path.join(config.model_dir, 'samples', 'step_%d.process_%d.json' % (step, process_index)) tf.io.gfile.makedirs(os.path.dirname(path)) with tf.io.gfile.GFile(path, 'ab') as fp: for batch in features: json.dump(batch, fp) fp.write('\n')
def setup_jax(globally_use_hardware_rng: bool, jax_use_gda: bool, jax_backend_target: Optional[str], jax_xla_backend: Optional[str]) -> None: """Setups JAX and logs information about this job.""" # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') if globally_use_hardware_rng: py_utils.set_globally_use_rbg_prng_key() # We use xmap only with SPMD. jax.config.update('experimental_xmap_spmd_lowering', True) # Use the manual partitioning lowering of xmap to avoid vectorization. jax.config.update('experimental_xmap_spmd_lowering_manual', True) if jax_use_gda: logging.info('Using JAX GSDA for pjit and checkpointing') if jax_backend_target: logging.info('Using JAX backend target %s', jax_backend_target) jax_xla_backend = 'None' if jax_xla_backend is None else jax_xla_backend logging.info('Using JAX XLA backend %s', jax_xla_backend) logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX devices: %r', jax.devices()) logging.info('jax.device_count(): %d', jax.device_count()) logging.info('jax.local_device_count(): %d', jax.local_device_count()) logging.info('jax.process_count(): %d', jax.process_count())
def prepare_batches_gen(dataset, eval_config): """Returns a data iterator. The API for the data iterator will be for b in batches_gen(): pass We yield the same "epoch" every time to the data iterator is called. Args: dataset: An init2winit.dataset_lib.Dataset object. This is ignored if eval_config['use_training_gen'] == False. eval_config: A dict specifying the parameters for the hessian eval. Returns: A data generator. """ train_iter = itertools.islice(dataset.train_iterator_fn(), 0, eval_config['num_batches']) batches = list(train_iter) init_rng = jax.random.PRNGKey(eval_config['rng_key']) init_rng = jax.random.fold_in(init_rng, jax.process_index()) def training_batches_gen(): for counter, batch in enumerate(batches): batch = data_utils.shard(batch) rng = jax.random.fold_in(init_rng, counter) rng = jax_utils.replicate(rng) yield (batch, rng) return training_batches_gen
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') utils.add_gfile_logger(_WORKDIR.value) # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') jax.config.update('jax_log_compiles', True) logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) jax_xla_backend = ('None' if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend) logging.info('Using JAX XLA backend %s', jax_xla_backend) logging.info('Config: %s', FLAGS.config) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, _WORKDIR.value, 'workdir') if FLAGS.config.trainer == 'train': train.train_and_evaluate(FLAGS.config, _WORKDIR.value) elif FLAGS.config.trainer == 'inference_time': inference_time.inference_time(FLAGS.config, _WORKDIR.value) else: raise app.UsageError(f'Unknown trainer: {FLAGS.config.trainer}')
def _thread_func(self, temp_checkpoint_dir, final_checkpoint_dir): try: for future in self._commit_futures: for f in future: f.result() logging.info('Commit to storage layer has completed.') current_process = jax.process_index() lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles') all_lockfile_paths = [ os.path.join(lockfiles_dir, f'lockfile_{p}') for p in range(jax.process_count()) ] current_process_lockfile = os.path.join( lockfiles_dir, f'lockfile_{current_process}') with _RetryWithTimeout(self._timeout_secs) as t: while not tf.io.gfile.exists(current_process_lockfile): if t.timed_out: raise RuntimeError( 'Terminating after waiting for ' f'{self._timeout_secs} secs for lockfile to appear' ) logging.info( 'Waiting for current process %s lockfile to appear.', current_process) time.sleep(60) tf.io.gfile.remove(current_process_lockfile) logging.info('Lockfile removed for process %s', current_process) # This while loop will not trigger until all commits have finished. if current_process == 0: with _RetryWithTimeout(self._timeout_secs) as t: while True: if t.timed_out: raise RuntimeError( 'Terminating after waiting for ' f'{self._timeout_secs} secs for ' 'finishing the serialization.') # Mark as done when no lockfiles exist. if no_lockfiles_exists(all_lockfile_paths): tf.io.gfile.rmtree(lockfiles_dir) logging.info('Lockfiles directory removed.') logging.info('Renaming %s to %s', temp_checkpoint_dir, final_checkpoint_dir) tf.io.gfile.rename(temp_checkpoint_dir, final_checkpoint_dir) logging.info( 'Finished saving GDA checkpoint to `%s`.', final_checkpoint_dir) break else: logging.info('Thread sleeping for 60 seconds.') time.sleep(60) except Exception as e: self._exception = e
def main(argv): del argv # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX devices: %r", jax.devices()) # Add a note so that we can tell which task is which JAX process. platform.work_unit().set_task_status( f"process_index: {jax.process_index()}, process_count: {jax.process_count()}" ) platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") train_mode = FLAGS.config.mode if train_mode == TrainingMode.PRETRAINING: train_lib = run_pretraining elif train_mode == TrainingMode.CLASSIFICATION: train_lib = run_classifier else: raise ValueError("Unknown training mode: %s" % train_mode) train_lib.train_and_evaluate(FLAGS.config, FLAGS.workdir, FLAGS.vocab_filepath)
def get_data_iterator(config): """Get iterator over the dataset.""" task = memory_generation_task.MemoryGenerationTask # Establish host information local_device_count = jax.local_device_count() host_count = jax.process_count() host_id = jax.process_index() # Load datasets logging.info('Loading dataset.') decode_fn = data_utils.make_decode_fn( name_to_features=task.get_name_to_features(config), samples_per_example=config.samples_per_example, ) preprocess_fn = task.make_preprocess_fn(config) collater_fn = task.make_collater_fn(config) data = data_utils.load_dataset( patterns=config.data_patterns, decode_fn=decode_fn, preprocess_fn=preprocess_fn, collater_fn=collater_fn, is_training=False, per_device_batch_size=config.per_device_batch_size, local_device_count=local_device_count, host_count=host_count, host_id=host_id, seed=0, ) return iter(data)
def _mkdir_path(name: str, tmp_dir: str) -> str: # Tensorstore does not want a trailing / in dirname. path = os.path.join(tmp_dir, name).rstrip('/') # Make the paths only on process 0. if jax.process_index() == 0: # Avoid recursively create parent dir. tf.io.gfile.mkdir(path) return path
def _f(path_tuple, v): vname = "/".join(path_tuple) for pattern, arg in regex_rules: if re.match(pattern, vname): if jax.process_index() == 0: logging.info("Updating %s with %s due to `%s`", vname, arg, pattern) return f(v, arg) return v
def main(unused_argv): if jax.process_index() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.process_count()) logging.info('host_id : %d', jax.process_index()) if FLAGS.batch_size is None or FLAGS.batch_size <= 0: raise ValueError("""FLAGS.batch_size value is invalid, expected a positive non-zero integer.""") if FLAGS.dataset is None: raise ValueError("""FLAGS.dataset value is invalid, expected a non-empty string describing dataset name.""") batch_size = FLAGS.batch_size num_batches = FLAGS.num_batches dataset_name = FLAGS.dataset model_name = FLAGS.model initializer_name = 'noop' hparam_overrides = { 'batch_size': batch_size, } hps = hyperparameters.build_hparams(model_name=model_name, initializer_name=initializer_name, dataset_name=dataset_name, hparam_file=None, hparam_overrides=hparam_overrides) rng = jax.random.PRNGKey(0) rng, data_rng = jax.random.split(rng) dataset = datasets.get_dataset(FLAGS.dataset)(data_rng, batch_size, batch_size, hps) train_iter = dataset.train_iterator_fn() for i in range(num_batches): batch = next(train_iter) logging.info('train batch_num = %d, batch = %r', i, batch) for batch in dataset.valid_epoch(num_batches): logging.info('validation batch = %r', batch)
def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX local devices: %r", jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.model_dir, "model_dir") tf.io.gfile.makedirs(FLAGS.model_dir) # Process config here if FLAGS.config_file: with tf.io.gfile.GFile(FLAGS.config_file, "r") as reader: config = json.load(reader) else: config = json.loads(FLAGS.config) # # Save config to workdir if it's not yet exists if jax.process_index() == 0: config_file = os.path.join(FLAGS.model_dir, "config.json") with tf.io.gfile.GFile(config_file, "w") as writer: writer.write(json.dumps(config, indent=4)) config["model_dir"] = FLAGS.model_dir if FLAGS.learning_rate is not None: config["learning_rate"] = FLAGS.learning_rate if FLAGS.per_device_batch_size is not None: config["per_device_batch_size"] = FLAGS.per_device_batch_size if FLAGS.num_train_steps is not None: config["num_train_steps"] = FLAGS.num_train_steps if FLAGS.warmup_steps is not None: config["warmup_steps"] = FLAGS.warmup_steps train(ml_collections.ConfigDict(config))
def get_summary_writer(summary_dir: str) -> SummaryWriter: """Context manager around Tensorflow's SummaryWriter.""" if jax.process_index() == 0: logging.info('Opening SummaryWriter `%s`...', summary_dir) summary_writer = tf_summary.create_file_writer(summary_dir) else: # We create a dummy tf.summary.SummaryWriter() on non-zero tasks. This will # return a mock object, which acts like a summary writer, but does nothing, # such as writing event to disk. logging.info('Opening a mock-like SummaryWriter.') summary_writer = tf_summary.create_noop_writer() try: yield summary_writer finally: summary_writer.close() if jax.process_index() == 0: logging.info('Closed SummaryWriter `%s`.', summary_dir) else: logging.info('Closed a mock-like SummaryWriter.')
def _train_sentencepiece(dataset: tf.data.Dataset, *, vocab_size: int, maxchars: int = int(1e7), model_path: str, model_type: str = 'unigram', character_coverage: float = 1.0, data_keys=('inputs', 'targets')): """Train SentencePiece tokenizer from subset of tf dataset. Args: dataset: tf.dataset vocab_size: int: size of vocab tokens to train. maxchars: int: number of characters to use for sentencepiece training. model_path: str: path of model file to save vocab model to. model_type: str: type of sentencepiece vocab to train. character_coverage: amount of characters covered by the model, good defaults are 0.9995 for languages with rich character set like Japanese or Chinese and 1.0 for other languages with small character set. data_keys: Tuple[str]: keys of dataset to use for training. Returns: path to the trained sentencepiece vocabulary model. """ if model_path.startswith('gs://'): abs_model_path = model_path else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys) with tempfile.NamedTemporaryFile(delete=False, prefix='/tmp/sp_tmp') as model_fp: pass # we just want a prefix'd tmp-filename argstr = ' '.join([ f'--input={fname}', f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', f'--model_prefix={model_fp.name}', f'--model_type={model_type}' ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address # create and fill delays. copy_rename_path = abs_model_path + '.rntmp' tf.io.gfile.copy(model_fp.name + '.model', copy_rename_path, overwrite=True) tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True) logging.info('copied %s to %s', model_fp.name + '.model', abs_model_path) else: while not tf.io.gfile.exists(abs_model_path): time.sleep(1) time.sleep(1) return abs_model_path
def train_sentencepiece(dataset, vocab_size, maxchars=1e9, character_coverage=1.0, model_path="model", model_type="unigram", data_keys=("inputs", "targets")): """Train SentencePiece tokenizer from subset of tf dataset. Args: dataset: tf.dataset vocab_size: int: size of vocab tokens to train. maxchars: int: number of characters to use for sentencepiece training. character_coverage: amount of characters covered by the model, good defaults are 0.9995 for languages with rich character set like Japanese or Chinese and 1.0 for other languages with small character set. model_path: str: path of model file to save vocab model to. model_type: str: type of sentencepiece vocab to train. data_keys: Tuple[str]: keys of dataset to use for training. Returns: path to the trained sentencepiece vocabulary model. """ fname, _ = dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys) with tempfile.NamedTemporaryFile(delete=False, prefix="/tmp/sp_tmp") as model_fp: pass # we just want a prefix'd tmp-filename argstr = " ".join([ f"--input={fname}", f"--vocab_size={vocab_size}", f"--character_coverage={character_coverage}", f"--model_prefix={model_fp.name}", f"--model_type={model_type}" ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address # create and fill delays. copy_rename_path = model_path + ".rntmp" tf.io.gfile.copy(model_fp.name + ".model", copy_rename_path, overwrite=True) tf.io.gfile.rename(copy_rename_path, model_path, overwrite=True) tf.io.gfile.copy(model_fp.name + ".vocab", copy_rename_path + ".vocab", overwrite=True) tf.io.gfile.rename(copy_rename_path + ".vocab", model_path + ".vocab", overwrite=True) logging.info("copied %s to %s", model_fp.name + ".model", model_path) else: while not tf.io.gfile.exists(model_path): time.sleep(1) time.sleep(1) return model_path
def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) logging.info('JAX total devices: %r', jax.device_count()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.output_dir, 'output_dir') tf.io.gfile.makedirs(FLAGS.output_dir) # Process config here if FLAGS.config_file: with tf.io.gfile.GFile(FLAGS.config_file, 'r') as reader: config = json.load(reader) else: config = json.loads(FLAGS.config) # # Save config to workdir if it's not yet exists if jax.process_index() == 0: config_file = os.path.join(FLAGS.output_dir, 'config.json') with tf.io.gfile.GFile(config_file, 'w') as writer: writer.write(json.dumps(config, indent=4)) config['output_dir'] = FLAGS.output_dir if 'num_total_memories' not in config: config['num_total_memories'] = get_num_total_memories( ml_collections.ConfigDict(config)) generate(ml_collections.ConfigDict(config))
def itstime(step, every_n_steps, total_steps, process=None, last=True, first=True): """Determines whether or not it is time to trigger an action.""" is_process = process is None or jax.process_index() == process is_step = every_n_steps and (step % every_n_steps == 0) is_last = every_n_steps and step == total_steps is_first = every_n_steps and step == 1 return is_process and (is_step or (last and is_last) or (first and is_first))