def train_and_evaluate(config, workdir): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. """ ## get random seed rng = jax.random.PRNGKey(0) ## Get data train_ds, test_ds = get_datasets("cifar10") ## Initializing model and infering dimensions of layers from one example batch model = models.ResNet18(num_classes=10) init_params = model.init( rng, jnp.ones((1, 32, 32, 3)) ) # figure this shape out automatically ? params = init_params solver, solver_param_name = get_solver( FLAGS, config, loss_fun, losses=loss_fun) # losses is not defined yet! params, state = solver.init(params) ## Path to dump results dumpath = create_dumpfile(config, solver_param_name, workdir, "cifar10") summary_writer = tensorboard.SummaryWriter(dumpath) summary_writer.hparams(dict(config)) for epoch in range(1, config.num_epochs + 1): rng, _ = jax.random.split(rng) params, state = train_epoch( config, solver, params, state, train_ds, rng ) test_loss, test_accuracy = eval_model(params, test_ds) train_loss, train_accuracy = eval_model(params, train_ds) print("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss, test_accuracy * 100) print("train epoch: %d, train_loss: %.4f, train_accuracy: %.2f", epoch, train_loss, train_accuracy * 100) logging.info("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss, test_accuracy * 100) summary_writer.scalar("train_loss", train_loss, epoch) summary_writer.scalar("test_loss", loss, epoch) summary_writer.scalar("train_accuracy", train_accuracy, epoch) summary_writer.scalar("test_accuracy", test_accuracy, epoch) summary_writer.flush()
def train_and_evaluate(config, workdir): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. """ train_ds, test_ds = get_datasets("mnist") # Get solver solver, solver_param_name = get_solver(FLAGS, config, loss_fun, losses) rng = jax.random.PRNGKey(0) rng, init_rng = jax.random.split(rng) init_params = CNN().init(init_rng, jnp.ones([1, 28, 28, 1]))["params"] params, state = solver.init(init_params) # Full path to dump resultss dumpath = create_dumpfile(config, solver_param_name, workdir, "mnist") summary_writer = tensorboard.SummaryWriter(dumpath) summary_writer.hparams(dict(config)) # Run solver. for epoch in range(1, config.num_epochs + 1): rng, input_rng = jax.random.split(rng) params, state, train_metrics = train_epoch(config, solver, params, state, train_ds, epoch, input_rng) test_loss, test_accuracy = eval_model(params, test_ds) print("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss, test_accuracy * 100) logging.info("eval epoch: %d, loss: %.4f, accuracy: %.2f", epoch, test_loss, test_accuracy * 100) summary_writer.scalar("train_loss", train_metrics["loss"], epoch) summary_writer.scalar("train_accuracy", train_metrics["accuracy"], epoch) summary_writer.scalar("eval_loss", test_loss, epoch) summary_writer.scalar("eval_accuracy", test_accuracy, epoch) summary_writer.flush()