def evaluate( split: dataset.Split, params: hk.Params, state: hk.State, ) -> Scalars: """Evaluates the model at the given params/state.""" if split.num_examples % FLAGS.eval_batch_size: raise ValueError( f'Eval batch size {FLAGS.eval_batch_size} must be a ' f'multiple of {split} num examples {split.num_examples}') # Params/state are sharded per-device during training. We just need the copy # from the first device (since we do not pmap evaluation at the moment). params, state = jax.tree_map(lambda x: x[0], (params, state)) test_dataset = dataset.load(split, is_training=False, batch_dims=[FLAGS.eval_batch_size], transpose=FLAGS.dataset_transpose) correct = jnp.array(0) total = 0 for batch in test_dataset: correct += eval_batch(params, state, batch) total += batch['labels'].shape[0] assert total == split.num_examples, total return {'top_1_acc': correct.item() / total}
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') FLAGS.alsologtostderr = True train_split = dataset.Split.from_string(FLAGS.train_split) eval_split = dataset.Split.from_string(FLAGS.eval_split) # The total batch size is the batch size accross all hosts and devices. In a # multi-host training setup each host will only see a batch size of # `total_train_batch_size / jax.host_count()`. total_train_batch_size = FLAGS.train_device_batch_size * jax.device_count() num_train_steps = ( (train_split.num_examples * FLAGS.train_epochs) // total_train_batch_size) local_device_count = jax.local_device_count() train_dataset = dataset.load( train_split, is_training=True, batch_dims=[local_device_count, FLAGS.train_device_batch_size], bfloat16=FLAGS.train_bfloat16, transpose=FLAGS.dataset_transpose) # For initialization we need the same random key on each device. rng = jax.random.PRNGKey(FLAGS.train_init_random_seed) rng = jnp.broadcast_to(rng, (local_device_count,) + rng.shape) # Initialization requires an example input. batch = next(train_dataset) params, state, opt_state = jax.pmap(make_initial_state)(rng, batch) eval_every = FLAGS.train_eval_every log_every = FLAGS.train_log_every with time_activity('train'): for step_num in range(num_train_steps): # Take a single training step. with jax.profiler.StepTraceContext('train', step_num=step_num): params, state, opt_state, train_scalars = ( train_step(params, state, opt_state, next(train_dataset))) # By default we do not evaluate during training, but you can configure # this with a flag. if eval_every > 0 and step_num and step_num % eval_every == 0: with time_activity('eval during train'): eval_scalars = evaluate(eval_split, params, state) logging.info('[Eval %s/%s] %s', step_num, num_train_steps, eval_scalars) # Log progress at fixed intervals. if step_num and step_num % log_every == 0: train_scalars = jax.tree_map(lambda v: np.mean(v).item(), jax.device_get(train_scalars)) logging.info('[Train %s/%s] %s', step_num, num_train_steps, train_scalars) # Once training has finished we run eval one more time to get final results. with time_activity('final eval'): eval_scalars = evaluate(eval_split, params, state) logging.info('[Eval FINAL]: %s', eval_scalars)
def evaluate( split: dataset.Split, params: hk.Params, state: hk.State, ) -> Scalars: """Evaluates the model at the given params/state.""" # Params/state are sharded per-device during training. We just need the copy # from the first device (since we do not pmap evaluation at the moment). params, state = jax.tree_map(lambda x: x[0], (params, state)) test_dataset = dataset.load(split, batch_dims=[FLAGS.eval_batch_size]) correct = jnp.array(0) total = 0 for batch in test_dataset: correct += eval_batch(params, state, next(test_dataset)) total += batch['images'].shape[0] return {'top_1_acc': correct.item() / total}
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') FLAGS.alsologtostderr = True train_split = dataset.Split.from_string(FLAGS.train_split) eval_split = dataset.Split.from_string(FLAGS.eval_split) # The total batch size is the batch size accross all hosts and devices. In a # multi-host training setup each host will only see a batch size of # `total_train_batch_size / jax.host_count()`. total_train_batch_size = FLAGS.train_device_batch_size * jax.device_count() num_train_steps = ((train_split.num_examples * FLAGS.train_epochs) // total_train_batch_size) local_device_count = jax.local_device_count() train_dataset = dataset.load( train_split, is_training=True, batch_dims=[local_device_count, FLAGS.train_device_batch_size], dtype=get_policy().compute_dtype, transpose=FLAGS.dataset_transpose, zeros=FLAGS.dataset_zeros) # Assign mixed precision policies to modules. Note that when training in f16 # we keep BatchNorm in full precision. When training with bf16 you can often # use bf16 for BatchNorm. mp_policy = get_policy() bn_policy = get_bn_policy().with_output_dtype(mp_policy.compute_dtype) # NOTE: The order we call `set_policy` doesn't matter, when a method on a # class is called the policy for that class will be applied, or it will # inherit the policy from its parent module. hk.mixed_precision.set_policy(hk.BatchNorm, bn_policy) hk.mixed_precision.set_policy(hk.nets.ResNet50, mp_policy) if jax.default_backend() == 'gpu': # TODO(tomhennigan): This could be removed if XLA:GPU's allocator changes. train_dataset = dataset.double_buffer(train_dataset) # For initialization we need the same random key on each device. rng = jax.random.PRNGKey(FLAGS.train_init_random_seed) rng = jnp.broadcast_to(rng, (local_device_count, ) + rng.shape) # Initialization requires an example input. batch = next(train_dataset) train_state = jax.pmap(initial_state)(rng, batch) # Print a useful summary of the execution of our module. summary = hk.experimental.tabulate(train_step)(train_state, batch) for line in summary.split('\n'): logging.info(line) eval_every = FLAGS.train_eval_every log_every = FLAGS.train_log_every with time_activity('train'): for step_num in range(num_train_steps): # Take a single training step. with jax.profiler.StepTraceAnnotation('train', step_num=step_num): batch = next(train_dataset) train_state, train_scalars = train_step(train_state, batch) # By default we do not evaluate during training, but you can configure # this with a flag. if eval_every > 0 and step_num and step_num % eval_every == 0: with time_activity('eval during train'): eval_scalars = evaluate(eval_split, train_state.params, train_state.state) logging.info('[Eval %s/%s] %s', step_num, num_train_steps, eval_scalars) # Log progress at fixed intervals. if step_num and step_num % log_every == 0: train_scalars = jax.tree_map(lambda v: np.mean(v).item(), jax.device_get(train_scalars)) logging.info('[Train %s/%s] %s', step_num, num_train_steps, train_scalars) # Once training has finished we run eval one more time to get final results. with time_activity('final eval'): eval_scalars = evaluate(eval_split, train_state.params, train_state.state) logging.info('[Eval FINAL]: %s', eval_scalars)