Exemple #1
0
    def train(self):
        dataset, test_dataset = get_dataset(self.args, self.config)
        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                num_workers=self.config.data.num_workers)
        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=self.config.data.num_workers, drop_last=True)
        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_logger = self.config.tb_logger

        score = get_model(self.config)

        score = torch.nn.DataParallel(score)
        optimizer = get_optimizer(self.config, score.parameters())

        start_epoch = 0
        step = 0

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(score)

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log_path, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            ### Make sure we can resume with different eps
            states[1]['param_groups'][0]['eps'] = self.config.optim.eps
            optimizer.load_state_dict(states[1])
            start_epoch = states[2]
            step = states[3]
            if self.config.model.ema:
                ema_helper.load_state_dict(states[4])

        sigmas = get_sigmas(self.config)

        if self.config.training.log_all_sigmas:
            ### Commented out training time logging to save time.
            test_loss_per_sigma = [None for _ in range(len(sigmas))]

            def hook(loss, labels):
                # for i in range(len(sigmas)):
                #     if torch.any(labels == i):
                #         test_loss_per_sigma[i] = torch.mean(loss[labels == i])
                pass

            def tb_hook():
                # for i in range(len(sigmas)):
                #     if test_loss_per_sigma[i] is not None:
                #         tb_logger.add_scalar('test_loss_sigma_{}'.format(i), test_loss_per_sigma[i],
                #                              global_step=step)
                pass

            def test_hook(loss, labels):
                for i in range(len(sigmas)):
                    if torch.any(labels == i):
                        test_loss_per_sigma[i] = torch.mean(loss[labels == i])

            def test_tb_hook():
                for i in range(len(sigmas)):
                    if test_loss_per_sigma[i] is not None:
                        tb_logger.add_scalar('test_loss_sigma_{}'.format(i), test_loss_per_sigma[i],
                                             global_step=step)

        else:
            hook = test_hook = None

            def tb_hook():
                pass

            def test_tb_hook():
                pass

        for epoch in range(start_epoch, self.config.training.n_epochs):
            for i, (X, y) in enumerate(dataloader):
                score.train()
                step += 1

                X = X.to(self.config.device)
                X = data_transform(self.config, X)

                loss = anneal_dsm_score_estimation(score, X, sigmas, None,
                                                   self.config.training.anneal_power,
                                                   hook)
                tb_logger.add_scalar('loss', loss, global_step=step)
                tb_hook()

                logging.info("step: {}, loss: {}".format(step, loss.item()))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if self.config.model.ema:
                    ema_helper.update(score)

                if step >= self.config.training.n_iters:
                    return 0

                if step % 100 == 0:
                    if self.config.model.ema:
                        test_score = ema_helper.ema_copy(score)
                    else:
                        test_score = score

                    test_score.eval()
                    try:
                        test_X, test_y = next(test_iter)
                    except StopIteration:
                        test_iter = iter(test_loader)
                        test_X, test_y = next(test_iter)

                    test_X = test_X.to(self.config.device)
                    test_X = data_transform(self.config, test_X)

                    with torch.no_grad():
                        test_dsm_loss = anneal_dsm_score_estimation(test_score, test_X, sigmas, None,
                                                                    self.config.training.anneal_power,
                                                                    hook=test_hook)
                        tb_logger.add_scalar('test_loss', test_dsm_loss, global_step=step)
                        test_tb_hook()
                        logging.info("step: {}, test_loss: {}".format(step, test_dsm_loss.item()))

                        del test_score

                if step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                        epoch,
                        step,
                    ]
                    if self.config.model.ema:
                        states.append(ema_helper.state_dict())

                    torch.save(states, os.path.join(self.args.log_path, 'checkpoint_{}.pth'.format(step)))
                    torch.save(states, os.path.join(self.args.log_path, 'checkpoint.pth'))

                    if self.config.training.snapshot_sampling:
                        if self.config.model.ema:
                            test_score = ema_helper.ema_copy(score)
                        else:
                            test_score = score

                        test_score.eval()

                        ## Different part from NeurIPS 2019.
                        ## Random state will be affected because of sampling during training time.
                        init_samples = torch.rand(36, self.config.data.channels,
                                                  self.config.data.image_size, self.config.data.image_size,
                                                  device=self.config.device)
                        init_samples = data_transform(self.config, init_samples)

                        all_samples = anneal_Langevin_dynamics(init_samples, test_score, sigmas.cpu().numpy(),
                                                               self.config.sampling.n_steps_each,
                                                               self.config.sampling.step_lr,
                                                               final_only=True, verbose=True,
                                                               denoise=self.config.sampling.denoise)

                        sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels,
                                                      self.config.data.image_size,
                                                      self.config.data.image_size)

                        sample = inverse_data_transform(self.config, sample)

                        image_grid = make_grid(sample, 6)
                        save_image(image_grid,
                                   os.path.join(self.args.log_sample_path, 'image_grid_{}.png'.format(step)))
                        torch.save(sample, os.path.join(self.args.log_sample_path, 'samples_{}.pth'.format(step)))

                        del test_score
                        del all_samples
