Beispiel #1
0
    def test_train_one_step(self):
        """Tests training loop over one step."""
        iterator = self._dataset.get_train()
        batch = next(iterator)

        state = jax_utils.replicate(self._state)
        optimizer = jax_utils.replicate(self._optimizer.create(self._model))

        self._rng, step_key = jax.random.split(self._rng)
        batch = training._shard_batch(batch)
        sharded_keys = common_utils.shard_prng_key(step_key)

        p_train_step = jax.pmap(functools.partial(
            training.train_step, learning_rate_fn=self._learning_rate_fn),
                                axis_name='batch')
        _, _, loss, gradient_norm = p_train_step(optimizer, batch,
                                                 sharded_keys, state)

        loss = jnp.mean(loss)
        gradient_norm = jax_utils.unreplicate(gradient_norm)

        with self.subTest(name='test_loss_range'):
            self.assertBetween(loss, self._min_loss, self._max_loss)

        with self.subTest(name='test_gradient_norm'):
            self.assertGreaterEqual(gradient_norm, 0)
Beispiel #2
0
def train_for_one_epoch(
    dataset_source: dataset_source_lib.DatasetSource,
    optimizer: flax.optim.Optimizer, state: flax.nn.Collection,
    prng_key: jnp.ndarray, pmapped_train_step: _TrainStep,
    pmapped_update_ema: Optional[_EMAUpdateStep],
    moving_averages: Optional[efficientnet_optim.ExponentialMovingAverage],
    summary_writer: tensorboard.SummaryWriter
) -> Tuple[flax.optim.Optimizer, flax.nn.Collection,
           Optional[efficientnet_optim.ExponentialMovingAverage]]:
  """Trains the model for one epoch.

  Args:
    dataset_source: Container for the training dataset.
    optimizer: The optimizer targeting the model to train.
    state: Current state associated with the model (contains the batch norm MA).
    prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual
      dropout mask). Is not used for shuffling the dataset.
    pmapped_train_step: A pmapped version of the `train_step` function (see its
      documentation for more details).
    pmapped_update_ema: Function to update the parameter moving average. Can be
      None if we don't use EMA.
    moving_averages: Parameters moving average if used.
    summary_writer: A Tensorboard SummaryWriter to use to log metrics.

  Returns:
    The updated optimizer (with the associated updated model), state and PRNG
      key.
  """
  start_time = time.time()
  cnt = 0
  train_metrics = []
  for batch in dataset_source.get_train(use_augmentations=True):
    # Generate a PRNG key that will be rolled into the batch.
    step_key = jax.random.fold_in(prng_key, optimizer.state.step[0])
    # Load and shard the TF batch.
    batch = tensorflow_to_numpy(batch)
    batch = shard_batch(batch)
    # Shard the step PRNG key.
    sharded_keys = common_utils.shard_prng_key(step_key)

    optimizer, state, metrics, lr = pmapped_train_step(
        optimizer, state, batch, sharded_keys)
    cnt += 1

    if moving_averages is not None:
      moving_averages = pmapped_update_ema(optimizer, state, moving_averages)

    train_metrics.append(metrics)
  train_metrics = common_utils.get_metrics(train_metrics)
  # Get training epoch summary for logging.
  train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
  train_summary['learning_rate'] = lr[0]
  current_step = int(optimizer.state.step[0])
  info = 'Whole training step done in {} ({} steps)'.format(
      time.time()-start_time, cnt)
  logging.info(info)
  for metric_name, metric_value in train_summary.items():
    summary_writer.scalar(metric_name, metric_value, current_step)
  summary_writer.flush()
  return optimizer, state, moving_averages
Beispiel #3
0
 def initial_state(self):
     return TrainState(
         history=self,
         rng=common_utils.shard_prng_key(
             jax.random.PRNGKey(np.random.randint(2 ** 16))
         ),
         step=None,
         metrics=None,
     )
Beispiel #4
0
def update_preconditioner(config, optimizer, p_update_grad_vars, rng, state,
                          train_iter):
    """Computes preconditioner state using samples from dataloader."""
    # TODO(basv): support multiple hosts.
    values = jax.tree_map(jnp.zeros_like, optimizer.target)

    eps = config.precon_est_eps
    n_batches = config.precon_est_batches
    for _ in range(n_batches):
        rng, est_key = jax.random.split(rng)
        batch = next(train_iter)
        batch = input_pipeline.load_and_shard_tf_batch(config, batch)
        if not config.debug_run:
            # Shard the step PRNG key
            sharded_keys = common_utils.shard_prng_key(est_key)
        else:
            sharded_keys = est_key
        values = p_update_grad_vars(optimizer, state, batch, sharded_keys,
                                    values)
    stds = jax.tree_map(
        lambda v: jnp.sqrt(eps + (1 / n_batches) * jnp.mean(v)), values)
    std_min = jnp.min(jnp.asarray(jax.tree_leaves(stds)))
    # TODO(basv): verify preconditioner estimate.
    new_precon = jax.tree_map(lambda s, x: jnp.ones_like(x) * (s / std_min),
                              stds, optimizer.target)

    def convert_momentum(
        new_precon,
        state,
    ):
        """Converts momenta to new preconditioner."""
        if config.weight_norm == 'learned':
            state = state.direction_state
        old_precon = state.preconditioner
        momentum = state.momentum

        m_c = jnp.power(old_precon, -.5) * momentum
        m = jnp.power(new_precon, .5) * m_c
        return m

    # TODO(basv): verify momentum convert.
    new_momentum = jax.tree_map(convert_momentum, new_precon,
                                optimizer.state.param_states)
    # TODO(basv): verify this is replaced correctly, check replicated.
    optimizer = replace_param_state(config,
                                    optimizer,
                                    preconditioner=new_precon,
                                    momentum=new_momentum)
    return optimizer, rng
def train_for_one_epoch(
    dataset_source,
    optimizer, state,
    prng_key, pmapped_train_step,
    summary_writer
):
  """Trains the model for one epoch.

  Args:
    dataset_source: Container for the training dataset.
    optimizer: The optimizer targeting the model to train.
    state: Current state associated with the model (contains the batch norm MA).
    prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual
      dropout mask). Is not used for shuffling the dataset.
    pmapped_train_step: A pmapped version of the `train_step` function (see its
      documentation for more details).
    summary_writer: A Tensorboard SummaryWriter to use to log metrics.

  Returns:
    The updated optimizer (with the associated updated model), state and PRNG
      key.
  """
  train_metrics = []
  for batch in dataset_source.get_train(use_augmentations=True):
    # Generate a PRNG key that will be rolled into the batch.
    step_key, prng_key = jax.random.split(prng_key)
    # Load and shard the TF batch.
    batch = tensorflow_to_numpy(batch)
    batch = shard_batch(batch)
    # Shard the step PRNG key.
    sharded_keys = common_utils.shard_prng_key(step_key)

    optimizer, state, metrics, lr = pmapped_train_step(
        optimizer, state, batch, sharded_keys)
    train_metrics.append(metrics)
  train_metrics = common_utils.get_metrics(train_metrics)
  # Get training epoch summary for logging.
  train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
  train_summary['learning_rate'] = lr[0]
  current_step = int(optimizer.state.step[0])
  for metric_name, metric_value in train_summary.items():
    summary_writer.scalar(metric_name, metric_value, current_step)
  summary_writer.flush()
  return optimizer, state, prng_key
