def create_checkpoint(model, path): model = models.KNOWN_MODELS[model].partial(num_classes=1) _, params = model.init_by_shape( jax.random.PRNGKey(0), [((1, 16, 16, 3), jnp.float32)], ) checkpoint.save(params, path)
def create_checkpoint(model_config, path): """Initializes model and stores weights in specified path.""" model = models.VisionTransformer(num_classes=1, **model_config) variables = model.init( jax.random.PRNGKey(0), jnp.ones([1, 16, 16, 3], jnp.float32), train=False, ) checkpoint.save(variables['params'], path)
def main(args): logdir = os.path.join(args.logdir, args.name) logger = logging.setup_logger(logdir) logger.info(args) logger.info(f'Available devices: {jax.devices()}') # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train') ds_train = input_pipeline.get_data(dataset=args.dataset, mode='train', repeats=None, mixup_alpha=args.mixup_alpha, batch_size=args.batch, shuffle_buffer=args.shuffle_buffer, tfds_data_dir=args.tfds_data_dir, tfds_manual_dir=args.tfds_manual_dir) batch = next(iter(ds_train)) logger.info(ds_train) ds_test = input_pipeline.get_data(dataset=args.dataset, mode='test', repeats=1, batch_size=args.batch_eval, tfds_data_dir=args.tfds_data_dir, tfds_manual_dir=args.tfds_manual_dir) logger.info(ds_test) # Build VisionTransformer architecture model = models.KNOWN_MODELS[args.model] VisionTransformer = model.partial(num_classes=dataset_info['num_classes']) _, params = VisionTransformer.init_by_shape( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension for initialization. [(batch['image'].shape[1:], batch['image'].dtype.name)]) pretrained_path = os.path.join(args.vit_pretrained_dir, f'{args.model}.npz') params = checkpoint.load_pretrained( pretrained_path=pretrained_path, init_params=params, model_config=models.CONFIGS[args.model], logger=logger) # pmap replicates the models over all TPUs/GPUs vit_fn_repl = jax.pmap(VisionTransformer.call) update_fn_repl = make_update_fn(VisionTransformer.call, args.accum_steps) # Create optimizer and replicate it over all TPUs/GPUs opt = momentum_clip.Optimizer( dtype=args.optim_dtype, grad_norm_clip=args.grad_norm_clip).create(params) opt_repl = flax_utils.replicate(opt) # Delete referenes to the objects that are not needed anymore del opt del params def copyfiles(paths): """Small helper to copy files to args.copy_to using tf.io.gfile.""" if not args.copy_to: return for path in paths: to_path = os.path.join(args.copy_to, args.name, os.path.basename(path)) tf.io.gfile.makedirs(os.path.dirname(to_path)) tf.io.gfile.copy(path, to_path, overwrite=True) logger.info(f'Copied {path} to {to_path}.') total_steps = args.total_steps or ( input_pipeline.DATASET_PRESETS[args.dataset]['total_steps']) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = hyper.create_learning_rate_schedule(total_steps, args.base_lr, args.decay_type, args.warmup_steps) lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps) update_rngs = jax.random.split(jax.random.PRNGKey(0), jax.local_device_count()) # Run training loop writer = metric_writers.create_default_writer(logdir, asynchronous=False) writer.write_hparams( {k: v for k, v in vars(args).items() if v is not None}) logger.info('Starting training loop; initial compile can take a while...') t0 = time.time() for step, batch, lr_repl in zip( range(1, total_steps + 1), input_pipeline.prefetch(ds_train, args.prefetch), lr_iter): opt_repl, loss_repl, update_rngs = update_fn_repl( opt_repl, lr_repl, batch, update_rngs) if step == 1: logger.info(f'First step took {time.time() - t0:.1f} seconds.') t0 = time.time() if args.progress_every and step % args.progress_every == 0: writer.write_scalars(step, dict(train_loss=float(loss_repl[0]))) done = step / total_steps logger.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') copyfiles(glob.glob(f'{logdir}/*')) # Run eval step if ((args.eval_every and step % args.eval_every == 0) or (step == total_steps)): accuracy_test = np.mean([ c for batch in input_pipeline.prefetch(ds_test, args.prefetch) for c in (np.argmax( vit_fn_repl(opt_repl.target, batch['image']), axis=2) == np.argmax(batch['label'], axis=2)).ravel() ]) lr = float(lr_repl[0]) logger.info(f'Step: {step} ' f'Learning rate: {lr:.7f}, ' f'Test accuracy: {accuracy_test:0.5f}') writer.write_scalars(step, dict(accuracy_test=accuracy_test, lr=lr)) copyfiles(glob.glob(f'{logdir}/*')) if args.output: checkpoint.save(flax_utils.unreplicate(opt_repl.target), args.output) logger.info(f'Stored fine tuned checkpoint to {args.output}') copyfiles([args.output])
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs training interleaved with evaluation.""" # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train') ds_train = input_pipeline.get_data( dataset=config.dataset, mode='train', repeats=None, mixup_alpha=config.mixup_alpha, batch_size=config.batch, pp_config=config.pp, shuffle_buffer=config.shuffle_buffer, tfds_data_dir=config.tfds_data_dir, tfds_manual_dir=config.tfds_manual_dir) batch = next(iter(ds_train)) logging.info(ds_train) ds_test = input_pipeline.get_data( dataset=config.dataset, mode='test', repeats=1, batch_size=config.batch_eval, pp_config=config.pp, tfds_data_dir=config.tfds_data_dir, tfds_manual_dir=config.tfds_manual_dir) logging.info(ds_test) # Build VisionTransformer architecture model_cls = {'ViT': models.VisionTransformer, 'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')] model = model_cls(num_classes=dataset_info['num_classes'], **config.model) def init_model(): return model.init( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension for initialization. jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name), train=False) # Use JIT to make sure params reside in CPU memory. variables = jax.jit(init_model, backend='cpu')() pretrained_path = os.path.join(config.pretrained_dir, f'{config.model.name}.npz') if not tf.io.gfile.exists(pretrained_path): raise ValueError( f'Could not find "{pretrained_path}" - you can download models from ' '"gs://vit_models/imagenet21k" or directly set ' '--config.pretrained_dir="gs://vit_models/imagenet21k".') params = checkpoint.load_pretrained( pretrained_path=pretrained_path, init_params=variables['params'], model_config=config.model) total_steps = config.total_steps lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr, config.decay_type, config.warmup_steps) update_fn_repl = make_update_fn( apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn) infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False)) # Create optimizer and replicate it over all TPUs/GPUs opt = momentum_clip.Optimizer( dtype=config.optim_dtype, grad_norm_clip=config.grad_norm_clip).create(params) opt_repl = flax.jax_utils.replicate(opt) # Delete references to the objects that are not needed anymore del opt del params # Prepare the learning-rate and pre-fetch it to device to avoid delays. update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0)) # Run training loop writer = metric_writers.create_default_writer(workdir, asynchronous=False) writer.write_hparams(config.to_dict()) logging.info('Starting training loop; initial compile can take a while...') t0 = lt0 = time.time() for step, batch in zip( range(1, total_steps + 1), input_pipeline.prefetch(ds_train, config.prefetch)): opt_repl, loss_repl, update_rng_repl = update_fn_repl( opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl) if step == 1: logging.info('First step took %.1f seconds.', time.time() - t0) t0 = time.time() lt0, lstep = time.time(), step # Report training metrics if config.progress_every and step % config.progress_every == 0: img_sec_core_train = (config.batch * (step - lstep) / (time.time() - lt0)) / jax.device_count() lt0, lstep = time.time(), step writer.write_scalars( step, dict( train_loss=float(flax.jax_utils.unreplicate(loss_repl)), img_sec_core_train=img_sec_core_train)) done = step / total_steps logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-format-interpolation f'img/sec/core: {img_sec_core_train:.1f}, ' f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') # Run evaluation if ((config.eval_every and step % config.eval_every == 0) or (step == total_steps)): accuracies = [] lt0 = time.time() for test_batch in input_pipeline.prefetch(ds_test, config.prefetch): logits = infer_fn_repl( dict(params=opt_repl.target), test_batch['image']) accuracies.append( (np.argmax(logits, axis=-1) == np.argmax(test_batch['label'], axis=-1)).mean()) accuracy_test = np.mean(accuracies) img_sec_core_test = ( config.batch_eval * ds_test.cardinality().numpy() / (time.time() - lt0) / jax.device_count()) lt0 = time.time() lr = float(lr_fn(step)) logging.info(f'Step: {step} ' # pylint: disable=logging-format-interpolation f'Learning rate: {lr:.7f}, ' f'Test accuracy: {accuracy_test:0.5f}, ' f'img/sec/core: {img_sec_core_test:.1f}') writer.write_scalars( step, dict( accuracy_test=accuracy_test, lr=lr, img_sec_core_test=img_sec_core_test)) opt = flax.jax_utils.unreplicate(opt_repl) del opt_repl checkpoint.save(opt.target, f'{workdir}/model.npz') logging.info('Stored fine tuned checkpoint to %s', workdir) return opt