def evaluate(config, workdir, eval_folder="eval"):
    """Evaluate trained models.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints.
    eval_folder: The subfolder for storing evaluation results. Default to
      "eval".
  """
    # Create directory to eval_folder
    eval_dir = os.path.join(workdir, eval_folder)
    tf.io.gfile.makedirs(eval_dir)

    # Build data pipeline
    train_ds, eval_ds, _ = datasets.get_dataset(
        config,
        uniform_dequantization=config.data.uniform_dequantization,
        evaluation=True)

    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Initialize model
    score_model = mutils.create_model(config)
    optimizer = losses.get_optimizer(config, score_model.parameters())
    ema = ExponentialMovingAverage(score_model.parameters(),
                                   decay=config.model.ema_rate)
    state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)

    checkpoint_dir = os.path.join(workdir, "checkpoints")

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min,
                            beta_max=config.model.beta_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min,
                               beta_max=config.model.beta_max,
                               N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min,
                            sigma_max=config.model.sigma_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    # Create the one-step evaluation function when loss computation is enabled
    if config.eval.enable_loss:
        optimize_fn = losses.optimization_manager(config)
        continuous = config.training.continuous
        likelihood_weighting = config.training.likelihood_weighting

        reduce_mean = config.training.reduce_mean
        eval_step = losses.get_step_fn(
            sde,
            train=False,
            optimize_fn=optimize_fn,
            reduce_mean=reduce_mean,
            continuous=continuous,
            likelihood_weighting=likelihood_weighting)

    # Create data loaders for likelihood evaluation. Only evaluate on uniformly dequantized data
    train_ds_bpd, eval_ds_bpd, _ = datasets.get_dataset(
        config, uniform_dequantization=True, evaluation=True)
    if config.eval.bpd_dataset.lower() == 'train':
        ds_bpd = train_ds_bpd
        bpd_num_repeats = 1
    elif config.eval.bpd_dataset.lower() == 'test':
        # Go over the dataset 5 times when computing likelihood on the test dataset
        ds_bpd = eval_ds_bpd
        bpd_num_repeats = 5
    else:
        raise ValueError(
            f"No bpd dataset {config.eval.bpd_dataset} recognized.")

    # Build the likelihood computation function when likelihood is enabled
    if config.eval.enable_bpd:
        likelihood_fn = likelihood.get_likelihood_fn(sde, inverse_scaler)

    # Build the sampling function when sampling is enabled
    if config.eval.enable_sampling:
        sampling_shape = (config.eval.batch_size, config.data.num_channels,
                          config.data.image_size, config.data.image_size)
        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape,
                                               inverse_scaler, sampling_eps)

    # Use inceptionV3 for images with resolution higher than 256.
    inceptionv3 = config.data.image_size >= 256
    inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)

    begin_ckpt = config.eval.begin_ckpt
    logging.info("begin checkpoint: %d" % (begin_ckpt, ))
    for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):
        # Wait if the target checkpoint doesn't exist yet
        waiting_message_printed = False
        ckpt_filename = os.path.join(checkpoint_dir,
                                     "checkpoint_{}.pth".format(ckpt))
        while not tf.io.gfile.exists(ckpt_filename):
            if not waiting_message_printed:
                logging.warning("Waiting for the arrival of checkpoint_%d" %
                                (ckpt, ))
                waiting_message_printed = True
            time.sleep(60)

        # Wait for 2 additional mins in case the file exists but is not ready for reading
        ckpt_path = os.path.join(checkpoint_dir, f'checkpoint_{ckpt}.pth')
        try:
            state = restore_checkpoint(ckpt_path, state, device=config.device)
        except:
            time.sleep(60)
            try:
                state = restore_checkpoint(ckpt_path,
                                           state,
                                           device=config.device)
            except:
                time.sleep(120)
                state = restore_checkpoint(ckpt_path,
                                           state,
                                           device=config.device)
        ema.copy_to(score_model.parameters())
        # Compute the loss function on the full evaluation dataset if loss computation is enabled
        if config.eval.enable_loss:
            all_losses = []
            eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
            for i, batch in enumerate(eval_iter):
                eval_batch = torch.from_numpy(batch['image']._numpy()).to(
                    config.device).float()
                eval_batch = eval_batch.permute(0, 3, 1, 2)
                eval_batch = scaler(eval_batch)
                eval_loss = eval_step(state, eval_batch)
                all_losses.append(eval_loss.item())
                if (i + 1) % 1000 == 0:
                    logging.info("Finished %dth step loss evaluation" %
                                 (i + 1))

            # Save loss values to disk or Google Cloud Storage
            all_losses = np.asarray(all_losses)
            with tf.io.gfile.GFile(
                    os.path.join(eval_dir, f"ckpt_{ckpt}_loss.npz"),
                    "wb") as fout:
                io_buffer = io.BytesIO()
                np.savez_compressed(io_buffer,
                                    all_losses=all_losses,
                                    mean_loss=all_losses.mean())
                fout.write(io_buffer.getvalue())

        # Compute log-likelihoods (bits/dim) if enabled
        if config.eval.enable_bpd:
            bpds = []
            for repeat in range(bpd_num_repeats):
                bpd_iter = iter(ds_bpd)  # pytype: disable=wrong-arg-types
                for batch_id in range(len(ds_bpd)):
                    batch = next(bpd_iter)
                    eval_batch = torch.from_numpy(batch['image']._numpy()).to(
                        config.device).float()
                    eval_batch = eval_batch.permute(0, 3, 1, 2)
                    eval_batch = scaler(eval_batch)
                    bpd = likelihood_fn(score_model, eval_batch)[0]
                    bpd = bpd.detach().cpu().numpy().reshape(-1)
                    bpds.extend(bpd)
                    logging.info(
                        "ckpt: %d, repeat: %d, batch: %d, mean bpd: %6f" %
                        (ckpt, repeat, batch_id, np.mean(np.asarray(bpds))))
                    bpd_round_id = batch_id + len(ds_bpd) * repeat
                    # Save bits/dim to disk or Google Cloud Storage
                    with tf.io.gfile.GFile(
                            os.path.join(
                                eval_dir,
                                f"{config.eval.bpd_dataset}_ckpt_{ckpt}_bpd_{bpd_round_id}.npz"
                            ), "wb") as fout:
                        io_buffer = io.BytesIO()
                        np.savez_compressed(io_buffer, bpd)
                        fout.write(io_buffer.getvalue())

        # Generate samples and compute IS/FID/KID when enabled
        if config.eval.enable_sampling:
            num_sampling_rounds = config.eval.num_samples // config.eval.batch_size + 1
            for r in range(num_sampling_rounds):
                logging.info("sampling -- ckpt: %d, round: %d" % (ckpt, r))

                # Directory to save samples. Different for each host to avoid writing conflicts
                this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}")
                tf.io.gfile.makedirs(this_sample_dir)
                samples, n = sampling_fn(score_model)
                samples = np.clip(
                    samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0,
                    255).astype(np.uint8)
                samples = samples.reshape(
                    (-1, config.data.image_size, config.data.image_size,
                     config.data.num_channels))
                # Write samples to disk or Google Cloud Storage
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, f"samples_{r}.npz"),
                        "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer, samples=samples)
                    fout.write(io_buffer.getvalue())

                # Force garbage collection before calling TensorFlow code for Inception network
                gc.collect()
                latents = evaluation.run_inception_distributed(
                    samples, inception_model, inceptionv3=inceptionv3)
                # Force garbage collection again before returning to JAX code
                gc.collect()
                # Save latent represents of the Inception network to disk or Google Cloud Storage
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, f"statistics_{r}.npz"),
                        "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer,
                                        pool_3=latents["pool_3"],
                                        logits=latents["logits"])
                    fout.write(io_buffer.getvalue())

            # Compute inception scores, FIDs and KIDs.
            # Load all statistics that have been previously computed and saved for each host
            all_logits = []
            all_pools = []
            this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}")
            stats = tf.io.gfile.glob(
                os.path.join(this_sample_dir, "statistics_*.npz"))
            for stat_file in stats:
                with tf.io.gfile.GFile(stat_file, "rb") as fin:
                    stat = np.load(fin)
                    if not inceptionv3:
                        all_logits.append(stat["logits"])
                    all_pools.append(stat["pool_3"])

            if not inceptionv3:
                all_logits = np.concatenate(all_logits,
                                            axis=0)[:config.eval.num_samples]
            all_pools = np.concatenate(all_pools,
                                       axis=0)[:config.eval.num_samples]

            # Load pre-computed dataset statistics.
            data_stats = evaluation.load_dataset_stats(config)
            data_pools = data_stats["pool_3"]

            # Compute FID/KID/IS on all samples together.
            if not inceptionv3:
                inception_score = tfgan.eval.classifier_score_from_logits(
                    all_logits)
            else:
                inception_score = -1

            fid = tfgan.eval.frechet_classifier_distance_from_activations(
                data_pools, all_pools)
            # Hack to get tfgan KID work for eager execution.
            tf_data_pools = tf.convert_to_tensor(data_pools)
            tf_all_pools = tf.convert_to_tensor(all_pools)
            kid = tfgan.eval.kernel_classifier_distance_from_activations(
                tf_data_pools, tf_all_pools).numpy()
            del tf_data_pools, tf_all_pools

            logging.info(
                "ckpt-%d --- inception_score: %.6e, FID: %.6e, KID: %.6e" %
                (ckpt, inception_score, fid, kid))

            with tf.io.gfile.GFile(
                    os.path.join(eval_dir, f"report_{ckpt}.npz"), "wb") as f:
                io_buffer = io.BytesIO()
                np.savez_compressed(io_buffer,
                                    IS=inception_score,
                                    fid=fid,
                                    kid=kid)
                f.write(io_buffer.getvalue())
