def local_replica_groups(inner_group_size: int) -> List[List[int]]: """Constructs local nearest-neighbor rings given the JAX device assignment. For inner_group_size=8, each inner group is a tray with replica order: 0/1 2/3 7/6 5/4 Args: inner_group_size: Number of replica in each group. Returns: A list of replica id groups. """ world_size = jax.device_count() outer_group_size, ragged = divmod(world_size, inner_group_size) assert not ragged, 'inner group size must evenly divide global device count' # the last device should have maximal x and y coordinate def bounds_from_last_device(device): x, y, z = device.coords return (x + 1) * (device.core_on_chip + 1), (y + 1) * (z + 1) global_x, _ = bounds_from_last_device(jax.devices()[-1]) per_host_x, per_host_y = bounds_from_last_device(jax.local_devices(0)[-1]) assert inner_group_size in [2 ** i for i in range(1, 15)], \ 'inner group size must be a power of two' if inner_group_size <= 4: # inner group is Nx1 (core, chip, 2x1) inner_x, inner_y = inner_group_size, 1 inner_perm = range(inner_group_size) else: if inner_group_size <= global_x * 2: # inner group is Nx2 (2x2 tray, 4x2 DF pod host, row of hosts) inner_x, inner_y = inner_group_size // 2, 2 else: # inner group covers the full x dimension and must be >2 in y inner_x, inner_y = global_x, inner_group_size // global_x p = np.arange(inner_group_size) per_group_hosts_x = 1 if inner_x < per_host_x else inner_x // per_host_x p = p.reshape(inner_y // per_host_y, per_group_hosts_x, per_host_y, inner_x // per_group_hosts_x) p = p.transpose(0, 2, 1, 3) p = p.reshape(inner_y // 2, 2, inner_x) p[:, 1, :] = p[:, 1, ::-1] inner_perm = p.reshape(-1) inner_replica_groups = [[o * inner_group_size + i for i in inner_perm] for o in range(outer_group_size)] return inner_replica_groups
def _build_train_input(self): """See base class.""" num_devices = jax.device_count() global_batch_size = self._batch_size per_device_batch_size, ragged = divmod(global_batch_size, num_devices) if ragged: raise ValueError( f'Global batch size {global_batch_size} must be divisible by ' f'num devices {num_devices}') return dataset.load( dataset.Split.TRAIN_AND_VALID, preprocess_mode=dataset.PreprocessMode.LINEAR_TRAIN, transpose=self._should_transpose_images(), batch_dims=[jax.local_device_count(), per_device_batch_size])
def _cross_entropy_loss_fn(self, params, state, images, adv_images, labels, target_probs, rng): scalars = {} images = self.normalize_fn(images) logits, state = self.model.apply(params, state, rng, images, is_training=True) loss = jnp.mean(utils.cross_entropy(logits, target_probs)) loss += self.config.training.weight_decay * utils.weight_decay(params) if not self.config.training.use_cutmix: scalars['top_1_acc'] = utils.accuracy(logits, labels) scalars['train_loss'] = loss scaled_loss = loss / jax.device_count() return scaled_loss, (state, scalars)
def _replicate(x, devices=None): x = jax.numpy.asarray(x) if devices is None: # match the default device assignments used in pmap: # for single-host, that's the XLA default device assignment # for multi-host, it's the order of jax.local_devices() if jax.host_count() == 1: devices = [ d for d in xb.get_backend().get_default_device_assignment( jax.device_count()) if d.host_id == jax.host_id() ] else: devices = jax.local_devices() aval = jax.ShapedArray((len(devices), ) + x.shape, x.dtype) buffers = [jax.interpreters.xla.device_put(x, device=d) for d in devices] return jax.pxla.ShardedDeviceArray(aval, buffers)
def lr_schedule(step: jnp.ndarray) -> jnp.ndarray: """Cosine learning rate schedule.""" train_split = dataset.Split.from_string(FLAGS.train_split) total_batch_size = FLAGS.train_device_batch_size * jax.device_count() steps_per_epoch = train_split.num_examples / total_batch_size warmup_steps = FLAGS.train_lr_warmup_epochs * steps_per_epoch training_steps = FLAGS.train_epochs * steps_per_epoch lr = FLAGS.train_lr_init * total_batch_size / 256 scaled_step = (jnp.maximum(step - warmup_steps, 0) / (training_steps - warmup_steps)) lr *= 0.5 * (1.0 + jnp.cos(jnp.pi * scaled_step)) if warmup_steps: lr *= jnp.minimum(step / warmup_steps, 1.0) return lr
def _build_train_input(self) -> Generator[dataset.Batch, None, None]: """Loads the (infinitely looping) dataset iterator.""" num_devices = jax.device_count() global_batch_size = self._batch_size per_device_batch_size, ragged = divmod(global_batch_size, num_devices) if ragged: raise ValueError( f'Global batch size {global_batch_size} must be divisible by ' f'num devices {num_devices}') return dataset.load( dataset.Split.TRAIN_AND_VALID, preprocess_mode=dataset.PreprocessMode.PRETRAIN, transpose=self._should_transpose_images(), batch_dims=[jax.local_device_count(), per_device_batch_size])
def testPmap(self): f = jax.pmap(lambda x: 0. / x) with self.assertRaisesRegex( FloatingPointError, r"invalid value \(nan\) encountered in parallel computation"): ans = f(jnp.array([0.])) ans.block_until_ready() if jax.device_count() >= 2: with self.assertRaisesRegex( FloatingPointError, r"invalid value \(nan\) encountered in parallel computation" ): ans = f(jnp.array([1., 0.])) ans.block_until_ready()
def train_loop(experiment_class, config): """The main training loop. This loop periodically saves a checkpoint to be evaluated in the eval_loop. Args: experiment_class: the constructor for the experiment (either byol_experiment or eval_experiment). config: the experiment config. """ experiment = experiment_class(**config) rng = jax.random.PRNGKey(0) step = 0 host_id = jax.host_id() last_logging = time.time() if config['checkpointing_config']['use_checkpointing']: checkpoint_data = experiment.load_checkpoint() if checkpoint_data is None: step = 0 else: step, rng = checkpoint_data local_device_count = jax.local_device_count() while step < config['max_steps']: step_rng, rng = tuple(jax.random.split(rng)) # Broadcast the random seeds across the devices step_rng_device = jax.random.split(step_rng, num=jax.device_count()) step_rng_device = step_rng_device[host_id * local_device_count:(host_id + 1) * local_device_count] step_device = np.broadcast_to(step, [local_device_count]) # Perform a training step and get scalars to log. scalars = experiment.step(global_step=step_device, rng=step_rng_device) # Checkpointing and logging. if config['checkpointing_config']['use_checkpointing']: experiment.save_checkpoint(step, rng) current_time = time.time() if current_time - last_logging > FLAGS.log_tensors_interval: logging.info('Step %d: %s', step, scalars) last_logging = current_time step += 1 logging.info('Saving final checkpoint') logging.info('Step %d: %s', step, scalars) experiment.save_checkpoint(step, rng)
def _dataset(load_fn, is_training: bool, total_batch_size: int, ratio: Optional[float] = None, one_minus_ratio: Optional[float] = None, repeat: int = 1) -> tf.data.Dataset: """Creates a dataset.""" num_devices = jax.device_count() per_device_batch_size, ragged = divmod(total_batch_size, num_devices) if ragged: raise ValueError( f'Global batch size {total_batch_size} must be divisible by the ' f'total number of devices {num_devices}') if repeat > 1: if per_device_batch_size % repeat: raise ValueError( f'Per device batch size {per_device_batch_size} must be divisible ' f'by the number of repeated batches {repeat}') per_device_batch_size //= repeat if ratio is None and one_minus_ratio is None: pass # Use full batch size. elif one_minus_ratio is None: per_device_batch_size = max( 1, min(round(per_device_batch_size * ratio), per_device_batch_size - 1)) elif ratio is None: batch_size = max( 1, min(round(per_device_batch_size * one_minus_ratio), per_device_batch_size - 1)) per_device_batch_size = per_device_batch_size - batch_size else: raise ValueError('Only one of `ratio` or `one_minus_ratio` must be ' 'specified') if repeat > 1: per_device_batch_size *= repeat # When testing, we need to batch data across all devices (not just local # devices). num_local_devices = jax.local_device_count() if is_training: batch_sizes = [num_local_devices, per_device_batch_size] else: num_hosts = jax.host_count() assert num_hosts * num_local_devices == num_devices batch_sizes = [num_hosts, num_local_devices, per_device_batch_size] return load_fn(batch_sizes, is_training=is_training)
def setUp(self): super(DistributedTest, self).setUp() if JAX_MODE: os.environ['XLA_FLAGS'] = ( '--xla_force_host_platform_device_count={}'.format(NUM_DEVICES)) assert jax.device_count() == NUM_DEVICES self.key = jax.random.PRNGKey(0) else: physical_devices = tf.config.experimental.list_physical_devices() tf.config.experimental.set_virtual_device_configuration( physical_devices[0], [tf.config.experimental.VirtualDeviceConfiguration()] * NUM_DEVICES) self.strategy = tf.distribute.MirroredStrategy( devices=tf.config.list_logical_devices()) self.key = [0, 0] self.axis_name = 'i'
def test_fake_data(self): model_dir = self.get_tmp_model_dir() FLAGS.batch_size = 256 * jax.device_count() FLAGS.half_precision = True FLAGS.num_epochs = 5 FLAGS.model_dir = model_dir start_time = time.time() train.main([]) benchmark_time = time.time() - start_time self.report_wall_time(benchmark_time) self.report_extras({ 'description': 'ImageNet ResNet50 with fake data', 'model_name': 'resnet50', 'parameters': f'hp=true,bs={FLAGS.batch_size}', })
def main(unused_argv): # Necessary to use the tfds imagenet loader. tf.enable_v2_behavior() rng = jax.random.PRNGKey(FLAGS.seed) if FLAGS.hessian_eval_config: hessian_eval_config = json.loads(FLAGS.hessian_eval_config) else: hessian_eval_config = hessian_eval.DEFAULT_EVAL_CONFIG if FLAGS.experiment_config_filename: with tf.io.gfile.GFile(FLAGS.experiment_config_filename, 'r') as f: experiment_config = json.load(f) if jax.host_id() == 0: logging.info('experiment_config: %r', experiment_config) dataset_name = experiment_config['dataset'] model_name = experiment_config['model'] else: assert FLAGS.dataset and FLAGS.model dataset_name = FLAGS.dataset model_name = FLAGS.model if jax.host_id() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.host_count()) logging.info('host_id : %d', jax.host_id()) model = models.get_model(model_name) dataset_builder = datasets.get_dataset(dataset_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) with tf.io.gfile.GFile(FLAGS.trial_hparams_filename, 'r') as f: hps = config_dict.ConfigDict(json.load(f)) if FLAGS.hparam_overrides: if isinstance(FLAGS.hparam_overrides, str): hparam_overrides = json.loads(FLAGS.hparam_overrides) hps.update_from_flattened_dict(hparam_overrides) run_lanczos.eval_checkpoints(FLAGS.checkpoint_dir, hps, rng, FLAGS.eval_num_batches, model, dataset_builder, dataset_meta_data, hessian_eval_config, FLAGS.min_global_step, FLAGS.max_global_step, FLAGS.use_deprecated_checkpointing)
def testBasic(self): if jax.device_count() < 2: raise SkipTest @partial(sharded_jit, in_parts=(P(2, 1), P(2, 1)), out_parts=None) def f(x, y): return x + y shape = (8, 8) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) actual = f(x, x + 1) expected = x + (x + 1) self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), expected, check_dtypes=False)
def _build_train_input(self) -> Generator[dataset.Batch, None, None]: """See base class.""" num_devices = jax.device_count() global_batch_size = self.config.training.batch_size per_device_batch_size, ragged = divmod(global_batch_size, num_devices) if ragged: raise ValueError( f'Global batch size {global_batch_size} must be divisible by ' f'num devices {num_devices}') split = dataset.Split.TRAIN_AND_VALID return self._load_data( split=split, is_training=True, batch_dims=[jax.local_device_count(), per_device_batch_size])
def lr_schedule(step: jnp.ndarray) -> jnp.ndarray: """Linear scaling rule optimized for 90 epochs.""" train_split = dataset.Split.from_string(FLAGS.train_split) # See Section 5.1 of https://arxiv.org/pdf/1706.02677.pdf. total_batch_size = FLAGS.train_device_batch_size * jax.device_count() steps_per_epoch = train_split.num_examples / total_batch_size current_epoch = step / steps_per_epoch # type: float lr = (0.1 * total_batch_size) / 256 lr_linear_till = 5 boundaries = jnp.array((30, 60, 80)) * steps_per_epoch values = jnp.array([1., 0.1, 0.01, 0.001]) * lr index = jnp.sum(boundaries < step) lr = jnp.take(values, index) return lr * jnp.minimum(1., current_epoch / lr_linear_till)
def build_ensemble_optimizer(ensemble_size, shared_params, ensemble_params, optimizer, optimizer_kwargs): num_devices = jax.device_count() assert ensemble_size % num_devices == 0 num_models_per_device = ensemble_size // num_devices optim = optimizer(**optimizer_kwargs) shared_params_optim_state = optim.init(shared_params) ensemble_params_optim_init = jax.pmap(jax.vmap(optim.init, in_axes=0, out_axes=0), in_axes=0, out_axes=0) ensemble_params_optim_state = ensemble_params_optim_init(ensemble_params) return optim, shared_params_optim_state, ensemble_params_optim_state
def main(unused_argv): # TODO(gdahl) Figure out a better way to handle passing more complicated # flags to the binary. training_metrics_config = None if FLAGS.training_metrics_config: training_metrics_config = json.loads(FLAGS.training_metrics_config) checkpoint_steps = [int(s.strip()) for s in FLAGS.checkpoint_steps] eval_steps = [int(s.strip()) for s in FLAGS.eval_steps] if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.experiment_dir) log_dir = os.path.join(FLAGS.experiment_dir, 'r=3/') tf.io.gfile.makedirs(log_dir) log_path = os.path.join( log_dir, 'worker{}_{}.log'.format(FLAGS.worker_id, jax.host_id())) with tf.io.gfile.GFile(log_path, 'a') as logfile: utils.add_log_file(logfile) if jax.host_id() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.host_count()) logging.info('host_id : %d', jax.host_id()) logging.info('checkpoint_steps: %r', checkpoint_steps) logging.info('eval_steps: %r', eval_steps) trainer.run( dataset_name=FLAGS.dataset, eval_batch_size=FLAGS.eval_batch_size, eval_num_batches=FLAGS.eval_num_batches, eval_train_num_batches=FLAGS.eval_train_num_batches, eval_frequency=FLAGS.eval_frequency, checkpoint_steps=checkpoint_steps, eval_steps=eval_steps, hparam_file=FLAGS.hparam_file, hparam_overrides=FLAGS.hparam_overrides, initializer_name=FLAGS.initializer, model_name=FLAGS.model, loss_name=FLAGS.loss, metrics_name=FLAGS.metrics, num_train_steps=FLAGS.num_train_steps, experiment_dir=FLAGS.experiment_dir, worker_id=FLAGS.worker_id, training_metrics_config=training_metrics_config, use_deprecated_checkpointing=FLAGS.use_deprecated_checkpointing)
def test_distributed_init(self): n_devices = jax.device_count() batch_size = 5 * n_devices x = np.random.uniform(size=(batch_size, 1)) y = 1.4 * x + 0.1 * np.random.uniform(size=(batch_size, 2)) model = eg.Model( eg.Linear(2), loss=[eg.losses.MeanSquaredError()], ) model = model.distributed() model.init_on_batch(x) assert model.module.kernel.shape == (n_devices, 1, 2) assert model.module.bias.shape == (n_devices, 2)
def from_local(self, model: M) -> M: # device_idxs used to inform pmap about the number of devices device_idxs = jnp.arange(jax.device_count()) def device_fn(idx: int, model: M) -> M: # fold rng state so its unique for each device return model.map( lambda key: jax.random.fold_in(key, idx), tx.Rng, inplace=True, ) model = jax.pmap( device_fn, in_axes=(0, None), )(device_idxs, model) return model
def build_train_input(data_dir, batch_size, img_size, augmentation): num_devices = jax.device_count() bs_per_device, ragged = divmod(batch_size, num_devices) if ragged: raise ValueError( f'Batch size {batch_size} must be divisible by num devices {num_devices}' ) return input_pipeline.load( input_pipeline.Split.TRAIN_AND_VALID, data_dir=data_dir, is_training=True, batch_dims=[jax.local_device_count(), bs_per_device], transpose=True, image_size=(img_size, img_size), augment_name=augmentation, augment_before_mix=True, name='imagenet', fake_data=False)
def get_gpu_count(): """ Return the number of available gpus (regardless of whether torch, tf or jax is used) """ if is_torch_available(): import torch return torch.cuda.device_count() elif is_tf_available(): import tensorflow as tf return len(tf.config.list_physical_devices("GPU")) elif is_flax_available(): import jax return jax.device_count() else: return 0
def test_counters(self): res = unittest.TestResult() ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr ts.run(res) active_pmap = int(jax.device_count() > 1) self.assertEqual(self.InnerTest.test_1_count, 0) self.assertEqual(self.InnerTest.test_2_count, 0) self.assertEqual(self.InnerTest.test_3_count, 4 + active_pmap) self.assertEqual(self.InnerTest.test_4_count, 4) self.assertEqual(self.InnerTest.test_5_count, 1) self.assertEqual(self.InnerTest.test_6_count, 2) # Test methods do not use `self.variant`. self.assertLen(res.errors, 1 + 2 + 4 + 4 + active_pmap) for _, msg in res.errors: self.assertRegex( msg, 'RuntimeError: Test is wrapped .+ but never calls self.variant')
def _build_dataset(self, data_rng: spec.RandomState, split: str, data_dir: str, batch_size): if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') ds_builder = tfds.builder('imagenet2012:5.*.*') ds_builder.download_and_prepare() ds = input_pipeline.create_input_iter( ds_builder, batch_size, self.train_mean, self.train_stddev, self.center_crop_size, self.resize_size, self.aspect_ratio_range, self.scale_ratio_range, train=True, cache=False) return ds
def _build_train_input(self): num_devices = jax.device_count() global_batch_size = self.config.train_batch_size bs_per_device, ragged = divmod(global_batch_size, num_devices) if ragged: raise ValueError( f'Global batch size {global_batch_size} must be divisible by ' f'num devices {num_devices}') return input_pipeline.load( input_pipeline.Split.TRAIN_AND_VALID, is_training=True, batch_dims=[jax.local_device_count(), bs_per_device], transpose=self.config.get('transpose', False), image_size=(self.train_imsize, self.train_imsize), augment_name=self.config.augment_name, augment_before_mix=self.config.get('augment_before_mix', True), name=self.config.which_dataset, fake_data=False)
def main(unused_argv): if jax.process_index() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.process_count()) logging.info('host_id : %d', jax.process_index()) if FLAGS.batch_size is None or FLAGS.batch_size <= 0: raise ValueError("""FLAGS.batch_size value is invalid, expected a positive non-zero integer.""") if FLAGS.dataset is None: raise ValueError("""FLAGS.dataset value is invalid, expected a non-empty string describing dataset name.""") batch_size = FLAGS.batch_size num_batches = FLAGS.num_batches dataset_name = FLAGS.dataset model_name = FLAGS.model initializer_name = 'noop' hparam_overrides = { 'batch_size': batch_size, } hps = hyperparameters.build_hparams(model_name=model_name, initializer_name=initializer_name, dataset_name=dataset_name, hparam_file=None, hparam_overrides=hparam_overrides) rng = jax.random.PRNGKey(0) rng, data_rng = jax.random.split(rng) dataset = datasets.get_dataset(FLAGS.dataset)(data_rng, batch_size, batch_size, hps) train_iter = dataset.train_iterator_fn() for i in range(num_batches): batch = next(train_iter) logging.info('train batch_num = %d, batch = %r', i, batch) for batch in dataset.valid_epoch(num_batches): logging.info('validation batch = %r', batch)
def testPyTreeOutputs(self): if jax.device_count() < 2: raise SkipTest def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (4, 4) x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1), ) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = sharded_jit(f, in_parts, out_parts)(x) expected = f(x) self.assertAllClose(result, expected, check_dtypes=False) out_parts = None result = sharded_jit(f, in_parts, out_parts)(x) self.assertAllClose(result, expected, check_dtypes=False)
def make_dataset_iterator( self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: """Creates a dataset.""" batch_size_per_learner = self._config.batch_size // jax.process_count() batch_size_per_device, ragged = divmod(self._config.batch_size, jax.device_count()) if ragged: raise ValueError( 'Learner batch size must be divisible by total number of devices!') dataset = datasets.make_reverb_dataset( table=self._config.replay_table_name, server_address=replay_client.server_address, batch_size=batch_size_per_device, num_parallel_calls=None, max_in_flight_samples_per_worker=2 * batch_size_per_learner) return utils.multi_device_put(dataset.as_numpy_iterator(), jax.local_devices())
def test_ordered_effect_remains_ordered_across_multiple_devices(self): # TODO(sharadmv): remove jaxlib check when minimum version is bumped # TODO(sharadmv): enable this test on GPU and TPU when backends are # supported if jaxlib.version < (0, 3, 8): raise unittest.SkipTest( "`emit_python_callback` only supported in jaxlib >= 0.3.8") if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") log = [] def log_value(x): log.append(x) return () @functools.partial(jax.jit, device=jax.devices()[0]) def f(x): # Expensive computation x = x.dot(x) x = jnp.log(x.sum()) return callback_p.bind(x, callback=log_value, effect='log', out_avals=[]) @functools.partial(jax.jit, device=jax.devices()[1]) def g(x): return callback_p.bind(x, callback=log_value, effect='log', out_avals=[]) f(jnp.ones((500, 500))) g(3.) f(jnp.ones((500, 500))) g(3.) f(jnp.ones((500, 500))) g(3.) dispatch.runtime_tokens.block_until_ready() x_, y_ = float(jnp.log(1.25e8)), 3. expected_log = [x_, y_, x_, y_, x_, y_] self.assertListEqual(log, expected_log)
def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) logging.info('JAX total devices: %r', jax.device_count()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.output_dir, 'output_dir') tf.io.gfile.makedirs(FLAGS.output_dir) # Process config here if FLAGS.config_file: with tf.io.gfile.GFile(FLAGS.config_file, 'r') as reader: config = json.load(reader) else: config = json.loads(FLAGS.config) # # Save config to workdir if it's not yet exists if jax.process_index() == 0: config_file = os.path.join(FLAGS.output_dir, 'config.json') with tf.io.gfile.GFile(config_file, 'w') as writer: writer.write(json.dumps(config, indent=4)) config['output_dir'] = FLAGS.output_dir if 'num_total_memories' not in config: config['num_total_memories'] = get_num_total_memories( ml_collections.ConfigDict(config)) generate(ml_collections.ConfigDict(config))
def _update_func( self, base_rng, state, inputs, ): """Computes loss and updates model parameters.""" step = state.step rng = jax.random.fold_in(base_rng, jax.lax.axis_index('batch')) rng = jax.random.fold_in(rng, step) grad_loss_fn = jax.value_and_grad(self._loss_fn, has_aux=True) (_, loss_dict), scaled_grads = grad_loss_fn(state.params, inputs, rng) grads = jax.lax.psum(scaled_grads, axis_name='batch') grad_norm = optax.global_norm(grads) loss_dict['scalars']['grad_norm'] = grad_norm # Compute and apply updates via our optimizer. learning_rate = self.learning_rate(state.step) loss_dict['scalars']['learning_rate'] = learning_rate _, opt_apply = self.optimizer(learning_rate) updates, new_opt_state = opt_apply(grads, state.opt_state, state.params) new_params = optax.apply_updates(state.params, updates) # Update ema params ema_rate = self.config.evaluation.ema_rate new_ema_params = jax.tree_multimap( lambda x, y: x + (1 - ema_rate) * (y - x), state.ema_params, new_params, ) new_state = state.replace(step=step + 1, params=new_params, ema_params=new_ema_params, opt_state=new_opt_state) # Rescale loss dict and return loss_dict['scalars'] = jax.tree_map( lambda x: jax.lax.psum(x, axis_name='batch') / jax.device_count(), loss_dict['scalars'], ) return new_state, loss_dict