Example #1
0
def local_replica_groups(inner_group_size: int) -> List[List[int]]:
  """Constructs local nearest-neighbor rings given the JAX device assignment.

  For inner_group_size=8, each inner group is a tray with replica order:

  0/1 2/3
  7/6 5/4

  Args:
    inner_group_size: Number of replica in each group.

  Returns:
    A list of replica id groups.
  """
  world_size = jax.device_count()
  outer_group_size, ragged = divmod(world_size, inner_group_size)
  assert not ragged, 'inner group size must evenly divide global device count'
  # the last device should have maximal x and y coordinate
  def bounds_from_last_device(device):
    x, y, z = device.coords
    return (x + 1) * (device.core_on_chip + 1), (y + 1) * (z + 1)
  global_x, _ = bounds_from_last_device(jax.devices()[-1])
  per_host_x, per_host_y = bounds_from_last_device(jax.local_devices(0)[-1])
  assert inner_group_size in [2 ** i for i in range(1, 15)], \
      'inner group size must be a power of two'
  if inner_group_size <= 4:
    # inner group is Nx1 (core, chip, 2x1)
    inner_x, inner_y = inner_group_size, 1
    inner_perm = range(inner_group_size)
  else:
    if inner_group_size <= global_x * 2:
      # inner group is Nx2 (2x2 tray, 4x2 DF pod host, row of hosts)
      inner_x, inner_y = inner_group_size // 2, 2
    else:
      # inner group covers the full x dimension and must be >2 in y
      inner_x, inner_y = global_x, inner_group_size // global_x
    p = np.arange(inner_group_size)
    per_group_hosts_x = 1 if inner_x < per_host_x else inner_x // per_host_x
    p = p.reshape(inner_y // per_host_y, per_group_hosts_x,
                  per_host_y, inner_x // per_group_hosts_x)
    p = p.transpose(0, 2, 1, 3)
    p = p.reshape(inner_y // 2, 2, inner_x)
    p[:, 1, :] = p[:, 1, ::-1]
    inner_perm = p.reshape(-1)

  inner_replica_groups = [[o * inner_group_size + i for i in inner_perm]
                          for o in range(outer_group_size)]
  return inner_replica_groups
    def _build_train_input(self):
        """See base class."""
        num_devices = jax.device_count()
        global_batch_size = self._batch_size
        per_device_batch_size, ragged = divmod(global_batch_size, num_devices)

        if ragged:
            raise ValueError(
                f'Global batch size {global_batch_size} must be divisible by '
                f'num devices {num_devices}')

        return dataset.load(
            dataset.Split.TRAIN_AND_VALID,
            preprocess_mode=dataset.PreprocessMode.LINEAR_TRAIN,
            transpose=self._should_transpose_images(),
            batch_dims=[jax.local_device_count(), per_device_batch_size])
Example #3
0
 def _cross_entropy_loss_fn(self, params, state, images, adv_images, labels,
                            target_probs, rng):
     scalars = {}
     images = self.normalize_fn(images)
     logits, state = self.model.apply(params,
                                      state,
                                      rng,
                                      images,
                                      is_training=True)
     loss = jnp.mean(utils.cross_entropy(logits, target_probs))
     loss += self.config.training.weight_decay * utils.weight_decay(params)
     if not self.config.training.use_cutmix:
         scalars['top_1_acc'] = utils.accuracy(logits, labels)
     scalars['train_loss'] = loss
     scaled_loss = loss / jax.device_count()
     return scaled_loss, (state, scalars)
Example #4
0
def _replicate(x, devices=None):
    x = jax.numpy.asarray(x)
    if devices is None:
        # match the default device assignments used in pmap:
        # for single-host, that's the XLA default device assignment
        # for multi-host, it's the order of jax.local_devices()
        if jax.host_count() == 1:
            devices = [
                d for d in xb.get_backend().get_default_device_assignment(
                    jax.device_count()) if d.host_id == jax.host_id()
            ]
        else:
            devices = jax.local_devices()
    aval = jax.ShapedArray((len(devices), ) + x.shape, x.dtype)
    buffers = [jax.interpreters.xla.device_put(x, device=d) for d in devices]
    return jax.pxla.ShardedDeviceArray(aval, buffers)
Example #5
0
def lr_schedule(step: jnp.ndarray) -> jnp.ndarray:
  """Cosine learning rate schedule."""
  train_split = dataset.Split.from_string(FLAGS.train_split)

  total_batch_size = FLAGS.train_device_batch_size * jax.device_count()
  steps_per_epoch = train_split.num_examples / total_batch_size
  warmup_steps = FLAGS.train_lr_warmup_epochs * steps_per_epoch
  training_steps = FLAGS.train_epochs * steps_per_epoch

  lr = FLAGS.train_lr_init * total_batch_size / 256
  scaled_step = (jnp.maximum(step - warmup_steps, 0) /
                 (training_steps - warmup_steps))
  lr *= 0.5 * (1.0 + jnp.cos(jnp.pi * scaled_step))
  if warmup_steps:
    lr *= jnp.minimum(step / warmup_steps, 1.0)
  return lr
    def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
        """Loads the (infinitely looping) dataset iterator."""
        num_devices = jax.device_count()
        global_batch_size = self._batch_size
        per_device_batch_size, ragged = divmod(global_batch_size, num_devices)

        if ragged:
            raise ValueError(
                f'Global batch size {global_batch_size} must be divisible by '
                f'num devices {num_devices}')

        return dataset.load(
            dataset.Split.TRAIN_AND_VALID,
            preprocess_mode=dataset.PreprocessMode.PRETRAIN,
            transpose=self._should_transpose_images(),
            batch_dims=[jax.local_device_count(), per_device_batch_size])
Example #7
0
    def testPmap(self):
        f = jax.pmap(lambda x: 0. / x)

        with self.assertRaisesRegex(
                FloatingPointError,
                r"invalid value \(nan\) encountered in parallel computation"):
            ans = f(jnp.array([0.]))
            ans.block_until_ready()

        if jax.device_count() >= 2:
            with self.assertRaisesRegex(
                    FloatingPointError,
                    r"invalid value \(nan\) encountered in parallel computation"
            ):
                ans = f(jnp.array([1., 0.]))
                ans.block_until_ready()
Example #8
0
def train_loop(experiment_class, config):
    """The main training loop.

  This loop periodically saves a checkpoint to be evaluated in the eval_loop.

  Args:
    experiment_class: the constructor for the experiment (either byol_experiment
    or eval_experiment).
    config: the experiment config.
  """
    experiment = experiment_class(**config)
    rng = jax.random.PRNGKey(0)
    step = 0

    host_id = jax.host_id()
    last_logging = time.time()
    if config['checkpointing_config']['use_checkpointing']:
        checkpoint_data = experiment.load_checkpoint()
        if checkpoint_data is None:
            step = 0
        else:
            step, rng = checkpoint_data

    local_device_count = jax.local_device_count()
    while step < config['max_steps']:
        step_rng, rng = tuple(jax.random.split(rng))
        # Broadcast the random seeds across the devices
        step_rng_device = jax.random.split(step_rng, num=jax.device_count())
        step_rng_device = step_rng_device[host_id *
                                          local_device_count:(host_id + 1) *
                                          local_device_count]
        step_device = np.broadcast_to(step, [local_device_count])

        # Perform a training step and get scalars to log.
        scalars = experiment.step(global_step=step_device, rng=step_rng_device)

        # Checkpointing and logging.
        if config['checkpointing_config']['use_checkpointing']:
            experiment.save_checkpoint(step, rng)
            current_time = time.time()
            if current_time - last_logging > FLAGS.log_tensors_interval:
                logging.info('Step %d: %s', step, scalars)
                last_logging = current_time
        step += 1
    logging.info('Saving final checkpoint')
    logging.info('Step %d: %s', step, scalars)
    experiment.save_checkpoint(step, rng)
Example #9
0
def _dataset(load_fn,
             is_training: bool,
             total_batch_size: int,
             ratio: Optional[float] = None,
             one_minus_ratio: Optional[float] = None,
             repeat: int = 1) -> tf.data.Dataset:
    """Creates a dataset."""
    num_devices = jax.device_count()
    per_device_batch_size, ragged = divmod(total_batch_size, num_devices)
    if ragged:
        raise ValueError(
            f'Global batch size {total_batch_size} must be divisible by the '
            f'total number of devices {num_devices}')
    if repeat > 1:
        if per_device_batch_size % repeat:
            raise ValueError(
                f'Per device batch size {per_device_batch_size} must be divisible '
                f'by the number of repeated batches {repeat}')
        per_device_batch_size //= repeat
    if ratio is None and one_minus_ratio is None:
        pass  # Use full batch size.
    elif one_minus_ratio is None:
        per_device_batch_size = max(
            1,
            min(round(per_device_batch_size * ratio),
                per_device_batch_size - 1))
    elif ratio is None:
        batch_size = max(
            1,
            min(round(per_device_batch_size * one_minus_ratio),
                per_device_batch_size - 1))
        per_device_batch_size = per_device_batch_size - batch_size
    else:
        raise ValueError('Only one of `ratio` or `one_minus_ratio` must be '
                         'specified')
    if repeat > 1:
        per_device_batch_size *= repeat
    # When testing, we need to batch data across all devices (not just local
    # devices).
    num_local_devices = jax.local_device_count()
    if is_training:
        batch_sizes = [num_local_devices, per_device_batch_size]
    else:
        num_hosts = jax.host_count()
        assert num_hosts * num_local_devices == num_devices
        batch_sizes = [num_hosts, num_local_devices, per_device_batch_size]
    return load_fn(batch_sizes, is_training=is_training)
Example #10
0
  def setUp(self):
    super(DistributedTest, self).setUp()
    if JAX_MODE:
      os.environ['XLA_FLAGS'] = (
          '--xla_force_host_platform_device_count={}'.format(NUM_DEVICES))
      assert jax.device_count() == NUM_DEVICES
      self.key = jax.random.PRNGKey(0)
    else:
      physical_devices = tf.config.experimental.list_physical_devices()

      tf.config.experimental.set_virtual_device_configuration(
          physical_devices[0],
          [tf.config.experimental.VirtualDeviceConfiguration()] * NUM_DEVICES)
      self.strategy = tf.distribute.MirroredStrategy(
          devices=tf.config.list_logical_devices())
      self.key = [0, 0]
    self.axis_name = 'i'
    def test_fake_data(self):
        model_dir = self.get_tmp_model_dir()
        FLAGS.batch_size = 256 * jax.device_count()
        FLAGS.half_precision = True
        FLAGS.num_epochs = 5
        FLAGS.model_dir = model_dir

        start_time = time.time()
        train.main([])
        benchmark_time = time.time() - start_time

        self.report_wall_time(benchmark_time)
        self.report_extras({
            'description': 'ImageNet ResNet50 with fake data',
            'model_name': 'resnet50',
            'parameters': f'hp=true,bs={FLAGS.batch_size}',
        })
Example #12
0
def main(unused_argv):
    # Necessary to use the tfds imagenet loader.
    tf.enable_v2_behavior()

    rng = jax.random.PRNGKey(FLAGS.seed)

    if FLAGS.hessian_eval_config:
        hessian_eval_config = json.loads(FLAGS.hessian_eval_config)
    else:
        hessian_eval_config = hessian_eval.DEFAULT_EVAL_CONFIG

    if FLAGS.experiment_config_filename:
        with tf.io.gfile.GFile(FLAGS.experiment_config_filename, 'r') as f:
            experiment_config = json.load(f)
        if jax.host_id() == 0:
            logging.info('experiment_config: %r', experiment_config)
        dataset_name = experiment_config['dataset']
        model_name = experiment_config['model']
    else:
        assert FLAGS.dataset and FLAGS.model
        dataset_name = FLAGS.dataset
        model_name = FLAGS.model

    if jax.host_id() == 0:
        logging.info('argv:\n%s', ' '.join(sys.argv))
        logging.info('device_count: %d', jax.device_count())
        logging.info('num_hosts : %d', jax.host_count())
        logging.info('host_id : %d', jax.host_id())

    model = models.get_model(model_name)
    dataset_builder = datasets.get_dataset(dataset_name)
    dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)

    with tf.io.gfile.GFile(FLAGS.trial_hparams_filename, 'r') as f:
        hps = config_dict.ConfigDict(json.load(f))

    if FLAGS.hparam_overrides:
        if isinstance(FLAGS.hparam_overrides, str):
            hparam_overrides = json.loads(FLAGS.hparam_overrides)
        hps.update_from_flattened_dict(hparam_overrides)
    run_lanczos.eval_checkpoints(FLAGS.checkpoint_dir, hps, rng,
                                 FLAGS.eval_num_batches, model,
                                 dataset_builder, dataset_meta_data,
                                 hessian_eval_config, FLAGS.min_global_step,
                                 FLAGS.max_global_step,
                                 FLAGS.use_deprecated_checkpointing)
Example #13
0
  def testBasic(self):
    if jax.device_count() < 2:
      raise SkipTest

    @partial(sharded_jit, in_parts=(P(2, 1), P(2, 1)), out_parts=None)
    def f(x, y):
      return x + y

    shape = (8, 8)
    x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    actual = f(x, x + 1)
    expected = x + (x + 1)
    self.assertAllClose(actual, expected, check_dtypes=False)
    self.assertIsInstance(actual, pxla.ShardedDeviceArray)
    self.assertLen(actual.device_buffers, 2)
    self.assertAllClose(actual.device_buffers[0].to_py(), expected,
                        check_dtypes=False)
Example #14
0
    def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
        """See base class."""
        num_devices = jax.device_count()
        global_batch_size = self.config.training.batch_size
        per_device_batch_size, ragged = divmod(global_batch_size, num_devices)

        if ragged:
            raise ValueError(
                f'Global batch size {global_batch_size} must be divisible by '
                f'num devices {num_devices}')

        split = dataset.Split.TRAIN_AND_VALID

        return self._load_data(
            split=split,
            is_training=True,
            batch_dims=[jax.local_device_count(), per_device_batch_size])
Example #15
0
def lr_schedule(step: jnp.ndarray) -> jnp.ndarray:
    """Linear scaling rule optimized for 90 epochs."""
    train_split = dataset.Split.from_string(FLAGS.train_split)

    # See Section 5.1 of https://arxiv.org/pdf/1706.02677.pdf.
    total_batch_size = FLAGS.train_device_batch_size * jax.device_count()
    steps_per_epoch = train_split.num_examples / total_batch_size

    current_epoch = step / steps_per_epoch  # type: float
    lr = (0.1 * total_batch_size) / 256
    lr_linear_till = 5
    boundaries = jnp.array((30, 60, 80)) * steps_per_epoch
    values = jnp.array([1., 0.1, 0.01, 0.001]) * lr

    index = jnp.sum(boundaries < step)
    lr = jnp.take(values, index)
    return lr * jnp.minimum(1., current_epoch / lr_linear_till)
def build_ensemble_optimizer(ensemble_size, shared_params, ensemble_params,
                             optimizer, optimizer_kwargs):
    num_devices = jax.device_count()
    assert ensemble_size % num_devices == 0
    num_models_per_device = ensemble_size // num_devices

    optim = optimizer(**optimizer_kwargs)
    shared_params_optim_state = optim.init(shared_params)

    ensemble_params_optim_init = jax.pmap(jax.vmap(optim.init,
                                                   in_axes=0,
                                                   out_axes=0),
                                          in_axes=0,
                                          out_axes=0)
    ensemble_params_optim_state = ensemble_params_optim_init(ensemble_params)

    return optim, shared_params_optim_state, ensemble_params_optim_state
Example #17
0
def main(unused_argv):

  # TODO(gdahl) Figure out a better way to handle passing more complicated
  # flags to the binary.
  training_metrics_config = None
  if FLAGS.training_metrics_config:
    training_metrics_config = json.loads(FLAGS.training_metrics_config)

  checkpoint_steps = [int(s.strip()) for s in FLAGS.checkpoint_steps]
  eval_steps = [int(s.strip()) for s in FLAGS.eval_steps]
  if jax.host_id() == 0:
    tf.io.gfile.makedirs(FLAGS.experiment_dir)
  log_dir = os.path.join(FLAGS.experiment_dir, 'r=3/')
  tf.io.gfile.makedirs(log_dir)
  log_path = os.path.join(
      log_dir, 'worker{}_{}.log'.format(FLAGS.worker_id, jax.host_id()))
  with tf.io.gfile.GFile(log_path, 'a') as logfile:
    utils.add_log_file(logfile)
    if jax.host_id() == 0:
      logging.info('argv:\n%s', ' '.join(sys.argv))
      logging.info('device_count: %d', jax.device_count())
      logging.info('num_hosts : %d', jax.host_count())
      logging.info('host_id : %d', jax.host_id())
      logging.info('checkpoint_steps: %r', checkpoint_steps)
      logging.info('eval_steps: %r', eval_steps)

    trainer.run(
        dataset_name=FLAGS.dataset,
        eval_batch_size=FLAGS.eval_batch_size,
        eval_num_batches=FLAGS.eval_num_batches,
        eval_train_num_batches=FLAGS.eval_train_num_batches,
        eval_frequency=FLAGS.eval_frequency,
        checkpoint_steps=checkpoint_steps,
        eval_steps=eval_steps,
        hparam_file=FLAGS.hparam_file,
        hparam_overrides=FLAGS.hparam_overrides,
        initializer_name=FLAGS.initializer,
        model_name=FLAGS.model,
        loss_name=FLAGS.loss,
        metrics_name=FLAGS.metrics,
        num_train_steps=FLAGS.num_train_steps,
        experiment_dir=FLAGS.experiment_dir,
        worker_id=FLAGS.worker_id,
        training_metrics_config=training_metrics_config,
        use_deprecated_checkpointing=FLAGS.use_deprecated_checkpointing)
Example #18
0
    def test_distributed_init(self):
        n_devices = jax.device_count()
        batch_size = 5 * n_devices

        x = np.random.uniform(size=(batch_size, 1))
        y = 1.4 * x + 0.1 * np.random.uniform(size=(batch_size, 2))

        model = eg.Model(
            eg.Linear(2),
            loss=[eg.losses.MeanSquaredError()],
        )

        model = model.distributed()

        model.init_on_batch(x)

        assert model.module.kernel.shape == (n_devices, 1, 2)
        assert model.module.bias.shape == (n_devices, 2)
Example #19
0
    def from_local(self, model: M) -> M:
        # device_idxs used to inform pmap about the number of devices
        device_idxs = jnp.arange(jax.device_count())

        def device_fn(idx: int, model: M) -> M:
            # fold rng state so its unique for each device
            return model.map(
                lambda key: jax.random.fold_in(key, idx),
                tx.Rng,
                inplace=True,
            )

        model = jax.pmap(
            device_fn,
            in_axes=(0, None),
        )(device_idxs, model)

        return model
def build_train_input(data_dir, batch_size, img_size, augmentation):
    num_devices = jax.device_count()
    bs_per_device, ragged = divmod(batch_size, num_devices)
    if ragged:
        raise ValueError(
            f'Batch size {batch_size} must be divisible by num devices {num_devices}'
        )
    return input_pipeline.load(
        input_pipeline.Split.TRAIN_AND_VALID,
        data_dir=data_dir,
        is_training=True,
        batch_dims=[jax.local_device_count(), bs_per_device],
        transpose=True,
        image_size=(img_size, img_size),
        augment_name=augmentation,
        augment_before_mix=True,
        name='imagenet',
        fake_data=False)
def get_gpu_count():
    """
    Return the number of available gpus (regardless of whether torch, tf or jax is used)
    """
    if is_torch_available():
        import torch

        return torch.cuda.device_count()
    elif is_tf_available():
        import tensorflow as tf

        return len(tf.config.list_physical_devices("GPU"))
    elif is_flax_available():
        import jax

        return jax.device_count()
    else:
        return 0
Example #22
0
  def test_counters(self):
    res = unittest.TestResult()
    ts = unittest.makeSuite(self.InnerTest)  # pytype: disable=module-attr
    ts.run(res)

    active_pmap = int(jax.device_count() > 1)
    self.assertEqual(self.InnerTest.test_1_count, 0)
    self.assertEqual(self.InnerTest.test_2_count, 0)
    self.assertEqual(self.InnerTest.test_3_count, 4 + active_pmap)
    self.assertEqual(self.InnerTest.test_4_count, 4)
    self.assertEqual(self.InnerTest.test_5_count, 1)
    self.assertEqual(self.InnerTest.test_6_count, 2)

    # Test methods do not use `self.variant`.
    self.assertLen(res.errors, 1 + 2 + 4 + 4 + active_pmap)
    for _, msg in res.errors:
      self.assertRegex(
          msg, 'RuntimeError: Test is wrapped .+ but never calls self.variant')
 def _build_dataset(self, data_rng: spec.RandomState, split: str,
                    data_dir: str, batch_size):
   if batch_size % jax.device_count() > 0:
     raise ValueError('Batch size must be divisible by the number of devices')
   ds_builder = tfds.builder('imagenet2012:5.*.*')
   ds_builder.download_and_prepare()
   ds = input_pipeline.create_input_iter(
       ds_builder,
       batch_size,
       self.train_mean,
       self.train_stddev,
       self.center_crop_size,
       self.resize_size,
       self.aspect_ratio_range,
       self.scale_ratio_range,
       train=True,
       cache=False)
   return ds
Example #24
0
 def _build_train_input(self):
     num_devices = jax.device_count()
     global_batch_size = self.config.train_batch_size
     bs_per_device, ragged = divmod(global_batch_size, num_devices)
     if ragged:
         raise ValueError(
             f'Global batch size {global_batch_size} must be divisible by '
             f'num devices {num_devices}')
     return input_pipeline.load(
         input_pipeline.Split.TRAIN_AND_VALID,
         is_training=True,
         batch_dims=[jax.local_device_count(), bs_per_device],
         transpose=self.config.get('transpose', False),
         image_size=(self.train_imsize, self.train_imsize),
         augment_name=self.config.augment_name,
         augment_before_mix=self.config.get('augment_before_mix', True),
         name=self.config.which_dataset,
         fake_data=False)
Example #25
0
def main(unused_argv):
    if jax.process_index() == 0:
        logging.info('argv:\n%s', ' '.join(sys.argv))
        logging.info('device_count: %d', jax.device_count())
        logging.info('num_hosts : %d', jax.process_count())
        logging.info('host_id : %d', jax.process_index())

        if FLAGS.batch_size is None or FLAGS.batch_size <= 0:
            raise ValueError("""FLAGS.batch_size value is invalid,
          expected a positive non-zero integer.""")

        if FLAGS.dataset is None:
            raise ValueError("""FLAGS.dataset value is invalid,
          expected a non-empty string describing dataset name.""")

        batch_size = FLAGS.batch_size
        num_batches = FLAGS.num_batches
        dataset_name = FLAGS.dataset
        model_name = FLAGS.model
        initializer_name = 'noop'

        hparam_overrides = {
            'batch_size': batch_size,
        }

        hps = hyperparameters.build_hparams(model_name=model_name,
                                            initializer_name=initializer_name,
                                            dataset_name=dataset_name,
                                            hparam_file=None,
                                            hparam_overrides=hparam_overrides)

        rng = jax.random.PRNGKey(0)
        rng, data_rng = jax.random.split(rng)

        dataset = datasets.get_dataset(FLAGS.dataset)(data_rng, batch_size,
                                                      batch_size, hps)
        train_iter = dataset.train_iterator_fn()

        for i in range(num_batches):
            batch = next(train_iter)
            logging.info('train batch_num = %d, batch = %r', i, batch)

        for batch in dataset.valid_epoch(num_batches):
            logging.info('validation batch = %r', batch)
Example #26
0
    def testPyTreeOutputs(self):
        if jax.device_count() < 2:
            raise SkipTest

        def f(x):
            return x + 1, ((x + 2, x + 3), x + 4)

        shape = (4, 4)
        x = np.arange(prod(shape)).reshape(shape)
        in_parts = (P(2, 1), )
        out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

        result = sharded_jit(f, in_parts, out_parts)(x)
        expected = f(x)
        self.assertAllClose(result, expected, check_dtypes=False)

        out_parts = None
        result = sharded_jit(f, in_parts, out_parts)(x)
        self.assertAllClose(result, expected, check_dtypes=False)
Example #27
0
  def make_dataset_iterator(
      self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
    """Creates a dataset."""
    batch_size_per_learner = self._config.batch_size // jax.process_count()
    batch_size_per_device, ragged = divmod(self._config.batch_size,
                                           jax.device_count())
    if ragged:
      raise ValueError(
          'Learner batch size must be divisible by total number of devices!')

    dataset = datasets.make_reverb_dataset(
        table=self._config.replay_table_name,
        server_address=replay_client.server_address,
        batch_size=batch_size_per_device,
        num_parallel_calls=None,
        max_in_flight_samples_per_worker=2 * batch_size_per_learner)

    return utils.multi_device_put(dataset.as_numpy_iterator(),
                                  jax.local_devices())
Example #28
0
    def test_ordered_effect_remains_ordered_across_multiple_devices(self):
        # TODO(sharadmv): remove jaxlib check when minimum version is bumped
        # TODO(sharadmv): enable this test on GPU and TPU when backends are
        # supported
        if jaxlib.version < (0, 3, 8):
            raise unittest.SkipTest(
                "`emit_python_callback` only supported in jaxlib >= 0.3.8")
        if jax.device_count() < 2:
            raise unittest.SkipTest("Test requires >= 2 devices.")
        log = []

        def log_value(x):
            log.append(x)
            return ()

        @functools.partial(jax.jit, device=jax.devices()[0])
        def f(x):
            # Expensive computation
            x = x.dot(x)
            x = jnp.log(x.sum())
            return callback_p.bind(x,
                                   callback=log_value,
                                   effect='log',
                                   out_avals=[])

        @functools.partial(jax.jit, device=jax.devices()[1])
        def g(x):
            return callback_p.bind(x,
                                   callback=log_value,
                                   effect='log',
                                   out_avals=[])

        f(jnp.ones((500, 500)))
        g(3.)
        f(jnp.ones((500, 500)))
        g(3.)
        f(jnp.ones((500, 500)))
        g(3.)
        dispatch.runtime_tokens.block_until_ready()
        x_, y_ = float(jnp.log(1.25e8)), 3.
        expected_log = [x_, y_, x_, y_, x_, y_]
        self.assertListEqual(log, expected_log)
def main(argv: Sequence[str]) -> None:
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())
    logging.info('JAX total devices: %r', jax.device_count())

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'process_index: {jax.process_index()}, '
        f'process_count: {jax.process_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.output_dir, 'output_dir')

    tf.io.gfile.makedirs(FLAGS.output_dir)

    # Process config here
    if FLAGS.config_file:
        with tf.io.gfile.GFile(FLAGS.config_file, 'r') as reader:
            config = json.load(reader)
    else:
        config = json.loads(FLAGS.config)
        # # Save config to workdir if it's not yet exists
        if jax.process_index() == 0:
            config_file = os.path.join(FLAGS.output_dir, 'config.json')
            with tf.io.gfile.GFile(config_file, 'w') as writer:
                writer.write(json.dumps(config, indent=4))

    config['output_dir'] = FLAGS.output_dir

    if 'num_total_memories' not in config:
        config['num_total_memories'] = get_num_total_memories(
            ml_collections.ConfigDict(config))

    generate(ml_collections.ConfigDict(config))
Example #30
0
    def _update_func(
        self,
        base_rng,
        state,
        inputs,
    ):
        """Computes loss and updates model parameters."""
        step = state.step
        rng = jax.random.fold_in(base_rng, jax.lax.axis_index('batch'))
        rng = jax.random.fold_in(rng, step)
        grad_loss_fn = jax.value_and_grad(self._loss_fn, has_aux=True)
        (_, loss_dict), scaled_grads = grad_loss_fn(state.params, inputs, rng)
        grads = jax.lax.psum(scaled_grads, axis_name='batch')
        grad_norm = optax.global_norm(grads)
        loss_dict['scalars']['grad_norm'] = grad_norm

        # Compute and apply updates via our optimizer.
        learning_rate = self.learning_rate(state.step)
        loss_dict['scalars']['learning_rate'] = learning_rate
        _, opt_apply = self.optimizer(learning_rate)
        updates, new_opt_state = opt_apply(grads, state.opt_state,
                                           state.params)
        new_params = optax.apply_updates(state.params, updates)

        # Update ema params
        ema_rate = self.config.evaluation.ema_rate
        new_ema_params = jax.tree_multimap(
            lambda x, y: x + (1 - ema_rate) * (y - x),
            state.ema_params,
            new_params,
        )
        new_state = state.replace(step=step + 1,
                                  params=new_params,
                                  ema_params=new_ema_params,
                                  opt_state=new_opt_state)

        # Rescale loss dict and return
        loss_dict['scalars'] = jax.tree_map(
            lambda x: jax.lax.psum(x, axis_name='batch') / jax.device_count(),
            loss_dict['scalars'],
        )
        return new_state, loss_dict