def train_and_evaluate(config, workdir):
  """Execute model training and evaluation loop.

  Args:
      config: Hyperparameter configuration for training and evaluation.
      workdir: Directory where the tensorboard summaries are written to.
  """
  ## get random seed
  rng = jax.random.PRNGKey(0)
  ## Get data
  train_ds, test_ds = get_datasets("cifar10")

  ## Initializing model and infering dimensions of layers from one example batch
  model = models.ResNet18(num_classes=10)
  init_params = model.init(
      rng, jnp.ones((1, 32, 32, 3))
  )  # figure this shape out automatically ?
  params = init_params

  solver, solver_param_name = get_solver(
      FLAGS, config, loss_fun, losses=loss_fun)  # losses is not defined yet!
  params, state = solver.init(params)

  ## Path to dump results
  dumpath = create_dumpfile(config, solver_param_name, workdir, "cifar10")

  summary_writer = tensorboard.SummaryWriter(dumpath)
  summary_writer.hparams(dict(config))

  for epoch in range(1, config.num_epochs + 1):
    rng, _ = jax.random.split(rng)
    params, state = train_epoch(
        config, solver, params, state, train_ds, rng
    )
    test_loss, test_accuracy = eval_model(params, test_ds)
    train_loss, train_accuracy = eval_model(params, train_ds)
    print("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss,
          test_accuracy * 100)
    print("train epoch: %d, train_loss: %.4f, train_accuracy: %.2f", epoch,
          train_loss, train_accuracy * 100)
    logging.info("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss,
                 test_accuracy * 100)
    summary_writer.scalar("train_loss", train_loss, epoch)
    summary_writer.scalar("test_loss", loss, epoch)
    summary_writer.scalar("train_accuracy", train_accuracy, epoch)
    summary_writer.scalar("test_accuracy", test_accuracy, epoch)

  summary_writer.flush()
def loss_fun(params, data):
  preds, new_batch_stats = models.ResNet18(num_classes=10).apply(
      params, data["image"], mutable=["batch_stats"])
  metrics = compute_metrics(preds=preds, labels=data["label"])
  return metrics["loss"], (new_batch_stats, preds)
def eval_step(params, data):
  preds = models.ResNet18(num_classes=10).apply(
      params, data["image"], train=False, mutable=False)
  return compute_metrics(preds=preds, labels=data["label"])