def create_input_iter(dataset_builder, batch_size, image_size, dtype, train, cache): ds = input_pipeline.create_split( dataset_builder, batch_size, image_size=image_size, dtype=dtype, train=train, cache=cache) it = map(prepare_tf_data, ds) it = jax_utils.prefetch_to_device(it, 2) return it
def main(argv): assert ( len(argv) == 1 ), "Please specify arguments via flags. Use --help for instructions" assert (getattr(elegy.nets.resnet, FLAGS.model, None) is not None), f"{FLAGS.model} is not defined in elegy.nets.resnet" assert not os.path.exists( FLAGS.output_dir ), "Output directory already exists. Delete manually or specify a new one." os.makedirs(FLAGS.output_dir) # dataset dataset_builder = tfds.builder(FLAGS.dataset) ds_train = input_pipeline.create_split( dataset_builder, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, dtype=FLAGS.dtype, train=True, cache=FLAGS.cache, ) ds_valid = input_pipeline.create_split( dataset_builder, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, dtype=FLAGS.dtype, train=False, cache=FLAGS.cache, ) N_BATCHES_TRAIN = (dataset_builder.info.splits["train"].num_examples // FLAGS.batch_size) N_BATCHES_VALID = ( dataset_builder.info.splits["validation"].num_examples // FLAGS.batch_size) # generator that converts tfds dataset batches to jax arrays def tfds2jax_generator(tf_ds): for batch in tf_ds: yield jnp.asarray(batch["image"], dtype=FLAGS.dtype), jax.device_put( jnp.asarray(batch["label"])) # model and optimizer definition def build_optimizer(lr, momentum, steps_per_epoch, n_epochs, nesterov, warmup_epochs=5): cosine_schedule = optax.cosine_decay_schedule(1, decay_steps=n_epochs * steps_per_epoch, alpha=1e-10) warmup_schedule = optax.polynomial_schedule( init_value=0.0, end_value=1.0, power=1, transition_steps=warmup_epochs * steps_per_epoch, ) schedule = lambda x: jnp.minimum(cosine_schedule(x), warmup_schedule(x) ) optimizer = optax.sgd(lr, momentum, nesterov=nesterov) optimizer = optax.chain(optimizer, optax.scale_by_schedule(schedule)) return optimizer module = getattr(elegy.nets.resnet, FLAGS.model)(dtype=FLAGS.dtype) model = elegy.Model( module, loss=[ elegy.losses.SparseCategoricalCrossentropy( from_logits=True, weight=FLAGS.loss_scale), elegy.regularizers.GlobalL2(FLAGS.L2_reg / 2 * FLAGS.loss_scale), ], metrics=elegy.metrics.SparseCategoricalAccuracy(), optimizer=build_optimizer( FLAGS.base_lr / FLAGS.loss_scale, FLAGS.momentum, N_BATCHES_TRAIN, FLAGS.epochs, FLAGS.nesterov, ), ) # training model.fit( x=tfds2jax_generator(ds_train), validation_data=tfds2jax_generator(ds_valid), epochs=FLAGS.epochs, verbose=2, steps_per_epoch=N_BATCHES_TRAIN, validation_steps=N_BATCHES_VALID, callbacks=[ elegy.callbacks.ModelCheckpoint(FLAGS.output_dir, save_best_only=True), elegy.callbacks.TerminateOnNaN(), elegy.callbacks.TensorBoard(logdir=FLAGS.output_dir), ], )