def train(config, workdir):
    """Runs the training pipeline.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """

    # Create directories for experimental logs
    sample_dir = os.path.join(workdir, "samples")
    tf.io.gfile.makedirs(sample_dir)

    tb_dir = os.path.join(workdir, "tensorboard")
    tf.io.gfile.makedirs(tb_dir)
    writer = tensorboard.SummaryWriter(tb_dir)

    # Initialize model.
    score_model = mutils.create_model(config)
    ema = ExponentialMovingAverage(score_model.parameters(),
                                   decay=config.model.ema_rate)
    optimizer = losses.get_optimizer(config, score_model.parameters())
    state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)

    # Create checkpoints directory
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    # Intermediate checkpoints to resume training after pre-emption in cloud environments
    checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta",
                                       "checkpoint.pth")
    tf.io.gfile.makedirs(checkpoint_dir)
    tf.io.gfile.makedirs(os.path.dirname(checkpoint_meta_dir))
    # Resume training when intermediate checkpoints are detected
    state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
    initial_step = int(state['step'])

    # Build data iterators
    train_ds, eval_ds, _ = datasets.get_dataset(
        config, uniform_dequantization=config.data.uniform_dequantization)
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min,
                            beta_max=config.model.beta_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min,
                               beta_max=config.model.beta_max,
                               N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min,
                            sigma_max=config.model.sigma_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(config)
    continuous = config.training.continuous
    reduce_mean = config.training.reduce_mean
    likelihood_weighting = config.training.likelihood_weighting
    train_step_fn = losses.get_step_fn(
        sde,
        train=True,
        optimize_fn=optimize_fn,
        reduce_mean=reduce_mean,
        continuous=continuous,
        likelihood_weighting=likelihood_weighting)
    eval_step_fn = losses.get_step_fn(
        sde,
        train=False,
        optimize_fn=optimize_fn,
        reduce_mean=reduce_mean,
        continuous=continuous,
        likelihood_weighting=likelihood_weighting)

    # Building sampling functions
    if config.training.snapshot_sampling:
        sampling_shape = (config.training.batch_size, config.data.num_channels,
                          config.data.image_size, config.data.image_size)
        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape,
                                               inverse_scaler, sampling_eps)

    num_train_steps = config.training.n_iters

    # In case there are multiple hosts (e.g., TPU pods), only log to host 0
    logging.info("Starting training loop at step %d." % (initial_step, ))

    for step in range(initial_step, num_train_steps + 1):
        # Convert data to JAX arrays and normalize them. Use ._numpy() to avoid copy.
        batch = torch.from_numpy(next(train_iter)['image']._numpy()).to(
            config.device).float()
        batch = batch.permute(0, 3, 1, 2)
        batch = scaler(batch)
        # Execute one training step
        loss = train_step_fn(state, batch)
        if step % config.training.log_freq == 0:
            logging.info("step: %d, training_loss: %.5e" % (step, loss.item()))
            writer.add_scalar("training_loss", loss, step)

        # Save a temporary checkpoint to resume training after pre-emption periodically
        if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
            save_checkpoint(checkpoint_meta_dir, state)

        # Report the loss on an evaluation dataset periodically
        if step % config.training.eval_freq == 0:
            eval_batch = torch.from_numpy(
                next(eval_iter)['image']._numpy()).to(config.device).float()
            eval_batch = eval_batch.permute(0, 3, 1, 2)
            eval_batch = scaler(eval_batch)
            eval_loss = eval_step_fn(state, eval_batch)
            logging.info("step: %d, eval_loss: %.5e" %
                         (step, eval_loss.item()))
            writer.add_scalar("eval_loss", eval_loss.item(), step)

        # Save a checkpoint periodically and generate samples if needed
        if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
            # Save the checkpoint.
            save_step = step // config.training.snapshot_freq
            save_checkpoint(
                os.path.join(checkpoint_dir, f'checkpoint_{save_step}.pth'),
                state)

            # Generate and save samples
            if config.training.snapshot_sampling:
                ema.store(score_model.parameters())
                ema.copy_to(score_model.parameters())
                sample, n = sampling_fn(score_model)
                ema.restore(score_model.parameters())
                this_sample_dir = os.path.join(sample_dir,
                                               "iter_{}".format(step))
                tf.io.gfile.makedirs(this_sample_dir)
                nrow = int(np.sqrt(sample.shape[0]))
                image_grid = make_grid(sample, nrow, padding=2)
                sample = np.clip(
                    sample.permute(0, 2, 3, 1).cpu().numpy() * 255, 0,
                    255).astype(np.uint8)
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, "sample.np"),
                        "wb") as fout:
                    np.save(fout, sample)

                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, "sample.png"),
                        "wb") as fout:
                    save_image(image_grid, fout)
  config = configs.get_config()
  sde = subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
  sampling_eps = 1e-3