Beispiel #6
0
    async def dalle(self, ctx: commands.Context, *, prompt: str):
        prompts = [
            "sunset over a lake in the mountains",
            "the Eiffel tower landing on the moon"
        ]
        tokenized_prompts = processor(prompts)
        tokenized_prompt = replicate(tokenized_prompts)

        # generate images
        images = []
        for i in trange(max(n_predictions // jax.device_count(), 1)):
            # get a new key
            key, subkey = jax.random.split(key)
            # generate images
            encoded_images = p_generate(
                tokenized_prompt,
                shard_prng_key(subkey),
                params,
                gen_top_k,
                gen_top_p,
                temperature,
                cond_scale,
            )
            # remove BOS
            encoded_images = encoded_images.sequences[..., 1:]
            # decode images
            decoded_images = p_decode(encoded_images, vqgan_params)
            decoded_images = decoded_images.clip(0.0, 1.0).reshape(
                (-1, 256, 256, 3))
            for decoded_img in decoded_images:
                img = Image.fromarray(
                    np.asarray(decoded_img * 255, dtype=np.uint8))
                images.append(img)
                # display(img)
                filename = f"{random.randrange(100, 999)}@{datetime.now()}"
                print(f"Saving picture '{filename}'")
                with open(Path(self.cache_dir, filename), 'wb') as image_file:
                    shutil.copyfileobj(img, image_file)
Beispiel #7
0
def train():
  """Train model."""
  batch_size = FLAGS.batch_size
  n_devices = jax.device_count()
  if jax.host_count() > 1:
    raise ValueError('PixelCNN++ example should not be run on more than 1 host'
                     ' (for now)')
  if batch_size % n_devices > 0:
    raise ValueError('Batch size must be divisible by the number of devices')

  train_summary_writer, eval_summary_writer = get_summary_writers()

  # Load dataset
  data_source = input_pipeline.DataSource(
      train_batch_size=batch_size, eval_batch_size=batch_size)
  train_ds = data_source.train_ds
  eval_ds = data_source.eval_ds

  # Create dataset batch iterators
  train_iter = iter(train_ds)
  eval_iter = iter(eval_ds)

  # Compute steps per epoch and nb of eval steps
  steps_per_epoch = data_source.TRAIN_IMAGES // batch_size
  steps_per_eval = data_source.EVAL_IMAGES // batch_size
  steps_per_checkpoint = steps_per_epoch * 10
  num_steps = steps_per_epoch * FLAGS.num_epochs

  # Create the model using data-dependent initialization. Don't shard the init
  # batch.
  assert FLAGS.init_batch_size <= batch_size
  init_batch = next(train_iter)['image']._numpy()[:FLAGS.init_batch_size]

  rng = random.PRNGKey(FLAGS.rng)
  rng, init_rng = random.split(rng)
  rng, dropout_rng = random.split(rng)

  initial_variables = model().init({
      'params': init_rng,
      'dropout': dropout_rng
  }, init_batch)['params']
  optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995)
  optimizer = optimizer_def.create(initial_variables)

  optimizer, ema = restore_checkpoint(optimizer, initial_variables)
  ema = initial_variables
  step_offset = int(optimizer.state.step)

  optimizer, ema = jax_utils.replicate((optimizer, ema))

  # Learning rate schedule
  learning_rate_fn = lambda step: FLAGS.learning_rate * FLAGS.lr_decay ** step

  # pmap the train and eval functions
  p_train_step = jax.pmap(
      partial(train_step, learning_rate_fn), axis_name='batch')
  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  # Gather metrics
  train_metrics = []

  for step, batch in zip(range(step_offset, num_steps), train_iter):
    # Load and shard the TF batch
    batch = load_and_shard_tf_batch(batch)

    # Generate a PRNG key that will be rolled into the batch.
    rng, step_rng = random.split(rng)
    sharded_rngs = common_utils.shard_prng_key(step_rng)

    # Train step
    optimizer, ema, metrics = p_train_step(optimizer, ema, batch, sharded_rngs)
    train_metrics.append(metrics)

    if (step + 1) % steps_per_epoch == 0:
      epoch = step // steps_per_epoch
      # We've finished an epoch
      train_metrics = common_utils.get_metrics(train_metrics)
      # Get training epoch summary for logging
      train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
      # Send stats to Tensorboard
      for key, vals in train_metrics.items():
        for i, val in enumerate(vals):
          train_summary_writer.scalar(key, val, step - len(vals) + i + 1)
      # Reset train metrics
      train_metrics = []

      # Evaluation
      eval_metrics = []
      for _ in range(steps_per_eval):
        eval_batch = next(eval_iter)
        # Load and shard the TF batch
        eval_batch = load_and_shard_tf_batch(eval_batch)
        # Step
        metrics = p_eval_step(ema, eval_batch)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      # Get eval epoch summary for logging
      eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)

      # Log epoch summary
      logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch,
                   train_summary['loss'], eval_summary['loss'])

      eval_summary_writer.scalar('loss', eval_summary['loss'], step)
      train_summary_writer.flush()
      eval_summary_writer.flush()

    if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
      save_checkpoint(optimizer, ema, step)
