예제 #1
0
def train(config, workdir):
    """Runs the training pipeline.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """

    # Create directories for experimental logs
    sample_dir = os.path.join(workdir, "samples")
    tf.io.gfile.makedirs(sample_dir)

    tb_dir = os.path.join(workdir, "tensorboard")
    tf.io.gfile.makedirs(tb_dir)
    writer = tensorboard.SummaryWriter(tb_dir)

    # Initialize model.
    score_model = mutils.create_model(config)
    ema = ExponentialMovingAverage(score_model.parameters(),
                                   decay=config.model.ema_rate)
    optimizer = losses.get_optimizer(config, score_model.parameters())
    state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)

    # Create checkpoints directory
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    # Intermediate checkpoints to resume training after pre-emption in cloud environments
    checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta",
                                       "checkpoint.pth")
    tf.io.gfile.makedirs(checkpoint_dir)
    tf.io.gfile.makedirs(os.path.dirname(checkpoint_meta_dir))
    # Resume training when intermediate checkpoints are detected
    state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
    initial_step = int(state['step'])

    # Build data iterators
    train_ds, eval_ds, _ = datasets.get_dataset(
        config, uniform_dequantization=config.data.uniform_dequantization)
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min,
                            beta_max=config.model.beta_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min,
                               beta_max=config.model.beta_max,
                               N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min,
                            sigma_max=config.model.sigma_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(config)
    continuous = config.training.continuous
    reduce_mean = config.training.reduce_mean
    likelihood_weighting = config.training.likelihood_weighting
    train_step_fn = losses.get_step_fn(
        sde,
        train=True,
        optimize_fn=optimize_fn,
        reduce_mean=reduce_mean,
        continuous=continuous,
        likelihood_weighting=likelihood_weighting)
    eval_step_fn = losses.get_step_fn(
        sde,
        train=False,
        optimize_fn=optimize_fn,
        reduce_mean=reduce_mean,
        continuous=continuous,
        likelihood_weighting=likelihood_weighting)

    # Building sampling functions
    if config.training.snapshot_sampling:
        sampling_shape = (config.training.batch_size, config.data.num_channels,
                          config.data.image_size, config.data.image_size)
        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape,
                                               inverse_scaler, sampling_eps)

    num_train_steps = config.training.n_iters

    # In case there are multiple hosts (e.g., TPU pods), only log to host 0
    logging.info("Starting training loop at step %d." % (initial_step, ))

    for step in range(initial_step, num_train_steps + 1):
        # Convert data to JAX arrays and normalize them. Use ._numpy() to avoid copy.
        batch = torch.from_numpy(next(train_iter)['image']._numpy()).to(
            config.device).float()
        batch = batch.permute(0, 3, 1, 2)
        batch = scaler(batch)
        # Execute one training step
        loss = train_step_fn(state, batch)
        if step % config.training.log_freq == 0:
            logging.info("step: %d, training_loss: %.5e" % (step, loss.item()))
            writer.add_scalar("training_loss", loss, step)

        # Save a temporary checkpoint to resume training after pre-emption periodically
        if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
            save_checkpoint(checkpoint_meta_dir, state)

        # Report the loss on an evaluation dataset periodically
        if step % config.training.eval_freq == 0:
            eval_batch = torch.from_numpy(
                next(eval_iter)['image']._numpy()).to(config.device).float()
            eval_batch = eval_batch.permute(0, 3, 1, 2)
            eval_batch = scaler(eval_batch)
            eval_loss = eval_step_fn(state, eval_batch)
            logging.info("step: %d, eval_loss: %.5e" %
                         (step, eval_loss.item()))
            writer.add_scalar("eval_loss", eval_loss.item(), step)

        # Save a checkpoint periodically and generate samples if needed
        if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
            # Save the checkpoint.
            save_step = step // config.training.snapshot_freq
            save_checkpoint(
                os.path.join(checkpoint_dir, f'checkpoint_{save_step}.pth'),
                state)

            # Generate and save samples
            if config.training.snapshot_sampling:
                ema.store(score_model.parameters())
                ema.copy_to(score_model.parameters())
                sample, n = sampling_fn(score_model)
                ema.restore(score_model.parameters())
                this_sample_dir = os.path.join(sample_dir,
                                               "iter_{}".format(step))
                tf.io.gfile.makedirs(this_sample_dir)
                nrow = int(np.sqrt(sample.shape[0]))
                image_grid = make_grid(sample, nrow, padding=2)
                sample = np.clip(
                    sample.permute(0, 2, 3, 1).cpu().numpy() * 255, 0,
                    255).astype(np.uint8)
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, "sample.np"),
                        "wb") as fout:
                    np.save(fout, sample)

                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, "sample.png"),
                        "wb") as fout:
                    save_image(image_grid, fout)
