Ejemplo n.º 1
0
def create_checkpoint(model, path):
    model = models.KNOWN_MODELS[model].partial(num_classes=1)
    _, params = model.init_by_shape(
        jax.random.PRNGKey(0),
        [((1, 16, 16, 3), jnp.float32)],
    )
    checkpoint.save(params, path)
Ejemplo n.º 2
0
def create_checkpoint(model_config, path):
    """Initializes model and stores weights in specified path."""
    model = models.VisionTransformer(num_classes=1, **model_config)
    variables = model.init(
        jax.random.PRNGKey(0),
        jnp.ones([1, 16, 16, 3], jnp.float32),
        train=False,
    )
    checkpoint.save(variables['params'], path)
Ejemplo n.º 3
0
def main(args):
    logdir = os.path.join(args.logdir, args.name)
    logger = logging.setup_logger(logdir)
    logger.info(args)

    logger.info(f'Available devices: {jax.devices()}')

    # Setup input pipeline
    dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train')

    ds_train = input_pipeline.get_data(dataset=args.dataset,
                                       mode='train',
                                       repeats=None,
                                       mixup_alpha=args.mixup_alpha,
                                       batch_size=args.batch,
                                       shuffle_buffer=args.shuffle_buffer,
                                       tfds_data_dir=args.tfds_data_dir,
                                       tfds_manual_dir=args.tfds_manual_dir)
    batch = next(iter(ds_train))
    logger.info(ds_train)
    ds_test = input_pipeline.get_data(dataset=args.dataset,
                                      mode='test',
                                      repeats=1,
                                      batch_size=args.batch_eval,
                                      tfds_data_dir=args.tfds_data_dir,
                                      tfds_manual_dir=args.tfds_manual_dir)
    logger.info(ds_test)

    # Build VisionTransformer architecture
    model = models.KNOWN_MODELS[args.model]
    VisionTransformer = model.partial(num_classes=dataset_info['num_classes'])
    _, params = VisionTransformer.init_by_shape(
        jax.random.PRNGKey(0),
        # Discard the "num_local_devices" dimension for initialization.
        [(batch['image'].shape[1:], batch['image'].dtype.name)])

    pretrained_path = os.path.join(args.vit_pretrained_dir,
                                   f'{args.model}.npz')
    params = checkpoint.load_pretrained(
        pretrained_path=pretrained_path,
        init_params=params,
        model_config=models.CONFIGS[args.model],
        logger=logger)

    # pmap replicates the models over all TPUs/GPUs
    vit_fn_repl = jax.pmap(VisionTransformer.call)
    update_fn_repl = make_update_fn(VisionTransformer.call, args.accum_steps)

    # Create optimizer and replicate it over all TPUs/GPUs
    opt = momentum_clip.Optimizer(
        dtype=args.optim_dtype,
        grad_norm_clip=args.grad_norm_clip).create(params)
    opt_repl = flax_utils.replicate(opt)

    # Delete referenes to the objects that are not needed anymore
    del opt
    del params

    def copyfiles(paths):
        """Small helper to copy files to args.copy_to using tf.io.gfile."""
        if not args.copy_to:
            return
        for path in paths:
            to_path = os.path.join(args.copy_to, args.name,
                                   os.path.basename(path))
            tf.io.gfile.makedirs(os.path.dirname(to_path))
            tf.io.gfile.copy(path, to_path, overwrite=True)
            logger.info(f'Copied {path} to {to_path}.')

    total_steps = args.total_steps or (
        input_pipeline.DATASET_PRESETS[args.dataset]['total_steps'])

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = hyper.create_learning_rate_schedule(total_steps, args.base_lr,
                                                args.decay_type,
                                                args.warmup_steps)
    lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps)
    update_rngs = jax.random.split(jax.random.PRNGKey(0),
                                   jax.local_device_count())

    # Run training loop
    writer = metric_writers.create_default_writer(logdir, asynchronous=False)
    writer.write_hparams(
        {k: v
         for k, v in vars(args).items() if v is not None})
    logger.info('Starting training loop; initial compile can take a while...')
    t0 = time.time()

    for step, batch, lr_repl in zip(
            range(1, total_steps + 1),
            input_pipeline.prefetch(ds_train, args.prefetch), lr_iter):

        opt_repl, loss_repl, update_rngs = update_fn_repl(
            opt_repl, lr_repl, batch, update_rngs)

        if step == 1:
            logger.info(f'First step took {time.time() - t0:.1f} seconds.')
            t0 = time.time()
        if args.progress_every and step % args.progress_every == 0:
            writer.write_scalars(step, dict(train_loss=float(loss_repl[0])))
            done = step / total_steps
            logger.info(f'Step: {step}/{total_steps} {100*done:.1f}%, '
                        f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')
            copyfiles(glob.glob(f'{logdir}/*'))

        # Run eval step
        if ((args.eval_every and step % args.eval_every == 0)
                or (step == total_steps)):

            accuracy_test = np.mean([
                c for batch in input_pipeline.prefetch(ds_test, args.prefetch)
                for c in (np.argmax(
                    vit_fn_repl(opt_repl.target, batch['image']), axis=2) ==
                          np.argmax(batch['label'], axis=2)).ravel()
            ])

            lr = float(lr_repl[0])
            logger.info(f'Step: {step} '
                        f'Learning rate: {lr:.7f}, '
                        f'Test accuracy: {accuracy_test:0.5f}')
            writer.write_scalars(step, dict(accuracy_test=accuracy_test,
                                            lr=lr))
            copyfiles(glob.glob(f'{logdir}/*'))

    if args.output:
        checkpoint.save(flax_utils.unreplicate(opt_repl.target), args.output)
        logger.info(f'Stored fine tuned checkpoint to {args.output}')
        copyfiles([args.output])
Ejemplo n.º 4
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs training interleaved with evaluation."""

  # Setup input pipeline
  dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train')

  ds_train = input_pipeline.get_data(
      dataset=config.dataset,
      mode='train',
      repeats=None,
      mixup_alpha=config.mixup_alpha,
      batch_size=config.batch,
      pp_config=config.pp,
      shuffle_buffer=config.shuffle_buffer,
      tfds_data_dir=config.tfds_data_dir,
      tfds_manual_dir=config.tfds_manual_dir)
  batch = next(iter(ds_train))
  logging.info(ds_train)
  ds_test = input_pipeline.get_data(
      dataset=config.dataset,
      mode='test',
      repeats=1,
      batch_size=config.batch_eval,
      pp_config=config.pp,
      tfds_data_dir=config.tfds_data_dir,
      tfds_manual_dir=config.tfds_manual_dir)
  logging.info(ds_test)

  # Build VisionTransformer architecture
  model_cls = {'ViT': models.VisionTransformer,
               'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')]
  model = model_cls(num_classes=dataset_info['num_classes'], **config.model)

  def init_model():
    return model.init(
        jax.random.PRNGKey(0),
        # Discard the "num_local_devices" dimension for initialization.
        jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name),
        train=False)

  # Use JIT to make sure params reside in CPU memory.
  variables = jax.jit(init_model, backend='cpu')()

  pretrained_path = os.path.join(config.pretrained_dir,
                                 f'{config.model.name}.npz')
  if not tf.io.gfile.exists(pretrained_path):
    raise ValueError(
        f'Could not find "{pretrained_path}" - you can download models from '
        '"gs://vit_models/imagenet21k" or directly set '
        '--config.pretrained_dir="gs://vit_models/imagenet21k".')
  params = checkpoint.load_pretrained(
      pretrained_path=pretrained_path,
      init_params=variables['params'],
      model_config=config.model)

  total_steps = config.total_steps
  lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr,
                                              config.decay_type,
                                              config.warmup_steps)

  update_fn_repl = make_update_fn(
      apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn)
  infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False))

  # Create optimizer and replicate it over all TPUs/GPUs
  opt = momentum_clip.Optimizer(
      dtype=config.optim_dtype,
      grad_norm_clip=config.grad_norm_clip).create(params)
  opt_repl = flax.jax_utils.replicate(opt)

  # Delete references to the objects that are not needed anymore
  del opt
  del params

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))

  # Run training loop
  writer = metric_writers.create_default_writer(workdir, asynchronous=False)
  writer.write_hparams(config.to_dict())
  logging.info('Starting training loop; initial compile can take a while...')
  t0 = lt0 = time.time()

  for step, batch in zip(
      range(1, total_steps + 1),
      input_pipeline.prefetch(ds_train, config.prefetch)):

    opt_repl, loss_repl, update_rng_repl = update_fn_repl(
        opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)

    if step == 1:
      logging.info('First step took %.1f seconds.', time.time() - t0)
      t0 = time.time()
      lt0, lstep = time.time(), step

    # Report training metrics
    if config.progress_every and step % config.progress_every == 0:
      img_sec_core_train = (config.batch * (step - lstep) /
                            (time.time() - lt0)) / jax.device_count()
      lt0, lstep = time.time(), step
      writer.write_scalars(
          step,
          dict(
              train_loss=float(flax.jax_utils.unreplicate(loss_repl)),
              img_sec_core_train=img_sec_core_train))
      done = step / total_steps
      logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, '  # pylint: disable=logging-format-interpolation
                   f'img/sec/core: {img_sec_core_train:.1f}, '
                   f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')

    # Run evaluation
    if ((config.eval_every and step % config.eval_every == 0) or
        (step == total_steps)):

      accuracies = []
      lt0 = time.time()
      for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
        logits = infer_fn_repl(
            dict(params=opt_repl.target), test_batch['image'])
        accuracies.append(
            (np.argmax(logits,
                       axis=-1) == np.argmax(test_batch['label'],
                                             axis=-1)).mean())
      accuracy_test = np.mean(accuracies)
      img_sec_core_test = (
          config.batch_eval * ds_test.cardinality().numpy() /
          (time.time() - lt0) / jax.device_count())
      lt0 = time.time()

      lr = float(lr_fn(step))
      logging.info(f'Step: {step} '  # pylint: disable=logging-format-interpolation
                   f'Learning rate: {lr:.7f}, '
                   f'Test accuracy: {accuracy_test:0.5f}, '
                   f'img/sec/core: {img_sec_core_test:.1f}')
      writer.write_scalars(
          step,
          dict(
              accuracy_test=accuracy_test,
              lr=lr,
              img_sec_core_test=img_sec_core_test))

  opt = flax.jax_utils.unreplicate(opt_repl)
  del opt_repl
  checkpoint.save(opt.target, f'{workdir}/model.npz')
  logging.info('Stored fine tuned checkpoint to %s', workdir)

  return opt