Beispiel #1
0
def create_input_iter(dataset_builder, batch_size, image_size, dtype, train,
                      cache):
  ds = input_pipeline.create_split(
      dataset_builder, batch_size, image_size=image_size, dtype=dtype,
      train=train, cache=cache)
  it = map(prepare_tf_data, ds)
  it = jax_utils.prefetch_to_device(it, 2)
  return it
Beispiel #2
0
def main(argv):
    assert (
        len(argv) == 1
    ), "Please specify arguments via flags. Use --help for instructions"

    assert (getattr(elegy.nets.resnet, FLAGS.model, None)
            is not None), f"{FLAGS.model} is not defined in elegy.nets.resnet"

    assert not os.path.exists(
        FLAGS.output_dir
    ), "Output directory already exists. Delete manually or specify a new one."
    os.makedirs(FLAGS.output_dir)

    # dataset
    dataset_builder = tfds.builder(FLAGS.dataset)
    ds_train = input_pipeline.create_split(
        dataset_builder,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        dtype=FLAGS.dtype,
        train=True,
        cache=FLAGS.cache,
    )
    ds_valid = input_pipeline.create_split(
        dataset_builder,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        dtype=FLAGS.dtype,
        train=False,
        cache=FLAGS.cache,
    )
    N_BATCHES_TRAIN = (dataset_builder.info.splits["train"].num_examples //
                       FLAGS.batch_size)
    N_BATCHES_VALID = (
        dataset_builder.info.splits["validation"].num_examples //
        FLAGS.batch_size)

    # generator that converts tfds dataset batches to jax arrays
    def tfds2jax_generator(tf_ds):
        for batch in tf_ds:
            yield jnp.asarray(batch["image"],
                              dtype=FLAGS.dtype), jax.device_put(
                                  jnp.asarray(batch["label"]))

    # model and optimizer definition
    def build_optimizer(lr,
                        momentum,
                        steps_per_epoch,
                        n_epochs,
                        nesterov,
                        warmup_epochs=5):
        cosine_schedule = optax.cosine_decay_schedule(1,
                                                      decay_steps=n_epochs *
                                                      steps_per_epoch,
                                                      alpha=1e-10)
        warmup_schedule = optax.polynomial_schedule(
            init_value=0.0,
            end_value=1.0,
            power=1,
            transition_steps=warmup_epochs * steps_per_epoch,
        )
        schedule = lambda x: jnp.minimum(cosine_schedule(x), warmup_schedule(x)
                                         )
        optimizer = optax.sgd(lr, momentum, nesterov=nesterov)
        optimizer = optax.chain(optimizer, optax.scale_by_schedule(schedule))
        return optimizer

    module = getattr(elegy.nets.resnet, FLAGS.model)(dtype=FLAGS.dtype)
    model = elegy.Model(
        module,
        loss=[
            elegy.losses.SparseCategoricalCrossentropy(
                from_logits=True, weight=FLAGS.loss_scale),
            elegy.regularizers.GlobalL2(FLAGS.L2_reg / 2 * FLAGS.loss_scale),
        ],
        metrics=elegy.metrics.SparseCategoricalAccuracy(),
        optimizer=build_optimizer(
            FLAGS.base_lr / FLAGS.loss_scale,
            FLAGS.momentum,
            N_BATCHES_TRAIN,
            FLAGS.epochs,
            FLAGS.nesterov,
        ),
    )

    # training
    model.fit(
        x=tfds2jax_generator(ds_train),
        validation_data=tfds2jax_generator(ds_valid),
        epochs=FLAGS.epochs,
        verbose=2,
        steps_per_epoch=N_BATCHES_TRAIN,
        validation_steps=N_BATCHES_VALID,
        callbacks=[
            elegy.callbacks.ModelCheckpoint(FLAGS.output_dir,
                                            save_best_only=True),
            elegy.callbacks.TerminateOnNaN(),
            elegy.callbacks.TensorBoard(logdir=FLAGS.output_dir),
        ],
    )