예제 #2
0
def train(config, workdir):
    """Runs the training pipeline.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """

    # Create directories for experimental logs
    sample_dir = os.path.join(workdir, "samples")
    tf.io.gfile.makedirs(sample_dir)

    rng = jax.random.PRNGKey(config.seed)
    tb_dir = os.path.join(workdir, "tensorboard")
    tf.io.gfile.makedirs(tb_dir)
    if jax.host_id() == 0:
        writer = tensorboard.SummaryWriter(tb_dir)

    # Initialize model.
    rng, step_rng = jax.random.split(rng)
    score_model, init_model_state, initial_params = mutils.init_model(
        step_rng, config)
    optimizer = losses.get_optimizer(config).create(initial_params)
    state = mutils.State(step=0,
                         optimizer=optimizer,
                         lr=config.optim.lr,
                         model_state=init_model_state,
                         ema_rate=config.model.ema_rate,
                         params_ema=initial_params,
                         rng=rng)  # pytype: disable=wrong-keyword-args

    # Create checkpoints directory
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    # Intermediate checkpoints to resume training after pre-emption in cloud environments
    checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")
    tf.io.gfile.makedirs(checkpoint_dir)
    tf.io.gfile.makedirs(checkpoint_meta_dir)
    # Resume training when intermediate checkpoints are detected
    state = checkpoints.restore_checkpoint(checkpoint_meta_dir, state)
    # `state.step` is JAX integer on the GPU/TPU devices
    initial_step = int(state.step)
    rng = state.rng

    # Build data iterators
    train_ds, eval_ds, _ = datasets.get_dataset(
        config,
        additional_dim=config.training.n_jitted_steps,
        uniform_dequantization=config.data.uniform_dequantization)
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min,
                            beta_max=config.model.beta_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min,
                               beta_max=config.model.beta_max,
                               N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min,
                            sigma_max=config.model.sigma_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(config)
    continuous = config.training.continuous
    reduce_mean = config.training.reduce_mean
    likelihood_weighting = config.training.likelihood_weighting
    train_step_fn = losses.get_step_fn(
        sde,
        score_model,
        train=True,
        optimize_fn=optimize_fn,
        reduce_mean=reduce_mean,
        continuous=continuous,
        likelihood_weighting=likelihood_weighting)
    # Pmap (and jit-compile) multiple training steps together for faster running
    p_train_step = jax.pmap(functools.partial(jax.lax.scan, train_step_fn),
                            axis_name='batch',
                            donate_argnums=1)
    eval_step_fn = losses.get_step_fn(
        sde,
        score_model,
        train=False,
        optimize_fn=optimize_fn,
        reduce_mean=reduce_mean,
        continuous=continuous,
        likelihood_weighting=likelihood_weighting)
    # Pmap (and jit-compile) multiple evaluation steps together for faster running
    p_eval_step = jax.pmap(functools.partial(jax.lax.scan, eval_step_fn),
                           axis_name='batch',
                           donate_argnums=1)

    # Building sampling functions
    if config.training.snapshot_sampling:
        sampling_shape = (config.training.batch_size //
                          jax.local_device_count(), config.data.image_size,
                          config.data.image_size, config.data.num_channels)
        sampling_fn = sampling.get_sampling_fn(config, sde, score_model,
                                               sampling_shape, inverse_scaler,
                                               sampling_eps)

    # Replicate the training state to run on multiple devices
    pstate = flax_utils.replicate(state)
    num_train_steps = config.training.n_iters

    # In case there are multiple hosts (e.g., TPU pods), only log to host 0
    if jax.host_id() == 0:
        logging.info("Starting training loop at step %d." % (initial_step, ))
    rng = jax.random.fold_in(rng, jax.host_id())

    # JIT multiple training steps together for faster training
    n_jitted_steps = config.training.n_jitted_steps
    # Must be divisible by the number of steps jitted together
    assert config.training.log_freq % n_jitted_steps == 0 and \
           config.training.snapshot_freq_for_preemption % n_jitted_steps == 0 and \
           config.training.eval_freq % n_jitted_steps == 0 and \
           config.training.snapshot_freq % n_jitted_steps == 0, "Missing logs or checkpoints!"

    for step in range(initial_step, num_train_steps + 1,
                      config.training.n_jitted_steps):
        # Convert data to JAX arrays and normalize them. Use ._numpy() to avoid copy.
        batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter))  # pylint: disable=protected-access
        rng, *next_rng = jax.random.split(rng,
                                          num=jax.local_device_count() + 1)
        next_rng = jnp.asarray(next_rng)
        # Execute one training step
        (_, pstate), ploss = p_train_step((next_rng, pstate), batch)
        loss = flax.jax_utils.unreplicate(ploss).mean()
        # Log to console, file and tensorboard on host 0
        if jax.host_id() == 0 and step % config.training.log_freq == 0:
            logging.info("step: %d, training_loss: %.5e" % (step, loss))
            writer.scalar("training_loss", loss, step)

        # Save a temporary checkpoint to resume training after pre-emption periodically
        if step != 0 and step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id(
        ) == 0:
            saved_state = flax_utils.unreplicate(pstate)
            saved_state = saved_state.replace(rng=rng)
            checkpoints.save_checkpoint(
                checkpoint_meta_dir,
                saved_state,
                step=step // config.training.snapshot_freq_for_preemption,
                keep=1)

        # Report the loss on an evaluation dataset periodically
        if step % config.training.eval_freq == 0:
            eval_batch = jax.tree_map(lambda x: scaler(x._numpy()),
                                      next(eval_iter))  # pylint: disable=protected-access
            rng, *next_rng = jax.random.split(rng,
                                              num=jax.local_device_count() + 1)
            next_rng = jnp.asarray(next_rng)
            (_, _), peval_loss = p_eval_step((next_rng, pstate), eval_batch)
            eval_loss = flax.jax_utils.unreplicate(peval_loss).mean()
            if jax.host_id() == 0:
                logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss))
                writer.scalar("eval_loss", eval_loss, step)

        # Save a checkpoint periodically and generate samples if needed
        if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
            # Save the checkpoint.
            if jax.host_id() == 0:
                saved_state = flax_utils.unreplicate(pstate)
                saved_state = saved_state.replace(rng=rng)
                checkpoints.save_checkpoint(checkpoint_dir,
                                            saved_state,
                                            step=step //
                                            config.training.snapshot_freq,
                                            keep=np.inf)

            # Generate and save samples
            if config.training.snapshot_sampling:
                rng, *sample_rng = jax.random.split(
                    rng,
                    jax.local_device_count() + 1)
                sample_rng = jnp.asarray(sample_rng)
                sample, n = sampling_fn(sample_rng, pstate)
                this_sample_dir = os.path.join(
                    sample_dir, "iter_{}_host_{}".format(step, jax.host_id()))
                tf.io.gfile.makedirs(this_sample_dir)
                image_grid = sample.reshape((-1, *sample.shape[2:]))
                nrow = int(np.sqrt(image_grid.shape[0]))
                sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, "sample.np"),
                        "wb") as fout:
                    np.save(fout, sample)

                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, "sample.png"),
                        "wb") as fout:
                    utils.save_image(image_grid, fout, nrow=nrow, padding=2)
