Ejemplo n.º 1
0
def train():
  """Train model."""
  batch_size = FLAGS.batch_size
  n_devices = jax.device_count()
  if jax.host_count() > 1:
    raise ValueError('PixelCNN++ example should not be run on more than 1 host'
                     ' (for now)')
  if batch_size % n_devices > 0:
    raise ValueError('Batch size must be divisible by the number of devices')

  train_summary_writer, eval_summary_writer = get_summary_writers()

  # Load dataset
  data_source = input_pipeline.DataSource(
      train_batch_size=batch_size, eval_batch_size=batch_size)
  train_ds = data_source.train_ds
  eval_ds = data_source.eval_ds

  # Create dataset batch iterators
  train_iter = iter(train_ds)
  eval_iter = iter(eval_ds)

  # Compute steps per epoch and nb of eval steps
  steps_per_epoch = data_source.TRAIN_IMAGES // batch_size
  steps_per_eval = data_source.EVAL_IMAGES // batch_size
  steps_per_checkpoint = steps_per_epoch * 10
  num_steps = steps_per_epoch * FLAGS.num_epochs

  # Create the model using data-dependent initialization. Don't shard the init
  # batch.
  assert FLAGS.init_batch_size <= batch_size
  init_batch = next(train_iter)['image']._numpy()[:FLAGS.init_batch_size]

  rng = random.PRNGKey(FLAGS.rng)
  rng, init_rng = random.split(rng)
  rng, dropout_rng = random.split(rng)

  initial_variables = model().init({
      'params': init_rng,
      'dropout': dropout_rng
  }, init_batch)['params']
  optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995)
  optimizer = optimizer_def.create(initial_variables)

  optimizer, ema = restore_checkpoint(optimizer, initial_variables)
  ema = initial_variables
  step_offset = int(optimizer.state.step)

  optimizer, ema = jax_utils.replicate((optimizer, ema))

  # Learning rate schedule
  learning_rate_fn = lambda step: FLAGS.learning_rate * FLAGS.lr_decay ** step

  # pmap the train and eval functions
  p_train_step = jax.pmap(
      partial(train_step, learning_rate_fn), axis_name='batch')
  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  # Gather metrics
  train_metrics = []

  for step, batch in zip(range(step_offset, num_steps), train_iter):
    # Load and shard the TF batch
    batch = load_and_shard_tf_batch(batch)

    # Generate a PRNG key that will be rolled into the batch.
    rng, step_rng = random.split(rng)
    sharded_rngs = common_utils.shard_prng_key(step_rng)

    # Train step
    optimizer, ema, metrics = p_train_step(optimizer, ema, batch, sharded_rngs)
    train_metrics.append(metrics)

    if (step + 1) % steps_per_epoch == 0:
      epoch = step // steps_per_epoch
      # We've finished an epoch
      train_metrics = common_utils.get_metrics(train_metrics)
      # Get training epoch summary for logging
      train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
      # Send stats to Tensorboard
      for key, vals in train_metrics.items():
        for i, val in enumerate(vals):
          train_summary_writer.scalar(key, val, step - len(vals) + i + 1)
      # Reset train metrics
      train_metrics = []

      # Evaluation
      eval_metrics = []
      for _ in range(steps_per_eval):
        eval_batch = next(eval_iter)
        # Load and shard the TF batch
        eval_batch = load_and_shard_tf_batch(eval_batch)
        # Step
        metrics = p_eval_step(ema, eval_batch)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      # Get eval epoch summary for logging
      eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)

      # Log epoch summary
      logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch,
                   train_summary['loss'], eval_summary['loss'])

      eval_summary_writer.scalar('loss', eval_summary['loss'], step)
      train_summary_writer.flush()
      eval_summary_writer.flush()

    if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
      save_checkpoint(optimizer, ema, step)