def main():
    args = parse_args()

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() ==
                    0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).

    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.

    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)

    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if args.task_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset("glue", args.task_name)
    else:
        # Loading the dataset from local csv or json file.
        data_files = {}
        if args.train_file is not None:
            data_files["train"] = args.train_file
        if args.validation_file is not None:
            data_files["validation"] = args.validation_file
        extension = (args.train_file if args.train_file is not None else
                     args.valid_file).split(".")[-1]
        raw_datasets = load_dataset(extension, data_files=data_files)
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
    if args.task_name is not None:
        is_regression = args.task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
        is_regression = raw_datasets["train"].features["label"].dtype in [
            "float32", "float64"
        ]
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
            label_list = raw_datasets["train"].unique("label")
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(args.model_name_or_path,
                                        num_labels=num_labels,
                                        finetuning_task=args.task_name)
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
    model = FlaxAutoModelForSequenceClassification.from_pretrained(
        args.model_name_or_path, config=config)

    # Preprocessing the datasets
    if args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[args.task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
        non_label_column_names = [
            name for name in raw_datasets["train"].column_names
            if name != "label"
        ]
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (model.config.label2id !=
            PretrainedConfig(num_labels=num_labels).label2id
            and args.task_name is not None and not is_regression):
        # Some have all caps in their config, some don't.
        label_name_to_id = {
            k.lower(): v
            for k, v in model.config.label2id.items()
        }
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            logger.info(
                f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
                "Using it!")
            label_to_id = {
                i: label_name_to_id[label_list[i]]
                for i in range(num_labels)
            }
        else:
            logger.warning(
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
    elif args.task_name is None:
        label_to_id = {v: i for i, v in enumerate(label_list)}

    def preprocess_function(examples):
        # Tokenize the texts
        texts = ((examples[sentence1_key], ) if sentence2_key is None else
                 (examples[sentence1_key], examples[sentence2_key]))
        result = tokenizer(*texts,
                           padding="max_length",
                           max_length=args.max_length,
                           truncation=True)

        if "label" in examples:
            if label_to_id is not None:
                # Map labels to IDs (not necessary for GLUE tasks)
                result["labels"] = [label_to_id[l] for l in examples["label"]]
            else:
                # In all cases, rename the column to labels because the model will expect that.
                result["labels"] = examples["label"]
        return result

    processed_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_datasets["train"].column_names)

    train_dataset = processed_datasets["train"]
    eval_dataset = processed_datasets["validation_matched" if args.task_name ==
                                      "mnli" else "validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(
            f"Sample {index} of the training set: {train_dataset[index]}.")

    # Define a summary writer
    summary_writer = tensorboard.SummaryWriter(args.output_dir)
    summary_writer.hparams(vars(args))

    def write_metric(train_metrics, eval_metrics, train_time, step):
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)

    num_epochs = int(args.num_train_epochs)
    rng = jax.random.PRNGKey(args.seed)

    train_batch_size = args.per_device_train_batch_size * jax.local_device_count(
    )
    eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count(
    )

    learning_rate_fn = create_learning_rate_fn(len(train_dataset),
                                               train_batch_size,
                                               args.num_train_epochs,
                                               args.num_warmup_steps,
                                               args.learning_rate)

    state = create_train_state(model,
                               learning_rate_fn,
                               is_regression,
                               num_labels=num_labels)

    # define step functions
    def train_step(
            state: train_state.TrainState, batch: Dict[str, Array],
            dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]
            loss = state.loss_fn(logits, targets)
            return loss, logits

        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, logits), grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": learning_rate_fn(state.step)
            },
            axis_name="batch")
        return new_state, metrics

    p_train_step = jax.pmap(train_step,
                            axis_name="batch",
                            donate_argnums=(0, ))

    def eval_step(state, batch):
        logits = state.apply_fn(**batch, params=state.params, train=False)[0]
        return state.logits_fn(logits)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    if args.task_name is not None:
        metric = load_metric("glue", args.task_name)
    else:
        metric = load_metric("accuracy")

    logger.info(f"===== Starting training ({num_epochs} epochs) =====")
    train_time = 0

    for epoch in range(1, num_epochs + 1):
        logger.info(f"Epoch {epoch}")
        logger.info("  Training...")

        # make sure weights are replicated on each device
        state = replicate(state)

        train_start = time.time()
        train_metrics = []
        rng, input_rng, dropout_rng = jax.random.split(rng, 3)

        # train
        for batch in glue_train_data_collator(input_rng, train_dataset,
                                              train_batch_size):
            dropout_rngs = shard_prng_key(dropout_rng)
            state, metrics = p_train_step(state, batch, dropout_rngs)
            train_metrics.append(metrics)
            train_time += time.time() - train_start
            logger.info(f"    Done! Training metrics: {unreplicate(metrics)}")

        logger.info("  Evaluating...")
        rng, input_rng = jax.random.split(rng)

        # evaluate
        for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
            labels = batch.pop("labels")
            predictions = p_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions),
                             references=chain(*labels))

        # evaluate also on leftover examples (not divisible by batch_size)
        num_leftover_samples = len(eval_dataset) % eval_batch_size

        # make sure leftover batch is evaluated on one device
        if num_leftover_samples > 0 and jax.process_index() == 0:
            # put weights on single device
            state = unreplicate(state)

            # take leftover samples
            batch = eval_dataset[-num_leftover_samples:]
            batch = {k: jnp.array(v) for k, v in batch.items()}

            labels = batch.pop("labels")
            predictions = eval_step(state, batch)
            metric.add_batch(predictions=predictions, references=labels)

        eval_metric = metric.compute()
        logger.info(f"    Done! Eval metrics: {eval_metric}")

        cur_step = epoch * (len(train_dataset) // train_batch_size)
        write_metric(train_metrics, eval_metric, train_time, cur_step)

    # save last checkpoint
    if jax.process_index() == 0:
        params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
        model.save_pretrained(args.output_dir, params=params)
Beispiel #9
0
def train(pcnn_module,
          model_dir,
          batch_size,
          init_batch_size,
          num_epochs,
          learning_rate,
          decay_rate,
          run_seed=0):
    """Train model."""
    if jax.host_count() > 1:
        raise ValueError(
            'PixelCNN++ example should not be run on more than 1 host'
            ' (for now)')

    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = model_dir + '/log/' + current_time
    train_log_dir = log_dir + '/train'
    eval_log_dir = log_dir + '/eval'
    train_summary_writer = tensorboard.SummaryWriter(train_log_dir)
    eval_summary_writer = tensorboard.SummaryWriter(eval_log_dir)

    rng = random.PRNGKey(run_seed)

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    # Load dataset
    data_source = input_pipeline.DataSource(train_batch_size=batch_size,
                                            eval_batch_size=batch_size)
    train_ds = data_source.train_ds
    eval_ds = data_source.eval_ds

    # Create dataset batch iterators
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    # Compute steps per epoch and nb of eval steps
    steps_per_epoch = data_source.TRAIN_IMAGES // batch_size
    steps_per_eval = data_source.EVAL_IMAGES // batch_size
    steps_per_checkpoint = steps_per_epoch * 10
    num_steps = steps_per_epoch * num_epochs

    base_learning_rate = learning_rate

    # Create the model using data-dependent initialization. Don't shard the init
    # batch.
    assert init_batch_size <= batch_size
    init_batch = next(train_iter)['image']._numpy()[:init_batch_size]
    model = create_model(rng, init_batch, pcnn_module)
    ema = model.params
    optimizer = create_optimizer(model, base_learning_rate)
    del model  # don't keep a copy of the initial model

    optimizer, ema = restore_checkpoint(optimizer, ema)
    step_offset = int(optimizer.state.step)
    optimizer, ema = jax_utils.replicate((optimizer, ema))

    # Learning rate schedule
    learning_rate_fn = lambda step: base_learning_rate * decay_rate**step

    # pmap the train and eval functions
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # Gather metrics
    train_metrics = []
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        # Generate a PRNG key that will be rolled into the batch
        rng, step_key = jax.random.split(rng)
        # Load and shard the TF batch
        batch = load_and_shard_tf_batch(batch)
        # Shard the step PRNG key
        sharded_keys = common_utils.shard_prng_key(step_key)

        # Train step
        optimizer, ema, metrics = p_train_step(optimizer, ema, batch,
                                               sharded_keys)
        train_metrics.append(metrics)

        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            # We've finished an epoch
            train_metrics = common_utils.get_metrics(train_metrics)
            # Get training epoch summary for logging
            train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
            # Send stats to Tensorboard
            for key, vals in train_metrics.items():
                for i, val in enumerate(vals):
                    train_summary_writer.scalar(key, val,
                                                step - len(vals) + i + 1)
            # Reset train metrics
            train_metrics = []

            # Evaluation
            model_ema = optimizer.target.replace(params=ema)
            eval_metrics = []
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                # Load and shard the TF batch
                eval_batch = load_and_shard_tf_batch(eval_batch)
                # Step
                metrics = p_eval_step(model_ema, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            # Get eval epoch summary for logging
            eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)

            # Log epoch summary
            logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch,
                         train_summary['loss'], eval_summary['loss'])

            eval_summary_writer.scalar('loss', eval_summary['loss'], step)
            train_summary_writer.flush()
            eval_summary_writer.flush()

        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            save_checkpoint(optimizer, ema)
 def initial_state(self):
   return TrainState(
       rng=common_utils.shard_prng_key(jax.random.PRNGKey(0)),
       step=None,
       metrics=None,
       history=self)