batch_size =   64#@param {"type":"integer"}
config.training.batch_size = batch_size
config.eval.batch_size = batch_size

random_seed = 0 #@param {"type": "integer"}

sigmas = mutils.get_sigmas(config)
scaler = datasets.get_data_scaler(config)
inverse_scaler = datasets.get_data_inverse_scaler(config)
score_model = mutils.create_model(config)

optimizer = get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(),
                               decay=config.model.ema_rate)
state = dict(step=0, optimizer=optimizer,
             model=score_model, ema=ema)

state = restore_checkpoint(ckpt_filename, state, config.device)
ema.copy_to(score_model.parameters())

#@title Visualization code

def image_grid(x):
  size = config.data.image_size
  channels = config.data.num_channels
  img = x.reshape(-1, size, size, channels)
  w = int(np.sqrt(img.shape[0]))
Exemple #5
0
 def get_optimizer(self, params, adv=False):
     return get_optimizer(self.config.optim, params, adv_opt=adv)
Exemple #6
0
def train(config, workdir):
    """Runs the training pipeline.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """

    # Create directories for experimental logs
    sample_dir = os.path.join(workdir, "samples")
    tf.io.gfile.makedirs(sample_dir)

    rng = jax.random.PRNGKey(config.seed)
    tb_dir = os.path.join(workdir, "tensorboard")
    tf.io.gfile.makedirs(tb_dir)
    if jax.host_id() == 0:
        writer = tensorboard.SummaryWriter(tb_dir)

    # Initialize model.
    rng, step_rng = jax.random.split(rng)
    score_model, init_model_state, initial_params = mutils.init_model(
        step_rng, config)
    optimizer = losses.get_optimizer(config).create(initial_params)
    state = mutils.State(step=0,
                         optimizer=optimizer,
                         lr=config.optim.lr,
                         model_state=init_model_state,
                         ema_rate=config.model.ema_rate,
                         params_ema=initial_params,
                         rng=rng)  # pytype: disable=wrong-keyword-args

    # Create checkpoints directory
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    # Intermediate checkpoints to resume training after pre-emption in cloud environments
    checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")
    tf.io.gfile.makedirs(checkpoint_dir)
    tf.io.gfile.makedirs(checkpoint_meta_dir)
    # Resume training when intermediate checkpoints are detected
    state = checkpoints.restore_checkpoint(checkpoint_meta_dir, state)
    # `state.step` is JAX integer on the GPU/TPU devices
    initial_step = int(state.step)
    rng = state.rng

    # Build data iterators
    train_ds, eval_ds, _ = datasets.get_dataset(
        config,
        additional_dim=config.training.n_jitted_steps,
        uniform_dequantization=config.data.uniform_dequantization)
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min,
                            beta_max=config.model.beta_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min,
                               beta_max=config.model.beta_max,
                               N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min,
                            sigma_max=config.model.sigma_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(config)
    continuous = config.training.continuous
    reduce_mean = config.training.reduce_mean
    likelihood_weighting = config.training.likelihood_weighting
    train_step_fn = losses.get_step_fn(
        sde,
        score_model,
        train=True,
        optimize_fn=optimize_fn,
        reduce_mean=reduce_mean,
        continuous=continuous,
        likelihood_weighting=likelihood_weighting)
    # Pmap (and jit-compile) multiple training steps together for faster running
    p_train_step = jax.pmap(functools.partial(jax.lax.scan, train_step_fn),
                            axis_name='batch',
                            donate_argnums=1)
    eval_step_fn = losses.get_step_fn(
        sde,
        score_model,
        train=False,
        optimize_fn=optimize_fn,
        reduce_mean=reduce_mean,
        continuous=continuous,
        likelihood_weighting=likelihood_weighting)
    # Pmap (and jit-compile) multiple evaluation steps together for faster running
    p_eval_step = jax.pmap(functools.partial(jax.lax.scan, eval_step_fn),
                           axis_name='batch',
                           donate_argnums=1)

    # Building sampling functions
    if config.training.snapshot_sampling:
        sampling_shape = (config.training.batch_size //
                          jax.local_device_count(), config.data.image_size,
                          config.data.image_size, config.data.num_channels)
        sampling_fn = sampling.get_sampling_fn(config, sde, score_model,
                                               sampling_shape, inverse_scaler,
                                               sampling_eps)

    # Replicate the training state to run on multiple devices
    pstate = flax_utils.replicate(state)
    num_train_steps = config.training.n_iters

    # In case there are multiple hosts (e.g., TPU pods), only log to host 0
    if jax.host_id() == 0:
        logging.info("Starting training loop at step %d." % (initial_step, ))
    rng = jax.random.fold_in(rng, jax.host_id())

    # JIT multiple training steps together for faster training
    n_jitted_steps = config.training.n_jitted_steps
    # Must be divisible by the number of steps jitted together
    assert config.training.log_freq % n_jitted_steps == 0 and \
           config.training.snapshot_freq_for_preemption % n_jitted_steps == 0 and \
           config.training.eval_freq % n_jitted_steps == 0 and \
           config.training.snapshot_freq % n_jitted_steps == 0, "Missing logs or checkpoints!"

    for step in range(initial_step, num_train_steps + 1,
                      config.training.n_jitted_steps):
        # Convert data to JAX arrays and normalize them. Use ._numpy() to avoid copy.
        batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter))  # pylint: disable=protected-access
        rng, *next_rng = jax.random.split(rng,
                                          num=jax.local_device_count() + 1)
        next_rng = jnp.asarray(next_rng)
        # Execute one training step
        (_, pstate), ploss = p_train_step((next_rng, pstate), batch)
        loss = flax.jax_utils.unreplicate(ploss).mean()
        # Log to console, file and tensorboard on host 0
        if jax.host_id() == 0 and step % config.training.log_freq == 0:
            logging.info("step: %d, training_loss: %.5e" % (step, loss))
            writer.scalar("training_loss", loss, step)

        # Save a temporary checkpoint to resume training after pre-emption periodically
        if step != 0 and step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id(
        ) == 0:
            saved_state = flax_utils.unreplicate(pstate)
            saved_state = saved_state.replace(rng=rng)
            checkpoints.save_checkpoint(
                checkpoint_meta_dir,
                saved_state,
                step=step // config.training.snapshot_freq_for_preemption,
                keep=1)

        # Report the loss on an evaluation dataset periodically
        if step % config.training.eval_freq == 0:
            eval_batch = jax.tree_map(lambda x: scaler(x._numpy()),
                                      next(eval_iter))  # pylint: disable=protected-access
            rng, *next_rng = jax.random.split(rng,
                                              num=jax.local_device_count() + 1)
            next_rng = jnp.asarray(next_rng)
            (_, _), peval_loss = p_eval_step((next_rng, pstate), eval_batch)
            eval_loss = flax.jax_utils.unreplicate(peval_loss).mean()
            if jax.host_id() == 0:
                logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss))
                writer.scalar("eval_loss", eval_loss, step)

        # Save a checkpoint periodically and generate samples if needed
        if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
            # Save the checkpoint.
            if jax.host_id() == 0:
                saved_state = flax_utils.unreplicate(pstate)
                saved_state = saved_state.replace(rng=rng)
                checkpoints.save_checkpoint(checkpoint_dir,
                                            saved_state,
                                            step=step //
                                            config.training.snapshot_freq,
                                            keep=np.inf)

            # Generate and save samples
            if config.training.snapshot_sampling:
                rng, *sample_rng = jax.random.split(
                    rng,
                    jax.local_device_count() + 1)
                sample_rng = jnp.asarray(sample_rng)
                sample, n = sampling_fn(sample_rng, pstate)
                this_sample_dir = os.path.join(
                    sample_dir, "iter_{}_host_{}".format(step, jax.host_id()))
                tf.io.gfile.makedirs(this_sample_dir)
                image_grid = sample.reshape((-1, *sample.shape[2:]))
                nrow = int(np.sqrt(image_grid.shape[0]))
                sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, "sample.np"),
                        "wb") as fout:
                    np.save(fout, sample)

                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, "sample.png"),
                        "wb") as fout:
                    utils.save_image(image_grid, fout, nrow=nrow, padding=2)
