def train_and_evaluate(config, workdir): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. """ ## get random seed rng = jax.random.PRNGKey(0) ## Get data train_ds, test_ds = get_datasets("cifar10") ## Initializing model and infering dimensions of layers from one example batch model = models.ResNet18(num_classes=10) init_params = model.init( rng, jnp.ones((1, 32, 32, 3)) ) # figure this shape out automatically ? params = init_params solver, solver_param_name = get_solver( FLAGS, config, loss_fun, losses=loss_fun) # losses is not defined yet! params, state = solver.init(params) ## Path to dump results dumpath = create_dumpfile(config, solver_param_name, workdir, "cifar10") summary_writer = tensorboard.SummaryWriter(dumpath) summary_writer.hparams(dict(config)) for epoch in range(1, config.num_epochs + 1): rng, _ = jax.random.split(rng) params, state = train_epoch( config, solver, params, state, train_ds, rng ) test_loss, test_accuracy = eval_model(params, test_ds) train_loss, train_accuracy = eval_model(params, train_ds) print("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss, test_accuracy * 100) print("train epoch: %d, train_loss: %.4f, train_accuracy: %.2f", epoch, train_loss, train_accuracy * 100) logging.info("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss, test_accuracy * 100) summary_writer.scalar("train_loss", train_loss, epoch) summary_writer.scalar("test_loss", loss, epoch) summary_writer.scalar("train_accuracy", train_accuracy, epoch) summary_writer.scalar("test_accuracy", test_accuracy, epoch) summary_writer.flush()
def loss_fun(params, data): preds, new_batch_stats = models.ResNet18(num_classes=10).apply( params, data["image"], mutable=["batch_stats"]) metrics = compute_metrics(preds=preds, labels=data["label"]) return metrics["loss"], (new_batch_stats, preds)
def eval_step(params, data): preds = models.ResNet18(num_classes=10).apply( params, data["image"], train=False, mutable=False) return compute_metrics(preds=preds, labels=data["label"])