Beispiel #11
0
def train(model_def,
          model_dir,
          batch_size,
          num_epochs,
          learning_rate,
          sgd_momentum,
          make_lr_fun=None,
          l2_reg=0.0005,
          run_seed=0):
    """Train model."""
    if jax.host_count() > 1:
        raise ValueError('CIFAR-10 example should not be run on '
                         'more than 1 host (for now)')

    if make_lr_fun is None:
        # No learning rate function provided
        # Default to stepped LR schedule for CIFAR-10 and Wide ResNet
        def make_lr_fun(base_lr, steps_per_epoch):  # pylint: disable=function-redefined
            return lr_schedule.create_stepped_learning_rate_schedule(
                base_lr, steps_per_epoch,
                [[60, 0.2], [120, 0.04], [160, 0.008]])

    summary_writer = tensorboard.SummaryWriter(model_dir)

    rng = random.PRNGKey(run_seed)

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    device_batch_size = batch_size // jax.device_count()

    # Load dataset
    data_source = input_pipeline.CIFAR10DataSource(train_batch_size=batch_size,
                                                   eval_batch_size=batch_size)
    train_ds = data_source.train_ds
    eval_ds = data_source.eval_ds

    # Compute steps per epoch and nb of eval steps
    steps_per_epoch = data_source.TRAIN_IMAGES // batch_size
    steps_per_eval = data_source.EVAL_IMAGES // batch_size
    num_steps = steps_per_epoch * num_epochs

    base_learning_rate = learning_rate

    # Create the model
    image_size = 32
    model, state = create_model(rng, device_batch_size, image_size, model_def)
    state = jax_utils.replicate(state)
    optimizer = create_optimizer(model, base_learning_rate, sgd_momentum)
    del model  # don't keep a copy of the initial model

    # Learning rate schedule
    learning_rate_fn = make_lr_fun(base_learning_rate, steps_per_epoch)

    # pmap the train and eval functions
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn, l2_reg=l2_reg),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # Create dataset batch iterators
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    # Gather metrics
    train_metrics = []
    epoch = 1
    for step, batch in zip(range(num_steps), train_iter):
        # Generate a PRNG key that will be rolled into the batch
        rng, step_key = jax.random.split(rng)
        # Load and shard the TF batch
        batch = load_and_shard_tf_batch(batch)
        # Shard the step PRNG key
        sharded_keys = common_utils.shard_prng_key(step_key)

        # Train step
        optimizer, state, metrics = p_train_step(optimizer, state, batch,
                                                 sharded_keys)
        train_metrics.append(metrics)

        if (step + 1) % steps_per_epoch == 0:
            # We've finished an epoch
            train_metrics = common_utils.get_metrics(train_metrics)
            # Get training epoch summary for logging
            train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
            # Send stats to Tensorboard
            for key, vals in train_metrics.items():
                tag = 'train_%s' % key
                for i, val in enumerate(vals):
                    summary_writer.scalar(tag, val, step - len(vals) + i + 1)
            # Reset train metrics
            train_metrics = []

            # Evaluation
            eval_metrics = []
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                # Load and shard the TF batch
                eval_batch = load_and_shard_tf_batch(eval_batch)
                # Step
                metrics = p_eval_step(optimizer.target, state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            # Get eval epoch summary for logging
            eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)

            # Log epoch summary
            logging.info(
                'Epoch %d: TRAIN loss=%.6f, err=%.2f, EVAL loss=%.6f, err=%.2f',
                epoch, train_summary['loss'],
                train_summary['error_rate'] * 100.0, eval_summary['loss'],
                eval_summary['error_rate'] * 100.0)

            summary_writer.scalar('eval_loss', eval_summary['loss'], epoch)
            summary_writer.scalar('eval_error_rate',
                                  eval_summary['error_rate'], epoch)
            summary_writer.flush()

            epoch += 1