Exemple #7
0
def evaluate(config, workdir, eval_folder="eval"):
    """Evaluate trained models.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints.
    eval_folder: The subfolder for storing evaluation results. Default to
      "eval".
  """
    # Create directory to eval_folder
    eval_dir = os.path.join(workdir, eval_folder)
    tf.io.gfile.makedirs(eval_dir)

    rng = jax.random.PRNGKey(config.seed + 1)

    # Build data pipeline
    train_ds, eval_ds, _ = datasets.get_dataset(
        config,
        additional_dim=1,
        uniform_dequantization=config.data.uniform_dequantization,
        evaluation=True)

    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Initialize model
    rng, model_rng = jax.random.split(rng)
    score_model, init_model_state, initial_params = mutils.init_model(
        model_rng, config)
    optimizer = losses.get_optimizer(config).create(initial_params)
    state = mutils.State(step=0,
                         optimizer=optimizer,
                         lr=config.optim.lr,
                         model_state=init_model_state,
                         ema_rate=config.model.ema_rate,
                         params_ema=initial_params,
                         rng=rng)  # pytype: disable=wrong-keyword-args

    checkpoint_dir = os.path.join(workdir, "checkpoints")

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min,
                            beta_max=config.model.beta_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min,
                               beta_max=config.model.beta_max,
                               N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min,
                            sigma_max=config.model.sigma_max,
                            N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    # Create the one-step evaluation function when loss computation is enabled
    if config.eval.enable_loss:
        optimize_fn = losses.optimization_manager(config)
        continuous = config.training.continuous
        likelihood_weighting = config.training.likelihood_weighting

        reduce_mean = config.training.reduce_mean
        eval_step = losses.get_step_fn(
            sde,
            score_model,
            train=False,
            optimize_fn=optimize_fn,
            reduce_mean=reduce_mean,
            continuous=continuous,
            likelihood_weighting=likelihood_weighting)
        # Pmap (and jit-compile) multiple evaluation steps together for faster execution
        p_eval_step = jax.pmap(functools.partial(jax.lax.scan, eval_step),
                               axis_name='batch',
                               donate_argnums=1)

    # Create data loaders for likelihood evaluation. Only evaluate on uniformly dequantized data
    train_ds_bpd, eval_ds_bpd, _ = datasets.get_dataset(
        config,
        additional_dim=None,
        uniform_dequantization=True,
        evaluation=True)
    if config.eval.bpd_dataset.lower() == 'train':
        ds_bpd = train_ds_bpd
        bpd_num_repeats = 1
    elif config.eval.bpd_dataset.lower() == 'test':
        # Go over the dataset 5 times when computing likelihood on the test dataset
        ds_bpd = eval_ds_bpd
        bpd_num_repeats = 5
    else:
        raise ValueError(
            f"No bpd dataset {config.eval.bpd_dataset} recognized.")

    # Build the likelihood computation function when likelihood is enabled
    if config.eval.enable_bpd:
        likelihood_fn = likelihood.get_likelihood_fn(sde, score_model,
                                                     inverse_scaler)

    # Build the sampling function when sampling is enabled
    if config.eval.enable_sampling:
        sampling_shape = (config.eval.batch_size // jax.local_device_count(),
                          config.data.image_size, config.data.image_size,
                          config.data.num_channels)
        sampling_fn = sampling.get_sampling_fn(config, sde, score_model,
                                               sampling_shape, inverse_scaler,
                                               sampling_eps)

    # Create different random states for different hosts in a multi-host environment (e.g., TPU pods)
    rng = jax.random.fold_in(rng, jax.host_id())

    # A data class for storing intermediate results to resume evaluation after pre-emption
    @flax.struct.dataclass
    class EvalMeta:
        ckpt_id: int
        sampling_round_id: int
        bpd_round_id: int
        rng: Any

    # Add one additional round to get the exact number of samples as required.
    num_sampling_rounds = config.eval.num_samples // config.eval.batch_size + 1
    num_bpd_rounds = len(ds_bpd) * bpd_num_repeats

    # Restore evaluation after pre-emption
    eval_meta = EvalMeta(ckpt_id=config.eval.begin_ckpt,
                         sampling_round_id=-1,
                         bpd_round_id=-1,
                         rng=rng)
    eval_meta = checkpoints.restore_checkpoint(eval_dir,
                                               eval_meta,
                                               step=None,
                                               prefix=f"meta_{jax.host_id()}_")

    if eval_meta.bpd_round_id < num_bpd_rounds - 1:
        begin_ckpt = eval_meta.ckpt_id
        begin_bpd_round = eval_meta.bpd_round_id + 1
        begin_sampling_round = 0

    elif eval_meta.sampling_round_id < num_sampling_rounds - 1:
        begin_ckpt = eval_meta.ckpt_id
        begin_bpd_round = num_bpd_rounds
        begin_sampling_round = eval_meta.sampling_round_id + 1

    else:
        begin_ckpt = eval_meta.ckpt_id + 1
        begin_bpd_round = 0
        begin_sampling_round = 0

    rng = eval_meta.rng

    # Use inceptionV3 for images with resolution higher than 256.
    inceptionv3 = config.data.image_size >= 256
    inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)

    logging.info("begin checkpoint: %d" % (begin_ckpt, ))
    for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):
        # Wait if the target checkpoint doesn't exist yet
        waiting_message_printed = False
        ckpt_filename = os.path.join(checkpoint_dir,
                                     "checkpoint_{}".format(ckpt))
        while not tf.io.gfile.exists(ckpt_filename):
            if not waiting_message_printed and jax.host_id() == 0:
                logging.warning("Waiting for the arrival of checkpoint_%d" %
                                (ckpt, ))
                waiting_message_printed = True
            time.sleep(60)

        # Wait for 2 additional mins in case the file exists but is not ready for reading
        try:
            state = checkpoints.restore_checkpoint(checkpoint_dir,
                                                   state,
                                                   step=ckpt)
        except:
            time.sleep(60)
            try:
                state = checkpoints.restore_checkpoint(checkpoint_dir,
                                                       state,
                                                       step=ckpt)
            except:
                time.sleep(120)
                state = checkpoints.restore_checkpoint(checkpoint_dir,
                                                       state,
                                                       step=ckpt)

        # Replicate the training state for executing on multiple devices
        pstate = flax.jax_utils.replicate(state)
        # Compute the loss function on the full evaluation dataset if loss computation is enabled
        if config.eval.enable_loss:
            all_losses = []
            eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
            for i, batch in enumerate(eval_iter):
                eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), batch)  # pylint: disable=protected-access
                rng, *next_rng = jax.random.split(
                    rng, num=jax.local_device_count() + 1)
                next_rng = jnp.asarray(next_rng)
                (_, _), p_eval_loss = p_eval_step((next_rng, pstate),
                                                  eval_batch)
                eval_loss = flax.jax_utils.unreplicate(p_eval_loss)
                all_losses.extend(eval_loss)
                if (i + 1) % 1000 == 0 and jax.host_id() == 0:
                    logging.info("Finished %dth step loss evaluation" %
                                 (i + 1))

            # Save loss values to disk or Google Cloud Storage
            all_losses = jnp.asarray(all_losses)
            with tf.io.gfile.GFile(
                    os.path.join(eval_dir, f"ckpt_{ckpt}_loss.npz"),
                    "wb") as fout:
                io_buffer = io.BytesIO()
                np.savez_compressed(io_buffer,
                                    all_losses=all_losses,
                                    mean_loss=all_losses.mean())
                fout.write(io_buffer.getvalue())

        # Compute log-likelihoods (bits/dim) if enabled
        if config.eval.enable_bpd:
            bpds = []
            begin_repeat_id = begin_bpd_round // len(ds_bpd)
            begin_batch_id = begin_bpd_round % len(ds_bpd)
            # Repeat multiple times to reduce variance when needed
            for repeat in range(begin_repeat_id, bpd_num_repeats):
                bpd_iter = iter(ds_bpd)  # pytype: disable=wrong-arg-types
                for _ in range(begin_batch_id):
                    next(bpd_iter)
                for batch_id in range(begin_batch_id, len(ds_bpd)):
                    batch = next(bpd_iter)
                    eval_batch = jax.tree_map(lambda x: scaler(x._numpy()),
                                              batch)
                    rng, *step_rng = jax.random.split(
                        rng,
                        jax.local_device_count() + 1)
                    step_rng = jnp.asarray(step_rng)
                    bpd = likelihood_fn(step_rng, pstate,
                                        eval_batch['image'])[0]
                    bpd = bpd.reshape(-1)
                    bpds.extend(bpd)
                    logging.info(
                        "ckpt: %d, repeat: %d, batch: %d, mean bpd: %6f" %
                        (ckpt, repeat, batch_id, jnp.mean(jnp.asarray(bpds))))
                    bpd_round_id = batch_id + len(ds_bpd) * repeat
                    # Save bits/dim to disk or Google Cloud Storage
                    with tf.io.gfile.GFile(
                            os.path.join(
                                eval_dir,
                                f"{config.eval.bpd_dataset}_ckpt_{ckpt}_bpd_{bpd_round_id}.npz"
                            ), "wb") as fout:
                        io_buffer = io.BytesIO()
                        np.savez_compressed(io_buffer, bpd)
                        fout.write(io_buffer.getvalue())

                    eval_meta = eval_meta.replace(ckpt_id=ckpt,
                                                  bpd_round_id=bpd_round_id,
                                                  rng=rng)
                    # Save intermediate states to resume evaluation after pre-emption
                    checkpoints.save_checkpoint(
                        eval_dir,
                        eval_meta,
                        step=ckpt * (num_sampling_rounds + num_bpd_rounds) +
                        bpd_round_id,
                        keep=1,
                        prefix=f"meta_{jax.host_id()}_")
        else:
            # Skip likelihood computation and save intermediate states for pre-emption
            eval_meta = eval_meta.replace(ckpt_id=ckpt,
                                          bpd_round_id=num_bpd_rounds - 1)
            checkpoints.save_checkpoint(
                eval_dir,
                eval_meta,
                step=ckpt * (num_sampling_rounds + num_bpd_rounds) +
                num_bpd_rounds - 1,
                keep=1,
                prefix=f"meta_{jax.host_id()}_")

        # Generate samples and compute IS/FID/KID when enabled
        if config.eval.enable_sampling:
            state = jax.device_put(state)
            # Run sample generation for multiple rounds to create enough samples
            # Designed to be pre-emption safe. Automatically resumes when interrupted
            for r in range(begin_sampling_round, num_sampling_rounds):
                if jax.host_id() == 0:
                    logging.info("sampling -- ckpt: %d, round: %d" % (ckpt, r))

                # Directory to save samples. Different for each host to avoid writing conflicts
                this_sample_dir = os.path.join(
                    eval_dir, f"ckpt_{ckpt}_host_{jax.host_id()}")
                tf.io.gfile.makedirs(this_sample_dir)

                rng, *sample_rng = jax.random.split(
                    rng,
                    jax.local_device_count() + 1)
                sample_rng = jnp.asarray(sample_rng)
                samples, n = sampling_fn(sample_rng, pstate)
                samples = np.clip(samples * 255., 0, 255).astype(np.uint8)
                samples = samples.reshape(
                    (-1, config.data.image_size, config.data.image_size,
                     config.data.num_channels))
                # Write samples to disk or Google Cloud Storage
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, f"samples_{r}.npz"),
                        "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer, samples=samples)
                    fout.write(io_buffer.getvalue())

                # Force garbage collection before calling TensorFlow code for Inception network
                gc.collect()
                latents = evaluation.run_inception_distributed(
                    samples, inception_model, inceptionv3=inceptionv3)
                # Force garbage collection again before returning to JAX code
                gc.collect()
                # Save latent represents of the Inception network to disk or Google Cloud Storage
                with tf.io.gfile.GFile(
                        os.path.join(this_sample_dir, f"statistics_{r}.npz"),
                        "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer,
                                        pool_3=latents["pool_3"],
                                        logits=latents["logits"])
                    fout.write(io_buffer.getvalue())

                # Update the intermediate evaluation state
                eval_meta = eval_meta.replace(ckpt_id=ckpt,
                                              sampling_round_id=r,
                                              rng=rng)
                # Save an intermediate checkpoint directly if not the last round.
                # Otherwise save eval_meta after computing the Inception scores and FIDs
                if r < num_sampling_rounds - 1:
                    checkpoints.save_checkpoint(
                        eval_dir,
                        eval_meta,
                        step=ckpt * (num_sampling_rounds + num_bpd_rounds) +
                        r + num_bpd_rounds,
                        keep=1,
                        prefix=f"meta_{jax.host_id()}_")

            # Compute inception scores, FIDs and KIDs.
            if jax.host_id() == 0:
                # Load all statistics that have been previously computed and saved for each host
                all_logits = []
                all_pools = []
                for host in range(jax.host_count()):
                    this_sample_dir = os.path.join(eval_dir,
                                                   f"ckpt_{ckpt}_host_{host}")

                    stats = tf.io.gfile.glob(
                        os.path.join(this_sample_dir, "statistics_*.npz"))
                    wait_message = False
                    while len(stats) < num_sampling_rounds:
                        if not wait_message:
                            logging.warning(
                                "Waiting for statistics on host %d" % (host, ))
                            wait_message = True
                        stats = tf.io.gfile.glob(
                            os.path.join(this_sample_dir, "statistics_*.npz"))
                        time.sleep(30)

                    for stat_file in stats:
                        with tf.io.gfile.GFile(stat_file, "rb") as fin:
                            stat = np.load(fin)
                            if not inceptionv3:
                                all_logits.append(stat["logits"])
                            all_pools.append(stat["pool_3"])

                if not inceptionv3:
                    all_logits = np.concatenate(
                        all_logits, axis=0)[:config.eval.num_samples]
                all_pools = np.concatenate(all_pools,
                                           axis=0)[:config.eval.num_samples]

                # Load pre-computed dataset statistics.
                data_stats = evaluation.load_dataset_stats(config)
                data_pools = data_stats["pool_3"]

                # Compute FID/KID/IS on all samples together.
                if not inceptionv3:
                    inception_score = tfgan.eval.classifier_score_from_logits(
                        all_logits)
                else:
                    inception_score = -1

                fid = tfgan.eval.frechet_classifier_distance_from_activations(
                    data_pools, all_pools)
                # Hack to get tfgan KID work for eager execution.
                tf_data_pools = tf.convert_to_tensor(data_pools)
                tf_all_pools = tf.convert_to_tensor(all_pools)
                kid = tfgan.eval.kernel_classifier_distance_from_activations(
                    tf_data_pools, tf_all_pools).numpy()
                del tf_data_pools, tf_all_pools

                logging.info(
                    "ckpt-%d --- inception_score: %.6e, FID: %.6e, KID: %.6e" %
                    (ckpt, inception_score, fid, kid))

                with tf.io.gfile.GFile(
                        os.path.join(eval_dir, f"report_{ckpt}.npz"),
                        "wb") as f:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer,
                                        IS=inception_score,
                                        fid=fid,
                                        kid=kid)
                    f.write(io_buffer.getvalue())
            else:
                # For host_id() != 0.
                # Use file existence to emulate synchronization across hosts
                while not tf.io.gfile.exists(
                        os.path.join(eval_dir, f"report_{ckpt}.npz")):
                    time.sleep(1.)

            # Save eval_meta after computing IS/KID/FID to mark the end of evaluation for this checkpoint
            checkpoints.save_checkpoint(
                eval_dir,
                eval_meta,
                step=ckpt * (num_sampling_rounds + num_bpd_rounds) + r +
                num_bpd_rounds,
                keep=1,
                prefix=f"meta_{jax.host_id()}_")

        else:
            # Skip sampling and save intermediate evaluation states for pre-emption
            eval_meta = eval_meta.replace(
                ckpt_id=ckpt,
                sampling_round_id=num_sampling_rounds - 1,
                rng=rng)
            checkpoints.save_checkpoint(
                eval_dir,
                eval_meta,
                step=ckpt * (num_sampling_rounds + num_bpd_rounds) +
                num_sampling_rounds - 1 + num_bpd_rounds,
                keep=1,
                prefix=f"meta_{jax.host_id()}_")

        begin_bpd_round = 0
        begin_sampling_round = 0

    # Remove all meta files after finishing evaluation
    meta_files = tf.io.gfile.glob(
        os.path.join(eval_dir, f"meta_{jax.host_id()}_*"))
    for file in meta_files:
        tf.io.gfile.remove(file)
