def train_and_evaluate(config, workdir):
  """Execute model training and evaluation loop.

  Args:
      config: Hyperparameter configuration for training and evaluation.
      workdir: Directory where the tensorboard summaries are written to.
  """
  ## get random seed
  rng = jax.random.PRNGKey(0)
  ## Get data
  train_ds, test_ds = get_datasets("cifar10")

  ## Initializing model and infering dimensions of layers from one example batch
  model = models.ResNet18(num_classes=10)
  init_params = model.init(
      rng, jnp.ones((1, 32, 32, 3))
  )  # figure this shape out automatically ?
  params = init_params

  solver, solver_param_name = get_solver(
      FLAGS, config, loss_fun, losses=loss_fun)  # losses is not defined yet!
  params, state = solver.init(params)

  ## Path to dump results
  dumpath = create_dumpfile(config, solver_param_name, workdir, "cifar10")

  summary_writer = tensorboard.SummaryWriter(dumpath)
  summary_writer.hparams(dict(config))

  for epoch in range(1, config.num_epochs + 1):
    rng, _ = jax.random.split(rng)
    params, state = train_epoch(
        config, solver, params, state, train_ds, rng
    )
    test_loss, test_accuracy = eval_model(params, test_ds)
    train_loss, train_accuracy = eval_model(params, train_ds)
    print("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss,
          test_accuracy * 100)
    print("train epoch: %d, train_loss: %.4f, train_accuracy: %.2f", epoch,
          train_loss, train_accuracy * 100)
    logging.info("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss,
                 test_accuracy * 100)
    summary_writer.scalar("train_loss", train_loss, epoch)
    summary_writer.scalar("test_loss", loss, epoch)
    summary_writer.scalar("train_accuracy", train_accuracy, epoch)
    summary_writer.scalar("test_accuracy", test_accuracy, epoch)

  summary_writer.flush()
def train_and_evaluate(config, workdir):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.
  """
    train_ds, test_ds = get_datasets("mnist")
    # Get solver
    solver, solver_param_name = get_solver(FLAGS, config, loss_fun, losses)

    rng = jax.random.PRNGKey(0)
    rng, init_rng = jax.random.split(rng)

    init_params = CNN().init(init_rng, jnp.ones([1, 28, 28, 1]))["params"]
    params, state = solver.init(init_params)

    # Full path to dump resultss
    dumpath = create_dumpfile(config, solver_param_name, workdir, "mnist")

    summary_writer = tensorboard.SummaryWriter(dumpath)
    summary_writer.hparams(dict(config))

    # Run solver.
    for epoch in range(1, config.num_epochs + 1):
        rng, input_rng = jax.random.split(rng)

        params, state, train_metrics = train_epoch(config, solver, params,
                                                   state, train_ds, epoch,
                                                   input_rng)
        test_loss, test_accuracy = eval_model(params, test_ds)

        print("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss,
              test_accuracy * 100)
        logging.info("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch,
                     test_loss, test_accuracy * 100)

        summary_writer.scalar("train_loss", train_metrics["loss"], epoch)
        summary_writer.scalar("train_accuracy", train_metrics["accuracy"],
                              epoch)
        summary_writer.scalar("eval_loss", test_loss, epoch)
        summary_writer.scalar("eval_accuracy", test_accuracy, epoch)

    summary_writer.flush()