예제 #3
0
def evaluate(config, workdir, eval_folder="eval"):
    """Evaluate trained models.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints.
    eval_folder: The subfolder for storing evaluation results. Default to
      "eval".
  """
    # Create directory to eval_folder
    eval_dir = os.path.join(workdir, eval_folder)
    tf.io.gfile.makedirs(eval_dir)

    # Build data pipeline
    train_ds, eval_ds, _ = datasets.get_dataset(
        config,
        uniform_dequantization=config.data.uniform_dequantization,
        evaluation=True)

    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Initialize model
    score_model = mutils.create_model(config)
    optimizer = losses.get_optimizer(config, score_model.parameters())
    ema = ExponentialMovingAverage(score_model.parameters(),
                                   decay=config.model.ema_rate)
    state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)

    checkpoint_dir = os.path.join(workdir, "checkpoints")

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min,
                            beta_max=config.model.beta_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min,
                               beta_max=config.model.beta_max,
                               N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min,
                            sigma_max=config.model.sigma_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    # Create the one-step evaluation function when loss computation is enabled
    if config.eval.enable_loss:
        optimize_fn = losses.optimization_manager(config)
        continuous = config.training.continuous
        likelihood_weighting = config.training.likelihood_weighting

        reduce_mean = config.training.reduce_mean
        eval_step = losses.get_step_fn(
            sde,
            train=False,
            optimize_fn=optimize_fn,
            reduce_mean=reduce_mean,
            continuous=continuous,
            likelihood_weighting=likelihood_weighting)

    # Create data loaders for likelihood evaluation. Only evaluate on uniformly dequantized data
    train_ds_bpd, eval_ds_bpd, _ = datasets.get_dataset(
        config, uniform_dequantization=True, evaluation=True)
    if config.eval.bpd_dataset.lower() == 'train':
        ds_bpd = train_ds_bpd
        bpd_num_repeats = 1
    elif config.eval.bpd_dataset.lower() == 'test':
        # Go over the dataset 5 times when computing likelihood on the test dataset
        ds_bpd = eval_ds_bpd
        bpd_num_repeats = 5
    else:
        raise ValueError(
            f"No bpd dataset {config.eval.bpd_dataset} recognized.")

    # Build the likelihood computation function when likelihood is enabled
    if config.eval.enable_bpd:
        likelihood_fn = likelihood.get_likelihood_fn(sde, inverse_scaler)

    # Build the sampling function when sampling is enabled
    if config.eval.enable_sampling:
        sampling_shape = (config.eval.batch_size, config.data.num_channels,
                          config.data.image_size, config.data.image_size)
        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape,
                                               inverse_scaler, sampling_eps)

    # Use inceptionV3 for images with resolution higher than 256.
    inceptionv3 = config.data.image_size >= 256
    inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)

    begin_ckpt = config.eval.begin_ckpt
    logging.info("begin checkpoint: %d" % (begin_ckpt, ))
    for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):
        # Wait if the target checkpoint doesn't exist yet
        waiting_message_printed = False
        ckpt_filename = os.path.join(checkpoint_dir,
                                     "checkpoint_{}.pth".format(ckpt))
        while not tf.io.gfile.exists(ckpt_filename):
            if not waiting_message_printed:
                logging.warning("Waiting for the arrival of checkpoint_%d" %
                                (ckpt, ))
                waiting_message_printed = True
            time.sleep(60)

        # Wait for 2 additional mins in case the file exists but is not ready for reading
        ckpt_path = os.path.join(checkpoint_dir, f'checkpoint_{ckpt}.pth')
        try:
            state = restore_checkpoint(ckpt_path, state, device=config.device)
        except:
            time.sleep(60)
            try:
                state = restore_checkpoint(ckpt_path,
                                           state,
                                           device=config.device)
            except:
                time.sleep(120)
                state = restore_checkpoint(ckpt_path,
                                           state,
                                           device=config.device)
        ema.copy_to(score_model.parameters())
        # Compute the loss function on the full evaluation dataset if loss computation is enabled
        if config.eval.enable_loss:
            all_losses = []
            eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
            for i, batch in enumerate(eval_iter):
                eval_batch = torch.from_numpy(batch['image']._numpy()).to(
                    config.device).float()
                eval_batch = eval_batch.permute(0, 3, 1, 2)
                eval_batch = scaler(eval_batch)
                eval_loss = eval_step(state, eval_batch)
                all_losses.append(eval_loss.item())
                if (i + 1) % 1000 == 0:
                    logging.info("Finished %dth step loss evaluation" %
                                 (i + 1))

            # Save loss values to disk or Google Cloud Storage
            all_losses = np.asarray(all_losses)
            with tf.io.gfile.GFile(
                    os.path.join(eval_dir, f"ckpt_{ckpt}_loss.npz"),
                    "wb") as fout:
                io_buffer = io.BytesIO()
                np.savez_compressed(io_buffer,
                                    all_losses=all_losses,
                                    mean_loss=all_losses.mean())
                fout.write(io_buffer.getvalue())

        # Compute log-likelihoods (bits/dim) if enabled
        if config.eval.enable_bpd:
            bpds = []
            for repeat in range(bpd_num_repeats):
                bpd_iter = iter(ds_bpd)  # pytype: disable=wrong-arg-types
                for batch_id in range(len(ds_bpd)):
                    batch = next(bpd_iter)
                    eval_batch = torch.from_numpy(batch['image']._numpy()).to(
                        config.device).float()
                    eval_batch = eval_batch.permute(0, 3, 1, 2)
                    eval_batch = scaler(eval_batch)
                    bpd = likelihood_fn(score_model, eval_batch)[0]
                    bpd = bpd.detach().cpu().numpy().reshape(-1)
                    bpds.extend(bpd)
                    logging.info(
                        "ckpt: %d, repeat: %d, batch: %d, mean bpd: %6f" %
                        (ckpt, repeat, batch_id, np.mean(np.asarray(bpds))))
                    bpd_round_id = batch_id + len(ds_bpd) * repeat
                    # Save bits/dim to disk or Google Cloud Storage
                    with tf.io.gfile.GFile(
                            os.path.join(
                                eval_dir,
                                f"{config.eval.bpd_dataset}_ckpt_{ckpt}_bpd_{bpd_round_id}.npz"
                            ), "wb") as fout:
                        io_buffer = io.BytesIO()
                        np.savez_compressed(io_buffer, bpd)
                        fout.write(io_buffer.getvalue())

        # Generate samples and compute IS/FID/KID when enabled
        if config.eval.enable_sampling:
            num_sampling_rounds = config.eval.num_samples // config.eval.batch_size + 1
            for r in range(num_sampling_rounds):
                logging.info("sampling -- ckpt: %d, round: %d" % (ckpt, r))

                # Directory to save samples. Different for each host to avoid writing conflicts
                this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}")
                tf.io.gfile.makedirs(this_sample_dir)
                samples, n = sampling_fn(score_model)
                samples = np.clip(
                    samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0,
                    255).astype(np.uint8)
                samples = samples.reshape(
                    (-1, config.data.image_size, config.data.image_size,
                     config.data.num_channels))
                # Write samples to disk or Google Cloud Storage
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, f"samples_{r}.npz"),
                        "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer, samples=samples)
                    fout.write(io_buffer.getvalue())

                # Force garbage collection before calling TensorFlow code for Inception network
                gc.collect()
                latents = evaluation.run_inception_distributed(
                    samples, inception_model, inceptionv3=inceptionv3)
                # Force garbage collection again before returning to JAX code
                gc.collect()
                # Save latent represents of the Inception network to disk or Google Cloud Storage
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, f"statistics_{r}.npz"),
                        "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer,
                                        pool_3=latents["pool_3"],
                                        logits=latents["logits"])
                    fout.write(io_buffer.getvalue())

            # Compute inception scores, FIDs and KIDs.
            # Load all statistics that have been previously computed and saved for each host
            all_logits = []
            all_pools = []
            this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}")
            stats = tf.io.gfile.glob(
                os.path.join(this_sample_dir, "statistics_*.npz"))
            for stat_file in stats:
                with tf.io.gfile.GFile(stat_file, "rb") as fin:
                    stat = np.load(fin)
                    if not inceptionv3:
                        all_logits.append(stat["logits"])
                    all_pools.append(stat["pool_3"])

            if not inceptionv3:
                all_logits = np.concatenate(all_logits,
                                            axis=0)[:config.eval.num_samples]
            all_pools = np.concatenate(all_pools,
                                       axis=0)[:config.eval.num_samples]

            # Load pre-computed dataset statistics.
            data_stats = evaluation.load_dataset_stats(config)
            data_pools = data_stats["pool_3"]

            # Compute FID/KID/IS on all samples together.
            if not inceptionv3:
                inception_score = tfgan.eval.classifier_score_from_logits(
                    all_logits)
            else:
                inception_score = -1

            fid = tfgan.eval.frechet_classifier_distance_from_activations(
                data_pools, all_pools)
            # Hack to get tfgan KID work for eager execution.
            tf_data_pools = tf.convert_to_tensor(data_pools)
            tf_all_pools = tf.convert_to_tensor(all_pools)
            kid = tfgan.eval.kernel_classifier_distance_from_activations(
                tf_data_pools, tf_all_pools).numpy()
            del tf_data_pools, tf_all_pools

            logging.info(
                "ckpt-%d --- inception_score: %.6e, FID: %.6e, KID: %.6e" %
                (ckpt, inception_score, fid, kid))

            with tf.io.gfile.GFile(
                    os.path.join(eval_dir, f"report_{ckpt}.npz"), "wb") as f:
                io_buffer = io.BytesIO()
                np.savez_compressed(io_buffer,
                                    IS=inception_score,
                                    fid=fid,
                                    kid=kid)
                f.write(io_buffer.getvalue())
