예제 #1
0
  def test_small_training_job(self):
    experiment_dir = self.create_tempdir().full_path
    work_unit_dir = self.create_tempdir().full_path

    # Disable compiler optimizations for faster compile time.
    jax.config.update('jax_disable_most_optimizations', True)

    # Seed the random number generators.
    random.seed(0)
    onp.random.seed(0)
    rng = utils.RngGen(jax.random.PRNGKey(0))

    # Construct a test config with a small number of steps.
    config = main_test_config.get_config()

    # Patch normalization so that we don't try to apply GroupNorm with more
    # groups than test channels.
    orig_normalize = model.Normalize
    try:
      model.Normalize = nn.LayerNorm

      # Make sure we can train without any exceptions.
      main.run_train(
          config=config,
          experiment_dir=experiment_dir,
          work_unit_dir=work_unit_dir,
          rng=rng)

    finally:
      model.Normalize = orig_normalize
예제 #2
0
def main(executable_dict, argv):
    del argv

    work_unit = platform.work_unit()
    tf.enable_v2_behavior()
    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count())
    logging.info('JAX devices: %r', jax.devices())

    work_unit.set_task_status(
        f'host_id: {jax.host_id()}, host_count: {jax.host_count()}')

    # Read configuration
    if FLAGS.config_json:
        logging.info('Reading config from JSON: %s', FLAGS.config_json)
        with tf.io.gfile.GFile(FLAGS.config_json, 'r') as f:
            config = ml_collections.ConfigDict(json.loads(f.read()))
    else:
        config = FLAGS.config
    logging.info('config=%s',
                 config.to_json_best_effort(indent=4, sort_keys=True))

    # Make output directories
    if FLAGS.experiment_dir:
        work_unit.create_artifact(platform.ArtifactType.DIRECTORY,
                                  FLAGS.experiment_dir, 'experiment_dir')
    if FLAGS.work_unit_dir:
        work_unit.create_artifact(platform.ArtifactType.DIRECTORY,
                                  FLAGS.work_unit_dir, 'work_unit_dir')
    logging.info('experiment_dir=%s work_unit_dir=%s', FLAGS.experiment_dir,
                 FLAGS.work_unit_dir)

    # Seeding
    random.seed(config.seed * jax.host_count() + jax.host_id())
    onp.random.seed(config.seed * jax.host_count() + jax.host_id())
    rng = utils.RngGen(
        jax.random.fold_in(jax.random.PRNGKey(config.seed), jax.host_id()))

    # Run the main function
    logging.info('Running executable: %s', FLAGS.executable_name)

    extra_args = {}
    if FLAGS.extra_args_json_str:
        extra_args = json.loads(FLAGS.extra_args_json_str)
        logging.info('Extra args passed in: %r', extra_args)

    executable_dict[FLAGS.executable_name](config=config,
                                           experiment_dir=FLAGS.experiment_dir,
                                           work_unit_dir=FLAGS.work_unit_dir,
                                           rng=rng,
                                           **extra_args)

    utils.barrier()
예제 #3
0
    def loss_fn(self, rng, train, batch, params):
        rng = utils.RngGen(rng)

        # Input: image
        img = batch['image']
        assert img.dtype == jnp.int32

        # Input: label
        label = batch.get('label', None)
        if label is not None:
            assert label.shape == (img.shape[0], )
            assert label.dtype == jnp.int32

        def model_fn(x, t):
            return self.model.apply(
                {'params': params},
                x=x,
                t=t,
                y=label,
                train=train,
                rngs={'dropout': next(rng)} if train else None)

        dif = make_diffusion(self.config.model, num_bits=self.num_bits)
        loss = dif.training_losses(model_fn, x_start=img, rng=next(rng)).mean()
        if not train:
            loss_dict = dif.calc_bpd_loop(model_fn, x_start=img, rng=next(rng))
            total_bpd = jnp.mean(loss_dict['total'], axis=0)  # scalar
            # vb_terms = jnp.mean(loss_dict['vbterms'], axis=0)  # vec: num_timesteps
            prior_bpd = jnp.mean(loss_dict['prior'], axis=0)
            return loss, {
                'loss': loss,
                'prior_bpd': prior_bpd,
                'total_bpd': total_bpd
            }
        else:
            prior_bpd = dif.prior_bpd(img).mean()
            return loss, {
                'loss': loss,
                'prior_bpd': prior_bpd,
            }