Ejemplo n.º 2
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    batch_size = config.batch_size
    n_devices = jax.device_count()
    if jax.host_count() > 1:
        raise ValueError(
            'PixelCNN++ example should not be run on more than 1 host'
            ' (for now)')
    if batch_size % n_devices > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    train_summary_writer, eval_summary_writer = get_summary_writers(workdir)
    # Load dataset
    data_source = input_pipeline.DataSource(config)
    train_ds = data_source.train_ds
    eval_ds = data_source.eval_ds
    steps_per_epoch = data_source.ds_info.splits[
        'train'].num_examples // config.batch_size
    # Create dataset batch iterators
    train_iter = iter(train_ds)
    num_train_steps = train_ds.cardinality().numpy()
    steps_per_checkpoint = 1000

    # Create the model using data-dependent initialization. Don't shard the init
    # batch.
    assert config.init_batch_size <= batch_size
    init_batch = next(train_iter)['image']._numpy()[:config.init_batch_size]

    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng, dropout_rng = jax.random.split(rng, 3)

    initial_variables = model(config).init(
        {
            'params': init_rng,
            'dropout': dropout_rng
        }, init_batch)['params']
    optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995)
    optimizer = optimizer_def.create(initial_variables)

    optimizer, ema = restore_checkpoint(workdir, optimizer, initial_variables)
    ema = initial_variables
    step_offset = int(optimizer.state.step)

    optimizer, ema = jax_utils.replicate((optimizer, ema))

    # Learning rate schedule
    learning_rate_fn = lambda step: config.learning_rate * config.lr_decay**step

    # pmap the train and eval functions
    p_train_step = jax.pmap(functools.partial(train_step, config,
                                              learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step, config=config),
                           axis_name='batch')

    # Gather metrics
    train_metrics = []

    for step, batch in zip(range(step_offset, num_train_steps), train_iter):
        # Load and shard the TF batch
        batch = load_and_shard_tf_batch(batch)

        # Generate a PRNG key that will be rolled into the batch.
        rng, step_rng = jax.random.split(rng)
        sharded_rngs = common_utils.shard_prng_key(step_rng)

        # Train step
        optimizer, ema, metrics = p_train_step(optimizer, ema, batch,
                                               sharded_rngs)
        train_metrics.append(metrics)

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                            step)

        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            # We've finished an epoch
            train_metrics = common_utils.get_metrics(train_metrics)
            # Get training epoch summary for logging
            train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
            # Send stats to Tensorboard
            for key, vals in train_metrics.items():
                for i, val in enumerate(vals):
                    train_summary_writer.scalar(key, val,
                                                step - len(vals) + i + 1)
            # Reset train metrics
            train_metrics = []

            # Evaluation
            eval_metrics = []
            for eval_batch in eval_ds:
                # Load and shard the TF batch
                eval_batch = load_and_shard_tf_batch(eval_batch)
                # Step
                metrics = p_eval_step(ema, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            # Get eval epoch summary for logging
            eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)

            # Log epoch summary
            logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch,
                         train_summary['loss'], eval_summary['loss'])

            eval_summary_writer.scalar('loss', eval_summary['loss'], step)
            train_summary_writer.flush()
            eval_summary_writer.flush()

        if (step +
                1) % steps_per_checkpoint == 0 or step + 1 == num_train_steps:
            save_checkpoint(workdir, optimizer, ema, step)
Ejemplo n.º 3
0
def train(pcnn_module,
          model_dir,
          batch_size,
          init_batch_size,
          num_epochs,
          learning_rate,
          decay_rate,
          run_seed=0):
    """Train model."""
    if jax.host_count() > 1:
        raise ValueError(
            'PixelCNN++ example should not be run on more than 1 host'
            ' (for now)')

    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = model_dir + '/log/' + current_time
    train_log_dir = log_dir + '/train'
    eval_log_dir = log_dir + '/eval'
    train_summary_writer = tensorboard.SummaryWriter(train_log_dir)
    eval_summary_writer = tensorboard.SummaryWriter(eval_log_dir)

    rng = random.PRNGKey(run_seed)

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    # Load dataset
    data_source = input_pipeline.DataSource(train_batch_size=batch_size,
                                            eval_batch_size=batch_size)
    train_ds = data_source.train_ds
    eval_ds = data_source.eval_ds

    # Create dataset batch iterators
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    # Compute steps per epoch and nb of eval steps
    steps_per_epoch = data_source.TRAIN_IMAGES // batch_size
    steps_per_eval = data_source.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    base_learning_rate = learning_rate

    # Create the model using data-dependent initialization. Don't shard the init
    # batch.
    assert init_batch_size <= batch_size
    init_batch = next(train_iter)['image']._numpy()[:init_batch_size]
    model = create_model(rng, init_batch, pcnn_module)
    ema = model.params
    optimizer = create_optimizer(model, base_learning_rate)
    del model  # don't keep a copy of the initial model

    optimizer, ema = restore_checkpoint(optimizer, ema)
    step_offset = int(optimizer.state.step)
    optimizer, ema = jax_utils.replicate((optimizer, ema))

    # Learning rate schedule
    learning_rate_fn = lambda step: base_learning_rate * decay_rate**step

    # pmap the train and eval functions
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # Gather metrics
    train_metrics = []
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        # Generate a PRNG key that will be rolled into the batch
        rng, step_key = jax.random.split(rng)
        # Load and shard the TF batch
        batch = load_and_shard_tf_batch(batch)
        # Shard the step PRNG key
        sharded_keys = common_utils.shard_prng_key(step_key)

        # Train step
        optimizer, ema, metrics = p_train_step(optimizer, ema, batch,
                                               sharded_keys)
        train_metrics.append(metrics)

        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            # We've finished an epoch
            train_metrics = common_utils.get_metrics(train_metrics)
            # Get training epoch summary for logging
            train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
            # Send stats to Tensorboard
            for key, vals in train_metrics.items():
                for i, val in enumerate(vals):
                    train_summary_writer.scalar(key, val,
                                                step - len(vals) + i + 1)
            # Reset train metrics
            train_metrics = []

            # Evaluation
            model_ema = optimizer.target.replace(params=ema)
            eval_metrics = []
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                # Load and shard the TF batch
                eval_batch = load_and_shard_tf_batch(eval_batch)
                # Step
                metrics = p_eval_step(model_ema, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            # Get eval epoch summary for logging
            eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)

            # Log epoch summary
            logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch,
                         train_summary['loss'], eval_summary['loss'])

            eval_summary_writer.scalar('loss', eval_summary['loss'], step)
            train_summary_writer.flush()
            eval_summary_writer.flush()

        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            save_checkpoint(optimizer, ema)