Beispiel #12
0
def main(_):
    if FLAGS.jax_backend_target:
        logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
        jax_config.update("jax_xla_backend", "tpu_driver")
        jax_config.update("jax_backend_target", FLAGS.jax_backend_target)

    logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir)
        # summary_writer.hparams(dict(FLAGS.config))

    rng = random.PRNGKey(FLAGS.seed)
    rng, init_rng_coarse, init_rng_fine = random.split(rng, 3)
    n_devices = jax.device_count()

    ### Load dataset and data values
    if FLAGS.config.dataset_type == "blender":
        images, poses, render_poses, hwf, counts = load_blender.load_data(
            FLAGS.data_dir,
            half_res=FLAGS.config.half_res,
            testskip=FLAGS.config.testskip,
        )
        logging.info("Loaded blender, total images: %d", images.shape[0])

        near = 2.0
        far = 6.0

        if FLAGS.config.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:])
        else:
            images = images[..., :3]

    elif FLAGS.config.dataset_type == "deepvoxels":
        images, poses, render_poses, hwf, counts = load_deepvoxels.load_dv_data(
            FLAGS.data_dir,
            scene=FLAGS.config.shape,
            testskip=FLAGS.config.testskip,
        )
        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R - 1.0
        far = hemi_R + 1.0
        logging.info(
            "Loaded deepvoxels (%s), total images: %d",
            FLAGS.config.shape,
            images.shape[0],
        )
    else:
        raise ValueError(f"Dataset '{FLAGS.config.dataset_type}' is not available.")

    img_h, img_w, focal = hwf
    logging.info("Images splits: %s", counts)
    logging.info("Render poses: %s", render_poses.shape)
    logging.info("Image height: %d, image width: %d, focal: %.5f", img_h, img_w, focal)

    train_imgs, val_imgs, test_imgs, *_ = np.split(images, np.cumsum(counts))
    train_poses, val_poses, test_poses, *_ = np.split(poses, np.cumsum(counts))

    if FLAGS.config.render_factor > 0:
        # render downsampled for speed
        r_img_h = img_h // FLAGS.config.render_factor
        r_img_w = img_w // FLAGS.config.render_factor
        r_focal = focal / FLAGS.config.render_factor
        r_hwf = r_img_h, r_img_w, r_focal
    else:
        r_hwf = hwf

    to_np = lambda x, h=img_h, w=img_w: np.reshape(x, [h, w, -1]).astype(np.float32)
    psnr_fn = lambda x: -10.0 * np.log(x) / np.log(10.0)

    ### Pre-compute rays
    @functools.partial(jax.jit, static_argnums=(0,))
    def prep_rays(hwf, c2w, c2w_sc=None):
        if c2w_sc is not None:
            c2w_sc = c2w_sc[:3, :4]
        return prepare_rays(None, hwf, FLAGS.config, near, far, c2w[:3, :4], c2w_sc)

    rays_render = lax.map(lambda x: prep_rays(r_hwf, x), render_poses)
    render_shape = [-1, n_devices, r_hwf[1], rays_render.shape[-1]]
    rays_render = jnp.reshape(rays_render, render_shape)
    logging.info("Render rays shape: %s", rays_render.shape)

    if FLAGS.config.use_viewdirs:
        rays_render_vdirs = lax.map(
            lambda x: prep_rays(r_hwf, x, render_poses[0]), render_poses
        ).reshape(render_shape)

    if FLAGS.config.batching:
        train_rays = lax.map(lambda pose: prep_rays(hwf, pose), train_poses)
        train_rays = jnp.reshape(train_rays, [-1, train_rays.shape[-1]])
        train_imgs = jnp.reshape(train_imgs, [-1, 3])
        logging.info("Batched rays shape: %s", train_rays.shape)
        val_rays = lax.map(lambda pose: prep_rays(hwf, pose), val_poses)

    test_rays = lax.map(lambda pose: prep_rays(r_hwf, pose), test_poses)
    test_rays = jnp.reshape(test_rays, render_shape)

    ### Init model parameters and optimizer
    input_pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3)
    input_views_shape = (FLAGS.config.num_rand, 3)
    model_coarse, params_coarse = initialized(
        init_rng_coarse, input_pts_shape, input_views_shape, FLAGS.config.model
    )

    optimizer = optim.Adam()
    state = TrainState(
        step=0, optimizer_coarse=optimizer.create(params_coarse), optimizer_fine=None
    )
    model_fn = (model_coarse.apply, None)

    if FLAGS.config.num_importance > 0:
        input_pts_shape = (
            FLAGS.config.num_rand,
            FLAGS.config.num_importance + FLAGS.config.num_samples,
            3,
        )
        model_fine, params_fine = initialized(
            init_rng_fine, input_pts_shape, input_views_shape, FLAGS.config.model_fine
        )
        state = state.replace(optimizer_fine=optimizer.create(params_fine))
        model_fn = (model_coarse.apply, model_fine.apply)

    state = checkpoints.restore_checkpoint(FLAGS.model_dir, state)
    start_step = int(state.step)
    state = jax_utils.replicate(state)

    ### Build 'pmapped' functions for distributed training
    learning_rate_fn = create_learning_rate_scheduler(
        factors=FLAGS.config.lr_schedule,
        base_learning_rate=FLAGS.config.learning_rate,
        decay_factor=FLAGS.config.decay_factor,
        steps_per_decay=FLAGS.config.lr_decay * 1000,
    )
    p_train_step = jax.pmap(
        functools.partial(
            train_step,
            model_fn,
            FLAGS.config,
            learning_rate_fn,
            (hwf, near, far),
        ),
        axis_name="batch",
        donate_argnums=(0,),
    )
    p_eval_step = jax.pmap(
        functools.partial(eval_step, model_fn, FLAGS.config),
        axis_name="batch",
    )

    t = time.time()
    train_metrics = []

    for step in range(start_step, FLAGS.config.num_steps + 1):
        rng, sample_rng, step_rng, test_rng = random.split(rng, 4)
        sharded_rngs = common_utils.shard_prng_key(step_rng)
        coords = None

        if FLAGS.config.batching:
            select_idx = random.randint(
                sample_rng,
                [n_devices * FLAGS.config.num_rand],
                minval=0,
                maxval=train_rays.shape[0],
            )
            inputs = train_rays[select_idx, ...]
            inputs = jnp.reshape(inputs, [n_devices, FLAGS.config.num_rand, -1])
            target = train_imgs[select_idx, ...]
            target = jnp.reshape(target, [n_devices, FLAGS.config.num_rand, 3])
        else:
            img_idx = random.randint(
                sample_rng, [n_devices], minval=0, maxval=counts[0]
            )
            inputs = train_poses[img_idx, ...]  # [n_devices, 4, 4]
            target = train_imgs[img_idx, ...]  # [n_devices, img_h, img_w, 3]

            if step < FLAGS.config.precrop_iters:
                dH = int(img_h // 2 * FLAGS.config.precrop_frac)
                dW = int(img_w // 2 * FLAGS.config.precrop_frac)
                coords = jnp.meshgrid(
                    jnp.arange(img_h // 2 - dH, img_h // 2 + dH),
                    jnp.arange(img_w // 2 - dW, img_w // 2 + dW),
                    indexing="ij",
                )
                coords = jax_utils.replicate(
                    jnp.stack(coords, axis=-1).reshape([-1, 2])
                )

        state, metrics, coarse_res, fine_res = p_train_step(
            state, (inputs, target), coords, rng=sharded_rngs
        )
        train_metrics.append(metrics)

        ### Write summaries to TB
        if step % FLAGS.config.i_print == 0 and step > 0:
            steps_per_sec = time.time() - t
            train_metrics = common_utils.get_metrics(train_metrics)
            train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
            if jax.host_id() == 0:
                logging.info(
                    "Step: %6d, %.3f s/step, loss %.5f, psnr %6.3f",
                    step,
                    steps_per_sec,
                    train_summary["loss"],
                    train_summary["psnr"],
                )
                for key, val in train_summary.items():
                    summary_writer.scalar(f"train/{key}", val, step)

                summary_writer.scalar("steps per second", steps_per_sec, step)
                summary_writer.histogram("raw_c", np.array(coarse_res["raw"]), step)
                if FLAGS.config.num_importance > 0:
                    summary_writer.histogram("raw_f", np.array(fine_res["raw"]), step)
            train_metrics = []

            ### Eval a random validation image and plot it in TB
            if step % FLAGS.config.i_img == 0:
                val_idx = random.randint(test_rng, [1], minval=0, maxval=counts[1])
                if FLAGS.config.batching:
                    inputs = val_rays[tuple(val_idx)].reshape(render_shape)
                else:
                    inputs = prep_rays(hwf, val_poses[tuple(val_idx)])
                    inputs = jnp.reshape(inputs, render_shape)
                target = val_imgs[tuple(val_idx)]
                preds, preds_c, z_std = lax.map(lambda x: p_eval_step(state, x), inputs)
                rgb = to_np(preds["rgb"])
                loss = np.mean((rgb - target) ** 2)

                summary_writer.scalar(f"val/loss", loss, step)
                summary_writer.scalar(f"val/psnr", psnr_fn(loss), step)

                rgb = 255 * np.clip(rgb, 0, 1)
                summary_writer.image("val/rgb", rgb.astype(np.uint8), step)
                summary_writer.image("val/target", target, step)
                summary_writer.image("val/disp", to_np(preds["disp"]), step)
                summary_writer.image("val/acc", to_np(preds["acc"]), step)

                if FLAGS.config.num_importance > 0:
                    rgb = 255 * np.clip(to_np(preds_c["rgb"]), 0, 1)
                    summary_writer.image("val/rgb_c", rgb.astype(np.uint8), step)
                    summary_writer.image("val/disp_c", to_np(preds_c["disp"]), step)
                    summary_writer.image("val/z_std", to_np(z_std), step)

        ### Render a video with test poses
        if step % FLAGS.config.i_video == 0 and step > 0:
            logging.info("Rendering video at step %d", step)
            t = time.time()
            preds, *_ = lax.map(lambda x: p_eval_step(state, x), rays_render)
            gen_video(preds["rgb"], "rgb", r_hwf, step)
            gen_video(preds["disp"] / jnp.max(preds["disp"]), "disp", r_hwf, step, ch=1)

            if FLAGS.config.use_viewdirs:
                preds = lax.map(
                    lambda x: p_eval_step(state, x)[0]["rgb"], rays_render_vdirs
                )
                gen_video(preds, "rgb_still", r_hwf, step)
            logging.info("Video rendering done in %ds", time.time() - t)

        ### Save images in the test set
        if step % FLAGS.config.i_testset == 0 and step > 0:
            logging.info("Rendering test set at step %d", step)
            preds = lax.map(lambda x: p_eval_step(state, x)[0]["rgb"], test_rays)
            save_test_imgs(preds, r_hwf, step)

            if FLAGS.config.render_factor == 0:
                loss = np.mean((preds.reshape(test_imgs.shape) - test_imgs) ** 2.0)
                summary_writer.scalar(f"test/loss", loss, step)
                summary_writer.scalar(f"test/psnr", psnr_fn(loss), step)

        ### Save ckpt
        if step % FLAGS.config.i_weights == 0 and step > 0:
            if jax.host_id() == 0:
                checkpoints.save_checkpoint(
                    FLAGS.model_dir,
                    jax_utils.unreplicate(state),
                    step,
                    keep=5,
                )
        t = time.time()
Beispiel #13
0
def train(config, model_def, device_batch_size, eval_ds, num_steps,
          steps_per_epoch, steps_per_eval, train_ds, image_size, data_source,
          workdir):
  """Train model."""

  make_lr_fn = schedulers.get_make_lr_fn(config)
  make_temp_fn = schedulers.get_make_temp_fn(config)
  make_step_size_fn = schedulers.get_make_step_size_fn(config)
  if jax.host_count() > 1:
    raise ValueError('CIFAR10 example should not be run on '
                     'more than 1 host due to preconditioner updating.')

  initial_step = 0  # TODO(basv): load from checkpoint.
  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)

  # Write config to the summary files. This makes the hyperparameters available
  # in TensorBoard and makes comparison of runs in TensorBoard easier.
  # with writer.summary_writer.as_default():
  writer.write_hparams(dict(config))

  rng = random.PRNGKey(config.seed)
  rng, opt_rng, init_key, sampler_rng = jax.random.split(rng, 4)

  base_learning_rate = config.learning_rate

  # Create the model.
  model, state = create_model(rng, device_batch_size, image_size, model_def)
  parameter_overview.log_parameter_overview(model.params)
  state = jax_utils.replicate(state)

  train_size = data_source.TRAIN_IMAGES

  with flax.deprecated.nn.stochastic(init_key):
    optimizer = create_optimizer(config, model, base_learning_rate, train_size,
                                 sampler_rng)
  del model  # Don't keep a copy of the initial model.

  # Learning rate schedule
  learning_rate_fn = make_lr_fn(base_learning_rate, steps_per_epoch)
  temperature_fn = make_temp_fn(config.base_temp, steps_per_epoch)
  step_size_fn = make_step_size_fn(steps_per_epoch)

  p_eval_step, _, p_train_step, p_update_grad_vars = make_step_functions(
      config, config.l2_reg, learning_rate_fn, train_size, temperature_fn,
      step_size_fn)

  # Create dataset batch iterators.
  train_iter = iter(train_ds)
  eval_iter = iter(eval_ds)

  # Gather metrics.
  train_metrics = []
  epoch = 0

  # Ensemble.
  ensemble = []
  ensemble_logits = []
  ensemble_labels = []
  ensemble_probs = []

  def ensemble_add_step(step):
    if config.lr_schedule == 'cosine':
      # Add if learning rate jumps up again in the next step.
      increase = step_size_fn(step) < step_size_fn(step + 1) - 1e-8
      _, temp_end = ast.literal_eval(config.temp_ramp)
      past_burn_in = step >= steps_per_epoch * temp_end
      return increase and past_burn_in

    elif config.lr_schedule == 'constant':
      if (step + 1) % steps_per_epoch == 0:
        return True
    return False

  logging.info('Starting training loop at step %d.', initial_step)

  for step in range(initial_step, num_steps):
    if config.optimizer in ['sym_euler'] and (step) % steps_per_epoch == 0:
      optimizer, rng = update_preconditioner(config, optimizer,
                                             p_update_grad_vars, rng, state,
                                             train_iter)
    # Generate a PRNG key that will be rolled into the batch
    step_key = jax.random.fold_in(rng, step)
    opt_step_rng = jax.random.fold_in(opt_rng, step)

    # Load and shard the TF batch
    batch = next(train_iter)
    batch = input_pipeline.load_and_shard_tf_batch(config, batch)
    if not config.debug_run:
      # Shard the step PRNG key
      # Don't shard the optimizer rng, as it should be equal among all machines.
      sharded_keys = common_utils.shard_prng_key(step_key)
    else:
      sharded_keys = step_key

    # Train step
    optimizer, state, metrics = p_train_step(optimizer, state, batch,
                                             sharded_keys, opt_step_rng)
    train_metrics.append(metrics)

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
    if step == initial_step:
      initial_train_metrics = get_metrics(config, train_metrics)
      train_summary = jax.tree_map(lambda x: x.mean(), initial_train_metrics)
      train_summary = {'train_' + k: v for k, v in train_summary.items()}
      logging.log(logging.INFO, 'initial metrics = %s',
                  str(train_summary.items()))

    if (step + 1) % steps_per_epoch == 0:
      # We've finished an epoch
      # Save model params/state.

      train_metrics = get_metrics(config, train_metrics)
      # Get training epoch summary for logging
      train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)

      train_summary = {'train_' + k: v for k, v in train_summary.items()}

      writer.write_scalars(epoch, train_summary)
      # Reset train metrics
      train_metrics = []

      # Evaluation
      if config.do_eval:
        eval_metrics = []
        eval_logits = []
        eval_labels = []
        for _ in range(steps_per_eval):
          eval_batch = next(eval_iter)
          # Load and shard the TF batch
          eval_batch = input_pipeline.load_and_shard_tf_batch(
              config, eval_batch)
          # Step
          logits, labels, metrics = p_eval_step(optimizer.target, state,
                                                eval_batch)
          eval_metrics.append(metrics)
          eval_logits.append(logits)
          eval_labels.append(labels)
        eval_metrics = get_metrics(config, eval_metrics)
        # Get eval epoch summary for logging
        eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
        eval_summary = {'eval_' + k: v for k, v in eval_summary.items()}
        writer.write_scalars(epoch, eval_summary)

      if config.algorithm == 'sgmcmc' and ensemble_add_step(step):
        ensemble.append((serialization.to_state_dict(optimizer.target), state))

      if config.algorithm == 'sgmcmc' and ensemble_add_step(
          step) and len(ensemble) >= 1:
        # Gather predictions for this ensemble sample.
        eval_logits = jnp.concatenate(eval_logits, axis=0)
        eval_probs = jax.nn.softmax(eval_logits, axis=-1)
        eval_labels = jnp.concatenate(eval_labels, axis=0)
        # Ensure that labels are consistent between predict runs.
        if ensemble_labels:
          assert jnp.allclose(
              eval_labels,
              ensemble_labels[0]), 'Labels unordered between eval runs.'

        ensemble_logits.append(eval_logits)
        ensemble_probs.append(eval_probs)
        ensemble_labels.append(eval_labels)

        # Compute ensemble predictions over last config.ensemble_size samples.
        ensemble_last_probs = jnp.mean(
            jnp.array(ensemble_probs[-config.ensemble_size:]), axis=0)
        ensemble_metrics = train_functions.compute_metrics_probs(
            ensemble_last_probs, ensemble_labels[0])
        ensemble_summary = jax.tree_map(lambda x: x.mean(), ensemble_metrics)
        ensemble_summary = {'ens_' + k: v for k, v in ensemble_summary.items()}
        ensemble_summary['ensemble_size'] = min(config.ensemble_size,
                                                len(ensemble_probs))
        writer.write_scalars(epoch, ensemble_summary)

      epoch += 1

  return ensemble, optimizer
def train_and_evaluate(config: ml_collections.ConfigDict, resume: str):
    """Execute model training and evaluation loop.

    Args:
      config: Hyperparameter configuration for training and evaluation.
      resume: Resume from checkpoints at specified dir if set (TDDO: support specific checkpoint file/step)
    """
    rng = random.PRNGKey(42)

    if config.batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = config.batch_size // jax.host_count()
    config.eval_batch_size = config.eval_batch_size or config.batch_size
    if config.eval_batch_size % jax.device_count() > 0:
        raise ValueError(
            'Validation batch size must be divisible by the number of devices')
    local_eval_batch_size = config.eval_batch_size // jax.host_count()

    platform = jax.local_devices()[0].platform
    half_prec = config.half_precision
    if half_prec:
        if platform == 'tpu':
            model_dtype = jnp.bfloat16
        else:
            model_dtype = jnp.float16
    else:
        model_dtype = jnp.float32

    rng, model_create_rng = random.split(rng)
    model, variables = create_model(config.model,
                                    dtype=model_dtype,
                                    drop_rate=config.drop_rate,
                                    drop_path_rate=config.drop_path_rate,
                                    rng=model_create_rng)
    image_size = config.image_size or model.default_cfg['input_size'][-1]

    dataset_builder = tfds.builder(config.dataset, data_dir=config.data_dir)

    train_iter = create_input_iter(
        dataset_builder,
        local_batch_size,
        train=True,
        image_size=image_size,
        augment_name=config.autoaugment,
        randaug_magnitude=config.randaug_magnitude,
        randaug_num_layers=config.randaug_num_layers,
        half_precision=half_prec,
        cache=config.cache)

    eval_iter = create_input_iter(dataset_builder,
                                  local_eval_batch_size,
                                  train=False,
                                  image_size=image_size,
                                  half_precision=half_prec,
                                  cache=config.cache)

    steps_per_epoch = dataset_builder.info.splits[
        'train'].num_examples // config.batch_size

    if config.num_train_steps == -1:
        num_steps = steps_per_epoch * config.num_epochs
    else:
        num_steps = config.num_train_steps

    if config.steps_per_eval == -1:
        num_validation_examples = dataset_builder.info.splits[
            'validation'].num_examples
        steps_per_eval = num_validation_examples // config.eval_batch_size
    else:
        steps_per_eval = config.steps_per_eval

    steps_per_checkpoint = steps_per_epoch * 1

    base_lr = config.lr * config.batch_size / 256.
    lr_fn = create_lr_schedule_epochs(base_lr,
                                      config.lr_schedule,
                                      steps_per_epoch=steps_per_epoch,
                                      total_epochs=config.num_epochs,
                                      decay_rate=config.lr_decay_rate,
                                      decay_epochs=config.lr_decay_epochs,
                                      warmup_epochs=config.lr_warmup_epochs,
                                      min_lr=config.lr_minimum)

    state = create_train_state(config, variables, lr_fn)
    if resume:
        state = restore_checkpoint(state, resume)
    # step_offset > 0 if restarting from checkpoint
    step_offset = int(state.step)
    state = flax.jax_utils.replicate(state)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        model.apply,
        lr_fn=lr_fn,
        label_smoothing=config.label_smoothing,
        weight_decay=config.weight_decay),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step, model.apply),
                           axis_name='batch')
    p_eval_step_ema = None
    if config.ema_decay != 0.:
        p_eval_step_ema = jax.pmap(functools.partial(eval_step_ema,
                                                     model.apply),
                                   axis_name='batch')

    if jax.host_id() == 0:
        if resume and step_offset > 0:
            output_dir = resume
        else:
            output_base = config.output_base_dir if config.output_base_dir else './output'
            exp_name = '-'.join(
                [datetime.now().strftime("%Y%m%d-%H%M%S"), config.model])
            output_dir = get_outdir(output_base, exp_name)
        summary_writer = tensorboard.SummaryWriter(output_dir)
        summary_writer.hparams(dict(config))

    epoch_metrics = []
    t_loop_start = time.time()
    num_samples = 0
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        step_p1 = step + 1
        rng, step_rng = random.split(rng)
        sharded_rng = common_utils.shard_prng_key(step_rng)

        num_samples += config.batch_size
        state, metrics = p_train_step(state, batch, dropout_rng=sharded_rng)
        epoch_metrics.append(metrics)

        if step_p1 % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            samples_per_sec = num_samples / (time.time() - t_loop_start)
            logging.info(
                'train epoch: %d, loss: %.4f, img/sec %.2f, top1: %.2f, top5: %.3f',
                epoch, summary['loss'], samples_per_sec, summary['top1'],
                summary['top5'])

            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step_p1 - len(vals) + i)
                summary_writer.scalar('samples per second', samples_per_sec,
                                      step)
            epoch_metrics = []
            state = sync_batch_stats(
                state)  # sync batch statistics across replicas

            eval_metrics = []
            for step_eval in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)

            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, top1: %.2f, top5: %.3f',
                         epoch, summary['loss'], summary['top1'],
                         summary['top5'])

            if p_eval_step_ema is not None:
                # NOTE running both ema and non-ema eval while improving this script
                eval_metrics = []
                for step_eval in range(steps_per_eval):
                    eval_batch = next(eval_iter)
                    metrics = p_eval_step_ema(state, eval_batch)
                    eval_metrics.append(metrics)

                eval_metrics = common_utils.get_metrics(eval_metrics)
                summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
                logging.info(
                    'eval epoch ema: %d, loss: %.4f, top1: %.2f, top5: %.3f',
                    epoch, summary['loss'], summary['top1'], summary['top5'])

            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
            t_loop_start = time.time()
            num_samples = 0

        elif step_p1 % 100 == 0:
            summary = jax.tree_map(lambda x: x.mean(),
                                   common_utils.get_metrics(epoch_metrics))
            samples_per_sec = num_samples / (time.time() - t_loop_start)
            logging.info('train steps: %d, loss: %.4f, img/sec: %.2f', step_p1,
                         summary['loss'], samples_per_sec)

        if step_p1 % steps_per_checkpoint == 0 or step_p1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state, output_dir)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Beispiel #15
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    batch_size = config.batch_size
    n_devices = jax.device_count()
    if jax.host_count() > 1:
        raise ValueError(
            'PixelCNN++ example should not be run on more than 1 host'
            ' (for now)')
    if batch_size % n_devices > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    train_summary_writer, eval_summary_writer = get_summary_writers(workdir)
    # Load dataset
    data_source = input_pipeline.DataSource(config)
    train_ds = data_source.train_ds
    eval_ds = data_source.eval_ds
    steps_per_epoch = data_source.ds_info.splits[
        'train'].num_examples // config.batch_size
    # Create dataset batch iterators
    train_iter = iter(train_ds)
    num_train_steps = train_ds.cardinality().numpy()
    steps_per_checkpoint = 1000

    # Create the model using data-dependent initialization. Don't shard the init
    # batch.
    assert config.init_batch_size <= batch_size
    init_batch = next(train_iter)['image']._numpy()[:config.init_batch_size]

    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng, dropout_rng = jax.random.split(rng, 3)

    initial_variables = model(config).init(
        {
            'params': init_rng,
            'dropout': dropout_rng
        }, init_batch)['params']
    optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995)
    optimizer = optimizer_def.create(initial_variables)

    optimizer, ema = restore_checkpoint(workdir, optimizer, initial_variables)
    ema = initial_variables
    step_offset = int(optimizer.state.step)

    optimizer, ema = jax_utils.replicate((optimizer, ema))

    # Learning rate schedule
    learning_rate_fn = lambda step: config.learning_rate * config.lr_decay**step

    # pmap the train and eval functions
    p_train_step = jax.pmap(functools.partial(train_step, config,
                                              learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step, config=config),
                           axis_name='batch')

    # Gather metrics
    train_metrics = []

    for step, batch in zip(range(step_offset, num_train_steps), train_iter):
        # Load and shard the TF batch
        batch = load_and_shard_tf_batch(batch)

        # Generate a PRNG key that will be rolled into the batch.
        rng, step_rng = jax.random.split(rng)
        sharded_rngs = common_utils.shard_prng_key(step_rng)

        # Train step
        optimizer, ema, metrics = p_train_step(optimizer, ema, batch,
                                               sharded_rngs)
        train_metrics.append(metrics)

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                            step)

        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            # We've finished an epoch
            train_metrics = common_utils.get_metrics(train_metrics)
            # Get training epoch summary for logging
            train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
            # Send stats to Tensorboard
            for key, vals in train_metrics.items():
                for i, val in enumerate(vals):
                    train_summary_writer.scalar(key, val,
                                                step - len(vals) + i + 1)
            # Reset train metrics
            train_metrics = []

            # Evaluation
            eval_metrics = []
            for eval_batch in eval_ds:
                # Load and shard the TF batch
                eval_batch = load_and_shard_tf_batch(eval_batch)
                # Step
                metrics = p_eval_step(ema, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            # Get eval epoch summary for logging
            eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)

            # Log epoch summary
            logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch,
                         train_summary['loss'], eval_summary['loss'])

            eval_summary_writer.scalar('loss', eval_summary['loss'], step)
            train_summary_writer.flush()
            eval_summary_writer.flush()

        if (step +
                1) % steps_per_checkpoint == 0 or step + 1 == num_train_steps:
            save_checkpoint(workdir, optimizer, ema, step)