예제 #4
0
 def get_tf_dataset(self, *, batch_shape, split, global_rng, repeat,
                    shuffle, augment, shard_id, num_shards):
     """Training dataset."""
     if split == 'train':
         split_str = 'train'
     elif split == 'eval':
         split_str = 'test'
     else:
         raise NotImplementedError
     if shuffle:
         global_rng = utils.RngGen(global_rng)
     ds = tfds.load('cifar10',
                    split=split_str,
                    shuffle_files=shuffle,
                    read_config=None if not shuffle else tfds.ReadConfig(
                        shuffle_seed=utils.jax_randint(next(global_rng))))
     if repeat:
         ds = ds.repeat()
     if shuffle:
         ds = ds.shuffle(50000, seed=utils.jax_randint(next(global_rng)))
     ds = shard_dataset(ds, shard_id=shard_id, num_shards=num_shards)
     return self._preprocess_and_batch(ds,
                                       batch_shape=batch_shape,
                                       augment=augment)
예제 #5
0
  def step_fn(self, base_rng, train, state,
              batch):
    """One training/eval step."""
    config = self.config

    # RNG for this step on this host
    step = state.step
    rng = jax.random.fold_in(base_rng, jax.lax.axis_index('batch'))
    rng = jax.random.fold_in(rng, step)
    rng = utils.RngGen(rng)

    # Loss and gradient
    loss_fn = functools.partial(self.loss_fn, next(rng), train, batch)

    if train:
      # Training mode
      (_, metrics), grad = jax.value_and_grad(
          loss_fn, has_aux=True)(
              state.optimizer.target)

      # Average grad across shards
      grad_clip = metrics['grad_clip'] = config.train.grad_clip
      grad, metrics['gnorm'] = utils.clip_by_global_norm(
          grad, clip_norm=grad_clip)
      grad = jax.lax.pmean(grad, axis_name='batch')

      # Learning rate
      if config.train.learning_rate_warmup_steps > 0:
        learning_rate = config.train.learning_rate * jnp.minimum(
            jnp.float32(step) / config.train.learning_rate_warmup_steps, 1.0)
      else:
        learning_rate = config.train.learning_rate
      metrics['lr'] = learning_rate

      # Update optimizer and EMA params
      new_optimizer = state.optimizer.apply_gradient(
          grad, learning_rate=learning_rate)
      new_ema_params = utils.apply_ema(
          decay=jnp.where(step == 0, 0.0, config.train.ema_decay),
          avg=state.ema_params,
          new=new_optimizer.target)
      new_state = state.replace(  # pytype: disable=attribute-error
          step=step + 1,
          optimizer=new_optimizer,
          ema_params=new_ema_params)
      if config.train.get('enable_update_skip', True):
        # Apply update if the new optimizer state is all finite
        ok = jnp.all(
            jnp.asarray([
                jnp.all(jnp.isfinite(p)) for p in jax.tree_leaves(new_optimizer)
            ]))
        new_state_no_update = state.replace(step=step + 1)
        state = jax.tree_map(lambda a, b: jnp.where(ok, a, b), new_state,
                                  new_state_no_update)
      else:
        logging.info('Update skipping disabled')
        state = new_state

    else:
      # Eval mode with EMA params
      _, metrics = loss_fn(state.ema_params)

    # Average metrics across shards
    metrics = jax.lax.pmean(metrics, axis_name='batch')
    # check that v.shape == () for all v in metric.values()
    assert all(not v.shape for v in metrics.values())
    metrics = {  # prepend prefix to names of metrics
        f"{'train' if train else 'eval'}/{k}": v for k, v in metrics.items()
    }
    return (state, metrics) if train else metrics
예제 #6
0
 def get_model_samples(self, params, rng):
   """Generate one batch of samples."""
   rng = utils.RngGen(rng)
   samples = self.p_gen_samples(params, rng.split(jax.local_device_count()))
   return samples