def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') train_split = dataset.Split.from_string(FLAGS.train_split) eval_split = dataset.Split.from_string(FLAGS.eval_split) # The total batch size is the batch size accross all hosts and devices. In a # multi-host training setup each host will only see a batch size of # `total_train_batch_size / jax.host_count()`. total_train_batch_size = FLAGS.train_device_batch_size * jax.device_count() num_train_steps = (train_split.num_examples * 90) // total_train_batch_size local_device_count = jax.local_device_count() train_dataset = dataset.load( train_split, batch_dims=[local_device_count, FLAGS.train_device_batch_size]) # For initialization we need the same random key on each device. rng = jax.random.PRNGKey(FLAGS.train_init_random_seed) rng = jnp.broadcast_to(rng, (local_device_count, ) + rng.shape) # Initialization requires an example input. batch = next(train_dataset) params, state, opt_state = jax.pmap(make_initial_state)(rng, batch) eval_every = FLAGS.train_eval_every log_every = FLAGS.train_log_every with time_activity('train'): for step_num in range(num_train_steps): # Take a single training step. params, state, opt_state, train_scalars = (train_step( params, state, opt_state, next(train_dataset))) # By default we do not evaluate during training, but you can configure # this with a flag. if eval_every > 0 and step_num and step_num % eval_every == 0: with time_activity('eval during train'): eval_scalars = evaluate(eval_split, params, state) logging.info( f'[Eval {step_num}/{num_train_steps}] {eval_scalars}') # Log progress at fixed intervals. if step_num and step_num % log_every == 0: train_scalars = jax.tree_map(lambda v: np.mean(v).item(), jax.device_get(train_scalars)) logging.info( f'[Train {step_num}/{num_train_steps}] {train_scalars}') # Once training has finished we run eval one more time to get final results. with time_activity('final eval'): eval_scalars = evaluate(eval_split, params, state) logging.info(f'[Eval FINAL]: {eval_scalars}')
def evaluate( split: dataset.Split, params: hk.Params, state: hk.State, ) -> Scalars: """Evaluates the model at the given params/state.""" # Params/state are sharded per-device during training. We just need the copy # from the first device (since we do not pmap evaluation at the moment). params, state = jax.tree_map(lambda x: x[0], (params, state)) test_dataset = dataset.load(split, batch_dims=[FLAGS.eval_batch_size]) correct = total = 0 for batch in test_dataset: correct += eval_batch(params, state, next(test_dataset)) total += batch['images'].shape[0] return {'top_1_acc': correct.item() / total}