Beispiel #16
0
 def replicate(self):
     return jax_utils.replicate(self).replace(
         dropout_rng=shard_prng_key(self.dropout_rng))
Beispiel #17
0
def main(_):
    if FLAGS.config.precrop_iters > 0 and FLAGS.config.batching:
        raise ValueError(
            "'precrop_iters has no effect when 'batching' the dataset")
    assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0

    logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    platform.work_unit().set_task_status(
        f"host_id: {jax.process_index()}, host_count: {jax.host_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.model_dir, "model_dir")

    os.makedirs(FLAGS.model_dir, exist_ok=True)
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_coarse, rng_fine, data_rng, step_rng = jax.random.split(rng, 5)
    rngs = common_utils.shard_prng_key(step_rng)

    ### Load dataset and data values
    datasets, counts, optics, render_datasets = get_dataset(
        FLAGS.data_dir,
        FLAGS.config,
        rng=data_rng,
        num_poses=FLAGS.config.num_poses)
    train_ds, val_ds, test_ds = datasets
    *_, test_items = counts
    hwf, r_hwf, near, far = optics
    render_ds, render_vdirs_ds, num_poses = render_datasets
    iter_render_ds = zip(range(num_poses), render_ds)
    iter_vdirs_ds = zip(range(num_poses), render_vdirs_ds)
    iter_test_ds = zip(range(test_items), test_ds)
    img_h, img_w, _ = hwf

    logging.info("Num poses: %d", num_poses)
    logging.info("Splits: train - %d, val - %d, test - %d", *counts)
    logging.info("Images: height %d, width %d, focal %.5f", *hwf)
    logging.info("Render: height %d, width %d, focal %.5f", *r_hwf)

    ### Init model parameters and optimizer
    initialized_ = functools.partial(initialized,
                                     model_config=FLAGS.config.model)
    pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3)
    views_shape = (FLAGS.config.num_rand, 3)
    model_coarse, params_coarse = initialized_(rng_coarse, pts_shape,
                                               views_shape)

    schedule_fn = optax.exponential_decay(
        init_value=FLAGS.config.learning_rate,
        transition_steps=FLAGS.config.lr_decay * 1000,
        decay_rate=FLAGS.config.decay_factor,
    )
    tx = optax.adam(learning_rate=schedule_fn)
    state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None),
                                          params={"coarse": params_coarse},
                                          tx=tx)

    if FLAGS.config.num_importance > 0:
        pts_shape = (
            FLAGS.config.num_rand,
            FLAGS.config.num_importance + FLAGS.config.num_samples,
            3,
        )
        model_fine, params_fine = initialized_(rng_fine, pts_shape,
                                               views_shape)
        state = train_state.TrainState.create(
            apply_fn=(model_coarse.apply, model_fine.apply),
            params={
                "coarse": params_coarse,
                "fine": params_fine
            },
            tx=tx,
        )

    state = checkpoints.restore_checkpoint(FLAGS.model_dir, state)
    start_step = int(state.step)

    # cycle already seen examples if resuming from checkpoint
    # (only useful for ensuring deterministic dataset, slow for large start_step)
    # if start_step != 0:
    #     for _ in range(start_step):
    #         _ = next(train_ds)

    # parameter_overview.log_parameter_overview(state.optimizer_coarse.target)
    # if FLAGS.config.num_importance > 0:
    #     parameter_overview.log_parameter_overview(state.optimizer_fine.target)

    state = jax.device_put_replicated(state, jax.local_devices())

    ### Build "pmapped" functions for distributed training
    train_fn = functools.partial(train_step, near, far, FLAGS.config,
                                 schedule_fn)
    p_train_step = jax.pmap(
        train_fn,
        axis_name="batch",
        in_axes=(0, 0, None, 0),
        # donate_argnums=(0, 1, 2),
    )

    def render_fn(state, rays):
        step_fn = functools.partial(eval_step, FLAGS.config, near, far, state)
        return lax.map(step_fn, rays)

    p_eval_step = jax.pmap(
        render_fn,
        axis_name="batch",
        # in_axes=(0, 0, None),
        # donate_argnums=(0, 1))
    )

    # TODO: add hparams
    writer = metric_writers.create_default_writer(
        FLAGS.model_dir, just_logging=jax.process_index() > 0)
    logging.info("Starting training loop.")

    hooks = []
    profiler = periodic_actions.Profile(num_profile_steps=5,
                                        logdir=FLAGS.model_dir)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=FLAGS.config.num_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [profiler, report_progress]
    train_metrics = []
    gen_video_ = functools.partial(gen_video, FLAGS.model_dir)

    for step in range(start_step, FLAGS.config.num_steps + 1):
        is_last_step = step == FLAGS.config.num_steps

        batch = next(train_ds)
        coords = None
        if not FLAGS.config.batching:
            coords = jnp.meshgrid(jnp.arange(img_h),
                                  jnp.arange(img_w),
                                  indexing="ij")
            if step < FLAGS.config.precrop_iters:
                dH = int(img_h // 2 * FLAGS.config.precrop_frac)
                dW = int(img_w // 2 * FLAGS.config.precrop_frac)
                coords = jnp.meshgrid(
                    jnp.arange(img_h // 2 - dH, img_h // 2 + dH),
                    jnp.arange(img_w // 2 - dW, img_w // 2 + dW),
                    indexing="ij",
                )
            coords = jnp.stack(coords, axis=-1).reshape([-1, 2])

        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            state, metrics = p_train_step(batch, state, coords, rngs)
        train_metrics.append(metrics)

        logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                            step)
        _ = [h(step) for h in hooks]

        ### Write train summaries to TB
        if step % FLAGS.config.i_print == 0 or is_last_step:
            with report_progress.timed("training_metrics"):
                train_metrics = common_utils.get_metrics(train_metrics)
                train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
                summary = {f"train/{k}": v for k, v in train_summary.items()}
                writer.write_scalars(step, summary)
            train_metrics = []

        ### Eval a random validation image and plot it to TB
        if step % FLAGS.config.i_img == 0 and step > 0 or is_last_step:
            with report_progress.timed("validation"):
                inputs = next(val_ds)
                rays, padding = prepare_render_data(inputs["rays"]._numpy())
                outputs = p_eval_step(state, rays)
                preds, preds_c, z_std = jax.tree_map(
                    lambda x: to_np(x, hwf, padding), outputs)
                loss = np.mean((preds["rgb"] - inputs["image"])**2)
                summary = {"val/loss": loss, "val/psnr": psnr_fn(loss)}
                writer.write_scalars(step, summary)

                summary = {
                    "val/rgb": to_rgb(preds["rgb"]),
                    "val/target": to_np(inputs["image"], hwf, padding),
                    "val/disp": disp_post(preds["disp"], FLAGS.config),
                    "val/acc": preds["acc"],
                }
                if FLAGS.config.num_importance > 0:
                    summary["val/rgb_c"] = to_rgb(preds_c["rgb"])
                    summary["val/disp_c"] = disp_post(preds_c["disp"],
                                                      FLAGS.config)
                    summary["val/z_std"] = z_std
                writer.write_images(step, summary)

        ### Render a video with test poses
        if step % FLAGS.config.i_video == 0 and step > 0:
            with report_progress.timed("video_render"):
                logging.info("Rendering video at step %d", step)
                rgb_list = []
                disp_list = []
                for idx, inputs in tqdm(iter_render_ds, desc="Rays render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding),
                                         preds)
                    rgb_list.append(preds["rgb"])
                    disp_list.append(preds["disp"])

                gen_video_(np.stack(rgb_list), "rgb", r_hwf, step)
                disp = np.stack(disp_list)
                gen_video_(disp_post(disp, FLAGS.config),
                           "disp",
                           r_hwf,
                           step,
                           ch=1)

                if FLAGS.config.use_viewdirs:
                    rgb_list = []
                    for idx, inputs in tqdm(iter_vdirs_ds,
                                            desc="Viewdirs render"):
                        rays, padding = prepare_render_data(
                            inputs["rays"].numpy())
                        preds, *_ = p_eval_step(state, rays)
                        rgb_list.append(to_np(preds["rgb"], r_hwf, padding))
                    gen_video_(np.stack(rgb_list), "rgb_still", r_hwf, step)

        ### Save images in the test set
        if step % FLAGS.config.i_testset == 0 and step > 0:
            with report_progress.timed("test_render"):
                logging.info("Rendering test set at step %d", step)
                test_losses = []
                for idx, inputs in tqdm(iter_test_ds, desc="Test render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    save_test_imgs(FLAGS.model_dir, preds["rgb"], r_hwf, step,
                                   idx)

                    if FLAGS.config.render_factor == 0:
                        loss = np.mean((preds["rgb"] - inputs["image"])**2.0)
                        test_losses.append(loss)
                if FLAGS.config.render_factor == 0:
                    loss = np.mean(test_losses)
                    summary = {"test/loss": loss, "test/psnr": psnr_fn(loss)}
                    writer.write_scalars(step, summary)
        writer.flush()

        ### Save ckpt
        if step % FLAGS.config.i_weights == 0 or is_last_step:
            with report_progress.timed("checkpoint"):
                save_checkpoint(state, FLAGS.model_dir)