示例#1
0
文件: train.py 项目: YAMWD/dm-haiku
def evaluate(
    split: dataset.Split,
    params: hk.Params,
    state: hk.State,
) -> Scalars:
    """Evaluates the model at the given params/state."""
    if split.num_examples % FLAGS.eval_batch_size:
        raise ValueError(
            f'Eval batch size {FLAGS.eval_batch_size} must be a '
            f'multiple of {split} num examples {split.num_examples}')

    # Params/state are sharded per-device during training. We just need the copy
    # from the first device (since we do not pmap evaluation at the moment).
    params, state = jax.tree_map(lambda x: x[0], (params, state))
    test_dataset = dataset.load(split,
                                is_training=False,
                                batch_dims=[FLAGS.eval_batch_size],
                                transpose=FLAGS.dataset_transpose)
    correct = jnp.array(0)
    total = 0
    for batch in test_dataset:
        correct += eval_batch(params, state, batch)
        total += batch['labels'].shape[0]
    assert total == split.num_examples, total
    return {'top_1_acc': correct.item() / total}
示例#2
0
文件: train.py 项目: ssghost/dm-haiku
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  FLAGS.alsologtostderr = True

  train_split = dataset.Split.from_string(FLAGS.train_split)
  eval_split = dataset.Split.from_string(FLAGS.eval_split)

  # The total batch size is the batch size accross all hosts and devices. In a
  # multi-host training setup each host will only see a batch size of
  # `total_train_batch_size / jax.host_count()`.
  total_train_batch_size = FLAGS.train_device_batch_size * jax.device_count()
  num_train_steps = (
      (train_split.num_examples * FLAGS.train_epochs) // total_train_batch_size)

  local_device_count = jax.local_device_count()
  train_dataset = dataset.load(
      train_split,
      is_training=True,
      batch_dims=[local_device_count, FLAGS.train_device_batch_size],
      bfloat16=FLAGS.train_bfloat16,
      transpose=FLAGS.dataset_transpose)

  # For initialization we need the same random key on each device.
  rng = jax.random.PRNGKey(FLAGS.train_init_random_seed)
  rng = jnp.broadcast_to(rng, (local_device_count,) + rng.shape)
  # Initialization requires an example input.
  batch = next(train_dataset)
  params, state, opt_state = jax.pmap(make_initial_state)(rng, batch)

  eval_every = FLAGS.train_eval_every
  log_every = FLAGS.train_log_every

  with time_activity('train'):
    for step_num in range(num_train_steps):
      # Take a single training step.
      with jax.profiler.StepTraceContext('train', step_num=step_num):
        params, state, opt_state, train_scalars = (
            train_step(params, state, opt_state, next(train_dataset)))

      # By default we do not evaluate during training, but you can configure
      # this with a flag.
      if eval_every > 0 and step_num and step_num % eval_every == 0:
        with time_activity('eval during train'):
          eval_scalars = evaluate(eval_split, params, state)
        logging.info('[Eval %s/%s] %s', step_num, num_train_steps, eval_scalars)

      # Log progress at fixed intervals.
      if step_num and step_num % log_every == 0:
        train_scalars = jax.tree_map(lambda v: np.mean(v).item(),
                                     jax.device_get(train_scalars))
        logging.info('[Train %s/%s] %s',
                     step_num, num_train_steps, train_scalars)

  # Once training has finished we run eval one more time to get final results.
  with time_activity('final eval'):
    eval_scalars = evaluate(eval_split, params, state)
  logging.info('[Eval FINAL]: %s', eval_scalars)
示例#3
0
文件: train.py 项目: jenkspt/dm-haiku
def evaluate(
    split: dataset.Split,
    params: hk.Params,
    state: hk.State,
) -> Scalars:
  """Evaluates the model at the given params/state."""
  # Params/state are sharded per-device during training. We just need the copy
  # from the first device (since we do not pmap evaluation at the moment).
  params, state = jax.tree_map(lambda x: x[0], (params, state))
  test_dataset = dataset.load(split, batch_dims=[FLAGS.eval_batch_size])
  correct = jnp.array(0)
  total = 0
  for batch in test_dataset:
    correct += eval_batch(params, state, next(test_dataset))
    total += batch['images'].shape[0]
  return {'top_1_acc': correct.item() / total}
示例#4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    FLAGS.alsologtostderr = True

    train_split = dataset.Split.from_string(FLAGS.train_split)
    eval_split = dataset.Split.from_string(FLAGS.eval_split)

    # The total batch size is the batch size accross all hosts and devices. In a
    # multi-host training setup each host will only see a batch size of
    # `total_train_batch_size / jax.host_count()`.
    total_train_batch_size = FLAGS.train_device_batch_size * jax.device_count()
    num_train_steps = ((train_split.num_examples * FLAGS.train_epochs) //
                       total_train_batch_size)

    local_device_count = jax.local_device_count()
    train_dataset = dataset.load(
        train_split,
        is_training=True,
        batch_dims=[local_device_count, FLAGS.train_device_batch_size],
        dtype=get_policy().compute_dtype,
        transpose=FLAGS.dataset_transpose,
        zeros=FLAGS.dataset_zeros)

    # Assign mixed precision policies to modules. Note that when training in f16
    # we keep BatchNorm in  full precision. When training with bf16 you can often
    # use bf16 for BatchNorm.
    mp_policy = get_policy()
    bn_policy = get_bn_policy().with_output_dtype(mp_policy.compute_dtype)
    # NOTE: The order we call `set_policy` doesn't matter, when a method on a
    # class is called the policy for that class will be applied, or it will
    # inherit the policy from its parent module.
    hk.mixed_precision.set_policy(hk.BatchNorm, bn_policy)
    hk.mixed_precision.set_policy(hk.nets.ResNet50, mp_policy)

    if jax.default_backend() == 'gpu':
        # TODO(tomhennigan): This could be removed if XLA:GPU's allocator changes.
        train_dataset = dataset.double_buffer(train_dataset)

    # For initialization we need the same random key on each device.
    rng = jax.random.PRNGKey(FLAGS.train_init_random_seed)
    rng = jnp.broadcast_to(rng, (local_device_count, ) + rng.shape)
    # Initialization requires an example input.
    batch = next(train_dataset)
    train_state = jax.pmap(initial_state)(rng, batch)

    # Print a useful summary of the execution of our module.
    summary = hk.experimental.tabulate(train_step)(train_state, batch)
    for line in summary.split('\n'):
        logging.info(line)

    eval_every = FLAGS.train_eval_every
    log_every = FLAGS.train_log_every

    with time_activity('train'):
        for step_num in range(num_train_steps):
            # Take a single training step.
            with jax.profiler.StepTraceAnnotation('train', step_num=step_num):
                batch = next(train_dataset)
                train_state, train_scalars = train_step(train_state, batch)

            # By default we do not evaluate during training, but you can configure
            # this with a flag.
            if eval_every > 0 and step_num and step_num % eval_every == 0:
                with time_activity('eval during train'):
                    eval_scalars = evaluate(eval_split, train_state.params,
                                            train_state.state)
                logging.info('[Eval %s/%s] %s', step_num, num_train_steps,
                             eval_scalars)

            # Log progress at fixed intervals.
            if step_num and step_num % log_every == 0:
                train_scalars = jax.tree_map(lambda v: np.mean(v).item(),
                                             jax.device_get(train_scalars))
                logging.info('[Train %s/%s] %s', step_num, num_train_steps,
                             train_scalars)

    # Once training has finished we run eval one more time to get final results.
    with time_activity('final eval'):
        eval_scalars = evaluate(eval_split, train_state.params,
                                train_state.state)
    logging.info('[Eval FINAL]: %s', eval_scalars)