Exemple #8
0
def train(model, dataset, data_augmentation, epochs, batch_size, beta, M,
          initial_lr, lr_schedule, strategy, output_dir, class_loss, cov_type):

    model_conf = model

    train_set, test_set, small_set = datasets.get_dataset(dataset)

    TRAIN_BUF, TEST_BUF = datasets.dataset_size[dataset]

    if data_augmentation:
        base_dataset = dataset.split("-")[0]
        print(f"Using image generator params from {base_dataset}")
        with open(f"./datasets/image-generator-config/{base_dataset}.yml",
                  "r") as fh:
            params = yaml.safe_load(fh)
            print(params)
        train_dataset = tf.keras.preprocessing.image.ImageDataGenerator(
            **params)
        train_dataset.fit(train_set[0])

    else:
        train_dataset = tf.data.Dataset.from_tensor_slices(train_set) \
            .shuffle(TRAIN_BUF).batch(batch_size)

    test_dataset = tf.data.Dataset.from_tensor_slices(test_set) \
        .shuffle(TEST_BUF).batch(batch_size)

    print(
        f"Training with {model} on {dataset} for {epochs} epochs (lr={initial_lr}, schedule={lr_schedule})"
    )
    print(
        f"Params: batch-size={batch_size} beta={beta} M={M} lr={initial_lr} strategy={strategy}"
    )

    optimizers, strategy_name, opt_params = losses.get_optimizer(
        strategy, lr, lr_schedule, dataset, batch_size)

    network_name, architecture = model.split("/")
    experiment_name = utils.get_experiment_name(
        f"{network_name}-{class_loss}-{cov_type}-{dataset}")

    print(f"Experiment name: {experiment_name}")
    artifact_dir = f"{output_dir}/{experiment_name}"
    print(f"Artifact directory: {artifact_dir}")

    train_log_dir = f"{artifact_dir}/logs/train"
    test_log_dir = f"{artifact_dir}/logs/test"

    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    test_summary_writer = tf.summary.create_file_writer(test_log_dir)

    # Instantiate model
    architecture = utils.parse_arch(architecture)

    model = nets.get_network(network_name)(architecture,
                                           datasets.input_dims[dataset],
                                           datasets.num_classes[dataset],
                                           cov_type,
                                           beta=beta,
                                           M=M)

    model.build(input_shape=(batch_size, *datasets.input_dims[dataset]))
    model.summary()

    print(f"Class loss: {class_loss}")
    model.class_loss = getattr(losses, f"compute_{class_loss}_class_loss")

    lr_labels = list(map(lambda x: f"lr_{x}", range(len(optimizers))))

    train_step = train_algo2 if strategy.split(
        "/")[0] == "algo2" else train_algo1

    print("Using trainstep: ", train_step)

    train_start_time = time.time()

    steps_per_epoch = int(np.ceil(train_set[0].shape[0] / batch_size))

    for epoch in range(1, epochs + 1):
        start_time = time.time()

        print(f"Epoch {epoch}")

        m, am = train_step(
            model, optimizers,
            train_dataset.flow(
                train_set[0], train_set[1], batch_size=batch_size)
            if data_augmentation else train_dataset, train_summary_writer, M,
            lr_labels, strategy_name, opt_params, epoch, steps_per_epoch)

        m = m.result().numpy()
        am = am.result().numpy()

        print(utils.format_metrics("Train", m, am))

        tfutils.log_metrics(train_summary_writer, metric_labels, m, epoch)
        tfutils.log_metrics(train_summary_writer, acc_labels, am, epoch)

        tfutils.log_metrics(
            train_summary_writer, lr_labels,
            map(lambda opt: opt._decayed_lr(tf.float32), optimizers), epoch)

        train_metrics = m.astype(float).tolist() + am.astype(float).tolist()
        end_time = time.time()

        test_metrics = evaluate(model, test_dataset, test_summary_writer, M,
                                epoch)

        print(f"--- Time elapse for current epoch {end_time - start_time}")

    train_end_time = time.time()
    elapsed_time = (train_end_time - train_start_time) / 60.

    test_metrics_dict = dict(zip(metric_labels + acc_labels, test_metrics))
    summary = dict(
        dataset=dataset,
        model=model_conf,
        strategy=strategy,
        beta=beta,
        epoch=epoch,
        M=M,
        lr=initial_lr,
        lr_schedule=lr_schedule,
        metrics=dict(
            train=dict(zip(metric_labels + acc_labels, train_metrics)),
            test=test_metrics_dict,
        ),
        class_loss=class_loss,
        cov_type=cov_type,
        batch_size=batch_size,
        elapsed_time=elapsed_time,  # in minutes
        test_accuracy_L12=test_metrics_dict["accuracy_L12"],
        data_augmentation=data_augmentation)

    if model.latent_dim == 2:
        plot_helper.plot_2d_representation(
            model,
            small_set,
            title="Epoch=%d Strategy=%s  Beta=%f M=%f" %
            (epoch, strategy, beta, M),
            path=f"{artifact_dir}/latent-representation.png")

    with train_summary_writer.as_default():
        tf.summary.text("setting",
                        json.dumps(summary, sort_keys=True, indent=4),
                        step=0)

    with open(f"{artifact_dir}/summary.yml", 'w') as f:
        print(summary)
        yaml.dump(summary, f, default_flow_style=False)

    model.save_weights(f"{artifact_dir}/model")

    print(f"Training took {elapsed_time:.4f} minutes")
    print(f"Please see artifact at: {artifact_dir}")