예제 #4
0
def evaluate(config, workdir, eval_folder="eval"):
    """Evaluate trained models.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints.
    eval_folder: The subfolder for storing evaluation results. Default to
      "eval".
  """
    # Create directory to eval_folder
    eval_dir = os.path.join(workdir, eval_folder)
    tf.io.gfile.makedirs(eval_dir)

    rng = jax.random.PRNGKey(config.seed + 1)

    # Build data pipeline
    train_ds, eval_ds, _ = datasets.get_dataset(
        config,
        additional_dim=1,
        uniform_dequantization=config.data.uniform_dequantization,
        evaluation=True)

    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Initialize model
    rng, model_rng = jax.random.split(rng)
    score_model, init_model_state, initial_params = mutils.init_model(
        model_rng, config)
    optimizer = losses.get_optimizer(config).create(initial_params)
    state = mutils.State(step=0,
                         optimizer=optimizer,
                         lr=config.optim.lr,
                         model_state=init_model_state,
                         ema_rate=config.model.ema_rate,
                         params_ema=initial_params,
                         rng=rng)  # pytype: disable=wrong-keyword-args

    checkpoint_dir = os.path.join(workdir, "checkpoints")

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min,
                            beta_max=config.model.beta_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min,
                               beta_max=config.model.beta_max,
                               N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min,
                            sigma_max=config.model.sigma_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    # Create the one-step evaluation function when loss computation is enabled
    if config.eval.enable_loss:
        optimize_fn = losses.optimization_manager(config)
        continuous = config.training.continuous
        likelihood_weighting = config.training.likelihood_weighting

        reduce_mean = config.training.reduce_mean
        eval_step = losses.get_step_fn(
            sde,
            score_model,
            train=False,
            optimize_fn=optimize_fn,
            reduce_mean=reduce_mean,
            continuous=continuous,
            likelihood_weighting=likelihood_weighting)
        # Pmap (and jit-compile) multiple evaluation steps together for faster execution
        p_eval_step = jax.pmap(functools.partial(jax.lax.scan, eval_step),
                               axis_name='batch',
                               donate_argnums=1)

    # Create data loaders for likelihood evaluation. Only evaluate on uniformly dequantized data
    train_ds_bpd, eval_ds_bpd, _ = datasets.get_dataset(
        config,
        additional_dim=None,
        uniform_dequantization=True,
        evaluation=True)
    if config.eval.bpd_dataset.lower() == 'train':
        ds_bpd = train_ds_bpd
        bpd_num_repeats = 1
    elif config.eval.bpd_dataset.lower() == 'test':
        # Go over the dataset 5 times when computing likelihood on the test dataset
        ds_bpd = eval_ds_bpd
        bpd_num_repeats = 5
    else:
        raise ValueError(
            f"No bpd dataset {config.eval.bpd_dataset} recognized.")

    # Build the likelihood computation function when likelihood is enabled
    if config.eval.enable_bpd:
        likelihood_fn = likelihood.get_likelihood_fn(sde, score_model,
                                                     inverse_scaler)

    # Build the sampling function when sampling is enabled
    if config.eval.enable_sampling:
        sampling_shape = (config.eval.batch_size // jax.local_device_count(),
                          config.data.image_size, config.data.image_size,
                          config.data.num_channels)
        sampling_fn = sampling.get_sampling_fn(config, sde, score_model,
                                               sampling_shape, inverse_scaler,
                                               sampling_eps)

    # Create different random states for different hosts in a multi-host environment (e.g., TPU pods)
    rng = jax.random.fold_in(rng, jax.host_id())

    # A data class for storing intermediate results to resume evaluation after pre-emption
    @flax.struct.dataclass
    class EvalMeta:
        ckpt_id: int
        sampling_round_id: int
        bpd_round_id: int
        rng: Any

    # Add one additional round to get the exact number of samples as required.
    num_sampling_rounds = config.eval.num_samples // config.eval.batch_size + 1
    num_bpd_rounds = len(ds_bpd) * bpd_num_repeats

    # Restore evaluation after pre-emption
    eval_meta = EvalMeta(ckpt_id=config.eval.begin_ckpt,
                         sampling_round_id=-1,
                         bpd_round_id=-1,
                         rng=rng)
    eval_meta = checkpoints.restore_checkpoint(eval_dir,
                                               eval_meta,
                                               step=None,
                                               prefix=f"meta_{jax.host_id()}_")

    if eval_meta.bpd_round_id < num_bpd_rounds - 1:
        begin_ckpt = eval_meta.ckpt_id
        begin_bpd_round = eval_meta.bpd_round_id + 1
        begin_sampling_round = 0

    elif eval_meta.sampling_round_id < num_sampling_rounds - 1:
        begin_ckpt = eval_meta.ckpt_id
        begin_bpd_round = num_bpd_rounds
        begin_sampling_round = eval_meta.sampling_round_id + 1

    else:
        begin_ckpt = eval_meta.ckpt_id + 1
        begin_bpd_round = 0
        begin_sampling_round = 0

    rng = eval_meta.rng

    # Use inceptionV3 for images with resolution higher than 256.
    inceptionv3 = config.data.image_size >= 256
    inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)

    logging.info("begin checkpoint: %d" % (begin_ckpt, ))
    for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):
        # Wait if the target checkpoint doesn't exist yet
        waiting_message_printed = False
        ckpt_filename = os.path.join(checkpoint_dir,
                                     "checkpoint_{}".format(ckpt))
        while not tf.io.gfile.exists(ckpt_filename):
            if not waiting_message_printed and jax.host_id() == 0:
                logging.warning("Waiting for the arrival of checkpoint_%d" %
                                (ckpt, ))
                waiting_message_printed = True
            time.sleep(60)

        # Wait for 2 additional mins in case the file exists but is not ready for reading
        try:
            state = checkpoints.restore_checkpoint(checkpoint_dir,
                                                   state,
                                                   step=ckpt)
        except:
            time.sleep(60)
            try:
                state = checkpoints.restore_checkpoint(checkpoint_dir,
                                                       state,
                                                       step=ckpt)
            except:
                time.sleep(120)
                state = checkpoints.restore_checkpoint(checkpoint_dir,
                                                       state,
                                                       step=ckpt)

        # Replicate the training state for executing on multiple devices
        pstate = flax.jax_utils.replicate(state)
        # Compute the loss function on the full evaluation dataset if loss computation is enabled
        if config.eval.enable_loss:
            all_losses = []
            eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
            for i, batch in enumerate(eval_iter):
                eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), batch)  # pylint: disable=protected-access
                rng, *next_rng = jax.random.split(
                    rng, num=jax.local_device_count() + 1)
                next_rng = jnp.asarray(next_rng)
                (_, _), p_eval_loss = p_eval_step((next_rng, pstate),
                                                  eval_batch)
                eval_loss = flax.jax_utils.unreplicate(p_eval_loss)
                all_losses.extend(eval_loss)
                if (i + 1) % 1000 == 0 and jax.host_id() == 0:
                    logging.info("Finished %dth step loss evaluation" %
                                 (i + 1))

            # Save loss values to disk or Google Cloud Storage
            all_losses = jnp.asarray(all_losses)
            with tf.io.gfile.GFile(
                    os.path.join(eval_dir, f"ckpt_{ckpt}_loss.npz"),
                    "wb") as fout:
                io_buffer = io.BytesIO()
                np.savez_compressed(io_buffer,
                                    all_losses=all_losses,
                                    mean_loss=all_losses.mean())
                fout.write(io_buffer.getvalue())

        # Compute log-likelihoods (bits/dim) if enabled
        if config.eval.enable_bpd:
            bpds = []
            begin_repeat_id = begin_bpd_round // len(ds_bpd)
            begin_batch_id = begin_bpd_round % len(ds_bpd)
            # Repeat multiple times to reduce variance when needed
            for repeat in range(begin_repeat_id, bpd_num_repeats):
                bpd_iter = iter(ds_bpd)  # pytype: disable=wrong-arg-types
                for _ in range(begin_batch_id):
                    next(bpd_iter)
                for batch_id in range(begin_batch_id, len(ds_bpd)):
                    batch = next(bpd_iter)
                    eval_batch = jax.tree_map(lambda x: scaler(x._numpy()),
                                              batch)
                    rng, *step_rng = jax.random.split(
                        rng,
                        jax.local_device_count() + 1)
                    step_rng = jnp.asarray(step_rng)
                    bpd = likelihood_fn(step_rng, pstate,
                                        eval_batch['image'])[0]
                    bpd = bpd.reshape(-1)
                    bpds.extend(bpd)
                    logging.info(
                        "ckpt: %d, repeat: %d, batch: %d, mean bpd: %6f" %
                        (ckpt, repeat, batch_id, jnp.mean(jnp.asarray(bpds))))
                    bpd_round_id = batch_id + len(ds_bpd) * repeat
                    # Save bits/dim to disk or Google Cloud Storage
                    with tf.io.gfile.GFile(
                            os.path.join(
                                eval_dir,
                                f"{config.eval.bpd_dataset}_ckpt_{ckpt}_bpd_{bpd_round_id}.npz"
                            ), "wb") as fout:
                        io_buffer = io.BytesIO()
                        np.savez_compressed(io_buffer, bpd)
                        fout.write(io_buffer.getvalue())

                    eval_meta = eval_meta.replace(ckpt_id=ckpt,
                                                  bpd_round_id=bpd_round_id,
                                                  rng=rng)
                    # Save intermediate states to resume evaluation after pre-emption
                    checkpoints.save_checkpoint(
                        eval_dir,
                        eval_meta,
                        step=ckpt * (num_sampling_rounds + num_bpd_rounds) +
                        bpd_round_id,
                        keep=1,
                        prefix=f"meta_{jax.host_id()}_")
        else:
            # Skip likelihood computation and save intermediate states for pre-emption
            eval_meta = eval_meta.replace(ckpt_id=ckpt,
                                          bpd_round_id=num_bpd_rounds - 1)
            checkpoints.save_checkpoint(
                eval_dir,
                eval_meta,
                step=ckpt * (num_sampling_rounds + num_bpd_rounds) +
                num_bpd_rounds - 1,
                keep=1,
                prefix=f"meta_{jax.host_id()}_")

        # Generate samples and compute IS/FID/KID when enabled
        if config.eval.enable_sampling:
            state = jax.device_put(state)
            # Run sample generation for multiple rounds to create enough samples
            # Designed to be pre-emption safe. Automatically resumes when interrupted
            for r in range(begin_sampling_round, num_sampling_rounds):
                if jax.host_id() == 0:
                    logging.info("sampling -- ckpt: %d, round: %d" % (ckpt, r))

                # Directory to save samples. Different for each host to avoid writing conflicts
                this_sample_dir = os.path.join(
                    eval_dir, f"ckpt_{ckpt}_host_{jax.host_id()}")
                tf.io.gfile.makedirs(this_sample_dir)

                rng, *sample_rng = jax.random.split(
                    rng,
                    jax.local_device_count() + 1)
                sample_rng = jnp.asarray(sample_rng)
                samples, n = sampling_fn(sample_rng, pstate)
                samples = np.clip(samples * 255., 0, 255).astype(np.uint8)
                samples = samples.reshape(
                    (-1, config.data.image_size, config.data.image_size,
                     config.data.num_channels))
                # Write samples to disk or Google Cloud Storage
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, f"samples_{r}.npz"),
                        "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer, samples=samples)
                    fout.write(io_buffer.getvalue())

                # Force garbage collection before calling TensorFlow code for Inception network
                gc.collect()
                latents = evaluation.run_inception_distributed(
                    samples, inception_model, inceptionv3=inceptionv3)
                # Force garbage collection again before returning to JAX code
                gc.collect()
                # Save latent represents of the Inception network to disk or Google Cloud Storage
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, f"statistics_{r}.npz"),
                        "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer,
                                        pool_3=latents["pool_3"],
                                        logits=latents["logits"])
                    fout.write(io_buffer.getvalue())

                # Update the intermediate evaluation state
                eval_meta = eval_meta.replace(ckpt_id=ckpt,
                                              sampling_round_id=r,
                                              rng=rng)
                # Save an intermediate checkpoint directly if not the last round.
                # Otherwise save eval_meta after computing the Inception scores and FIDs
                if r < num_sampling_rounds - 1:
                    checkpoints.save_checkpoint(
                        eval_dir,
                        eval_meta,
                        step=ckpt * (num_sampling_rounds + num_bpd_rounds) +
                        r + num_bpd_rounds,
                        keep=1,
                        prefix=f"meta_{jax.host_id()}_")

            # Compute inception scores, FIDs and KIDs.
            if jax.host_id() == 0:
                # Load all statistics that have been previously computed and saved for each host
                all_logits = []
                all_pools = []
                for host in range(jax.host_count()):
                    this_sample_dir = os.path.join(eval_dir,
                                                   f"ckpt_{ckpt}_host_{host}")

                    stats = tf.io.gfile.glob(
                        os.path.join(this_sample_dir, "statistics_*.npz"))
                    wait_message = False
                    while len(stats) < num_sampling_rounds:
                        if not wait_message:
                            logging.warning(
                                "Waiting for statistics on host %d" % (host, ))
                            wait_message = True
                        stats = tf.io.gfile.glob(
                            os.path.join(this_sample_dir, "statistics_*.npz"))
                        time.sleep(30)

                    for stat_file in stats:
                        with tf.io.gfile.GFile(stat_file, "rb") as fin:
                            stat = np.load(fin)
                            if not inceptionv3:
                                all_logits.append(stat["logits"])
                            all_pools.append(stat["pool_3"])

                if not inceptionv3:
                    all_logits = np.concatenate(
                        all_logits, axis=0)[:config.eval.num_samples]
                all_pools = np.concatenate(all_pools,
                                           axis=0)[:config.eval.num_samples]

                # Load pre-computed dataset statistics.
                data_stats = evaluation.load_dataset_stats(config)
                data_pools = data_stats["pool_3"]

                # Compute FID/KID/IS on all samples together.
                if not inceptionv3:
                    inception_score = tfgan.eval.classifier_score_from_logits(
                        all_logits)
                else:
                    inception_score = -1

                fid = tfgan.eval.frechet_classifier_distance_from_activations(
                    data_pools, all_pools)
                # Hack to get tfgan KID work for eager execution.
                tf_data_pools = tf.convert_to_tensor(data_pools)
                tf_all_pools = tf.convert_to_tensor(all_pools)
                kid = tfgan.eval.kernel_classifier_distance_from_activations(
                    tf_data_pools, tf_all_pools).numpy()
                del tf_data_pools, tf_all_pools

                logging.info(
                    "ckpt-%d --- inception_score: %.6e, FID: %.6e, KID: %.6e" %
                    (ckpt, inception_score, fid, kid))

                with tf.io.gfile.GFile(
                        os.path.join(eval_dir, f"report_{ckpt}.npz"),
                        "wb") as f:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer,
                                        IS=inception_score,
                                        fid=fid,
                                        kid=kid)
                    f.write(io_buffer.getvalue())
            else:
                # For host_id() != 0.
                # Use file existence to emulate synchronization across hosts
                while not tf.io.gfile.exists(
                        os.path.join(eval_dir, f"report_{ckpt}.npz")):
                    time.sleep(1.)

            # Save eval_meta after computing IS/KID/FID to mark the end of evaluation for this checkpoint
            checkpoints.save_checkpoint(
                eval_dir,
                eval_meta,
                step=ckpt * (num_sampling_rounds + num_bpd_rounds) + r +
                num_bpd_rounds,
                keep=1,
                prefix=f"meta_{jax.host_id()}_")

        else:
            # Skip sampling and save intermediate evaluation states for pre-emption
            eval_meta = eval_meta.replace(
                ckpt_id=ckpt,
                sampling_round_id=num_sampling_rounds - 1,
                rng=rng)
            checkpoints.save_checkpoint(
                eval_dir,
                eval_meta,
                step=ckpt * (num_sampling_rounds + num_bpd_rounds) +
                num_sampling_rounds - 1 + num_bpd_rounds,
                keep=1,
                prefix=f"meta_{jax.host_id()}_")

        begin_bpd_round = 0
        begin_sampling_round = 0

    # Remove all meta files after finishing evaluation
    meta_files = tf.io.gfile.glob(
        os.path.join(eval_dir, f"meta_{jax.host_id()}_*"))
    for file in meta_files:
        tf.io.gfile.remove(file)