예제 #1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    train_split = dataset.Split.from_string(FLAGS.train_split)
    eval_split = dataset.Split.from_string(FLAGS.eval_split)

    # The total batch size is the batch size accross all hosts and devices. In a
    # multi-host training setup each host will only see a batch size of
    # `total_train_batch_size / jax.host_count()`.
    total_train_batch_size = FLAGS.train_device_batch_size * jax.device_count()
    num_train_steps = (train_split.num_examples * 90) // total_train_batch_size

    local_device_count = jax.local_device_count()
    train_dataset = dataset.load(
        train_split,
        batch_dims=[local_device_count, FLAGS.train_device_batch_size])

    # For initialization we need the same random key on each device.
    rng = jax.random.PRNGKey(FLAGS.train_init_random_seed)
    rng = jnp.broadcast_to(rng, (local_device_count, ) + rng.shape)
    # Initialization requires an example input.
    batch = next(train_dataset)
    params, state, opt_state = jax.pmap(make_initial_state)(rng, batch)

    eval_every = FLAGS.train_eval_every
    log_every = FLAGS.train_log_every

    with time_activity('train'):
        for step_num in range(num_train_steps):
            # Take a single training step.
            params, state, opt_state, train_scalars = (train_step(
                params, state, opt_state, next(train_dataset)))

            # By default we do not evaluate during training, but you can configure
            # this with a flag.
            if eval_every > 0 and step_num and step_num % eval_every == 0:
                with time_activity('eval during train'):
                    eval_scalars = evaluate(eval_split, params, state)
                logging.info(
                    f'[Eval {step_num}/{num_train_steps}] {eval_scalars}')

            # Log progress at fixed intervals.
            if step_num and step_num % log_every == 0:
                train_scalars = jax.tree_map(lambda v: np.mean(v).item(),
                                             jax.device_get(train_scalars))
                logging.info(
                    f'[Train {step_num}/{num_train_steps}] {train_scalars}')

    # Once training has finished we run eval one more time to get final results.
    with time_activity('final eval'):
        eval_scalars = evaluate(eval_split, params, state)
    logging.info(f'[Eval FINAL]: {eval_scalars}')
예제 #2
0
def evaluate(
    split: dataset.Split,
    params: hk.Params,
    state: hk.State,
) -> Scalars:
    """Evaluates the model at the given params/state."""
    # Params/state are sharded per-device during training. We just need the copy
    # from the first device (since we do not pmap evaluation at the moment).
    params, state = jax.tree_map(lambda x: x[0], (params, state))
    test_dataset = dataset.load(split, batch_dims=[FLAGS.eval_batch_size])
    correct = total = 0
    for batch in test_dataset:
        correct += eval_batch(params, state, next(test_dataset))
        total += batch['images'].shape[0]
    return {'top_1_acc': correct.item() / total}