def loss_fn(model): if train: with nn.stateful(state.model_state) as new_model_state: with nn.stochastic(run_rng): if not class_conditional: scores = model(perturbed_data, labels, train=train) else: scores = model(perturbed_data, labels, y=class_labels, train=train) else: with nn.stateful(state.model_state, mutable=False): with nn.stochastic(run_rng): if not class_conditional: scores = model(perturbed_data, labels, train=train) else: scores = model(perturbed_data, labels, y=class_labels, train=train) new_model_state = state.model_state scores = scores.reshape((scores.shape[0], -1)) target = -1 / (used_sigmas**2) * noise target = target.reshape((target.shape[0], -1)) losses = 1 / 2. * ((scores - target)**2).sum( axis=-1) * used_sigmas.squeeze()**anneal_power loss = jnp.mean(losses) if loss_per_sigma: return loss, new_model_state, losses else: return loss, new_model_state
def loss_fn(model): if train: with nn.stateful(state.model_state) as new_model_state: with nn.stochastic(run_rng): scores = model(perturbed_data, T, train=train) else: with nn.stateful(state.model_state, mutable=False): with nn.stochastic(run_rng): scores = model(perturbed_data, T, train=train) new_model_state = state.model_state scores = scores.reshape((scores.shape[0], -1)) target = noise.reshape((noise.shape[0], -1)) loss = jnp.mean((scores - target)**2) return loss, new_model_state
def impl_loss_fn(model_params): with nn.stochastic(rng), nn.stateful( state.model_state) as new_model_state: logits, stats = module.call(model_params, batch["image"]) losses = loss_fn if isinstance(loss_fn, (list, tuple)) else [loss_fn] loss = sum(l(logits, batch["label"], stats) for l in losses) return loss, (logits, new_model_state, stats)
def dann_forward_pass( batch, flax_module, model_state, input_layer_key, train, ): """Forward pass of flax_module for DANN. Args: batch: dict; A batch of examples. flax_module: flax.nn.Model; Flax model. model_state: flax.nn.Collection; Model state. input_layer_key: str; Which layer the input should be plugged in. train: bool; Train flag. Returns: logits, hidden activations, activations of key layer, and new model state. """ with nn.stateful(model_state) as new_model_state: logits, reps, reps_key, domain_logits = flax_module( batch, train=train, return_activations=True, input_layer_key=input_layer_key, discriminator=True) key_reps = reps[reps_key] return logits, domain_logits, reps, key_reps, new_model_state
def eval_step(self, train_state, batch, env_id, all_env_ids): """Runs a single step of evaluation. Args: train_state: TrainState, the state of training including the current global_step, model_state, rng, and optimizer. batch: A single batch of data. a metrics function, that given logits and batch of data, calculates the metrics as well as the loss. env_id: int: Eval environments code. all_env_ids: List of eval all environment ids. Returns: Calculated metrics. """ flax_model = train_state.optimizer.target inputs = pipeline_utils.get_multi_env_inputs(batch, 'inputs') with nn.stateful(train_state.model_state, mutable=False): env_logits = pipeline_utils.vmapped_flax_module_eval( flax_model, inputs) if env_id >= 0: metrics = self.metrics_fn(env_logits, batch, [env_id], flax_model) else: metrics = self.metrics_fn(env_logits, batch, all_env_ids, flax_model) return metrics
def create_model(module, input_shape, rng): """Instanciates the model.""" model_rng, init_rng = jax.random.split(rng) with nn.stochastic(model_rng), nn.stateful() as init_state: x = jnp.ones(input_shape, dtype=jnp.float32) _, init_params = module.init(init_rng, x) model = nn.Model(module, init_params) return model, init_params, init_state
def get_reps(train_state, flax_module, batch): with nn.stochastic(train_state.rng): with nn.stateful(train_state.model_state): _, reps, _ = flax_module(batch['inputs'], train=True, return_activations=True) return reps
def _create_flax_module(): device_batch_size = hparams.batch_size // jax.device_count() shape = (device_batch_size, ) + tuple(input_shape[1:]) model_rng, init_rng = jax.random.split(rng) with nn.stateful() as init_model_state: with nn.stochastic(model_rng): _, initial_params = flax_module_def.init_by_shape( init_rng, [(shape, model_input_dtype)]) flax_module = nn.Model(flax_module_def, initial_params) num_trainable_params = model_utils.log_param_shapes(flax_module) return flax_module, init_model_state, num_trainable_params
def vmapped_flax_module_train(flax_module, model_state, env_batch): """Vmapped forward pass of flax_module (with train flag == True). Args: flax_module: flax.nn.Model; Flax model. model_state: flax.nn.Collection; Model state. env_batch: dict; A batch of examples. Returns: logits. """ with nn.stateful(model_state) as new_model_state: return flax_module(env_batch, train=True, return_activations=False), new_model_state
def loss_fn(model): """Loss function used for training.""" # Stateful collection for tracking internal state like activations. with nn.stateful() as batch_stats: with nn.stochastic(dropout_rng): outputs = model(inputs, train=True, cache=None) if isinstance(outputs, dict): logits = outputs.get('logits', None) regression_predictions = outputs.get('regression', None) else: logits = outputs regression_predictions = None mean_loss = 0.0 # Classification loss if classification_targets is not None: classification_loss, classification_weight_sum = utils.compute_weighted_cross_entropy( logits, classification_targets, token_weights=classification_weights, example_weights=example_weights) classification_weight_sum = jnp.maximum(classification_weight_sum, epsilon) # Handle case where nothing is masked out in BERT # (Only occurs with very short sequences). mean_classification_loss = classification_loss / classification_weight_sum mean_loss += mean_classification_loss if regression_targets is not None: regression_loss, regression_weight_sum = utils.compute_weighted_mse( regression_predictions, regression_targets, weights=regression_weights) regression_weight_sum = jnp.maximum(regression_weight_sum, epsilon) mean_regression_loss = regression_loss / regression_weight_sum outputs['regression_loss'] = mean_regression_loss # TODO(ddohan): Allow weighting each loss separately. mean_loss += mean_regression_loss return mean_loss, (outputs, batch_stats)
def training_loss_fn(self, flax_module, train_state, batch, dropout_rng): """Runs forward pass and computes loss. Args: flax_module: A flax module. train_state: TrainState, the state of training including the current global_step, model_state, rng, and optimizer. batch: Batches from different environments. dropout_rng: FLAX PRNG key. Returns: loss, new_module_state and computed logits for each batch. """ with nn.stateful(train_state.model_state) as new_model_state: with nn.stochastic(dropout_rng): logits = flax_module(batch['inputs'], train=True) loss = self.task.loss_function(logits, batch, flax_module.params) return loss, (new_model_state, logits)
def pseudo_label_generator(batch, train_state, pseudo_labels_transformer_fn=lambda x: (x, None), input_key='inputs', train=True): """Pseudo label generator passed to the dataset class. This function can be passed to datasets initializer for self-supervised training or distillation. Args: batch: dict; Batch of examples, witch an 'inputs' key. train_state: TrainState; Train state of the model which we want to use to generate pseudo labels. pseudo_labels_transformer_fn: function; A function that applies a specific transformation on the logits from the model to generate the labels. The most basic function to be used here is a simple softmax or argmax to get one-hot labels. This function should return the labels and the weights for each example in the batch (for each label) and has the following API: ``` new_labels, weights = pseudo_labels_transformer(logits) ``` input_key: str; What key to use to retrieve the input field of the batch. train: bool; Train flag passed to the model forward pass. Returns: Return the batch with ground truth labels and weights replaced with pseudo labels and new weights. """ inputs = batch[input_key] _, dropout_rng = jax.random.split(train_state.rng) with nn.stochastic(dropout_rng): with nn.stateful(train_state.model_state): logits = train_state.optimizer.target(inputs, train=train) # Make sure the parameter of the teacher are not updated. logits = jax.lax.stop_gradient(logits) batch['label'], weights = pseudo_labels_transformer_fn(logits) if weights is not None: batch['weights'] = weights return batch
def eval_step(train_state, batch, metrics_fn): """Runs a single step of training. Args: train_state: TrainState, the state of training including the current global_step, model_state, rng, and optimizer. batch: A single batch of data. a metrics function, that given logits and batch of data, calculates the metrics as well as the loss. metrics_fn: A metrics function, that given logits and batch of data, calculates the metrics as well as the loss. Returns: Calculated metrics. """ flax_module = train_state.optimizer.target with nn.stateful(train_state.model_state, mutable=False): logits = flax_module(batch['inputs'], train=False) metrics = metrics_fn(logits, batch) return metrics
def train(config, workdir): """Runs a training loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ # Create directories for experimental logs tf.io.gfile.makedirs(workdir) sample_dir = os.path.join(workdir, "samples") tf.io.gfile.makedirs(sample_dir) rng = jax.random.PRNGKey(config.seed) tb_dir = os.path.join(workdir, "tensorboard") tf.io.gfile.makedirs(tb_dir) if jax.host_id() == 0: writer = tensorboard.SummaryWriter(tb_dir) # Initialize model. rng, model_rng = jax.random.split(rng) model_name = config.model.name ncsn_def = mutils.get_model(model_name).partial(config=config) rng, run_rng = jax.random.split(rng) # Whether the generative model is conditioned on class labels class_conditional = "conditional" in config.training.loss.lower() with nn.stateful() as init_model_state: with nn.stochastic(run_rng): input_shape = (jax.local_device_count(), config.data.image_size, config.data.image_size, 3) input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)] if class_conditional: input_list.append(input_list[-1]) _, initial_params = ncsn_def.init_by_shape( model_rng, input_list, train=True) ncsn = nn.Model(ncsn_def, initial_params) optimizer = losses.get_optimizer(config).create(ncsn) state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr, model_state=init_model_state, ema_rate=config.model.ema_rate, params_ema=initial_params, rng=rng) # pytype: disable=wrong-keyword-args del ncsn, init_model_state # Do not keep a copy of the initial model. # Create checkpoints directory and the initial checkpoint checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = utils.Checkpoint( checkpoint_dir, max_to_keep=None) ckpt.restore_or_initialize(state) # Save intermediate checkpoints to resume training automatically checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta") ckpt_meta = utils.Checkpoint( checkpoint_meta_dir, max_to_keep=1) state = ckpt_meta.restore_or_initialize(state) initial_step = int(state.step) rng = state.rng # Build input pipeline. rng, ds_rng = jax.random.split(rng) train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config) train_iter = iter(train_ds) # pytype: disable=wrong-arg-types eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types scaler = datasets.get_data_scaler(config) # data normalizer inverse_scaler = datasets.get_data_inverse_scaler(config) # Distribute training. optimize_fn = losses.optimization_manager(config) if config.training.loss.lower() == "ddpm": # Use score matching loss with DDPM-type perturbation. ddpm_params = mutils.get_ddpm_params() train_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params, train=True, optimize_fn=optimize_fn) eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params, train=False) else: # Use score matching loss with NCSN-type perturbation. sigmas = mutils.get_sigmas(config) # Whether to use a continuous distribution of noise levels continuous = "continuous" in config.training.loss.lower() train_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, class_conditional=class_conditional, continuous=continuous, train=True, optimize_fn=optimize_fn, anneal_power=config.training.anneal_power) eval_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, class_conditional=class_conditional, continuous=continuous, train=False, anneal_power=config.training.anneal_power) p_train_step = jax.pmap(train_step, axis_name="batch") p_eval_step = jax.pmap(eval_step, axis_name="batch") state = flax_utils.replicate(state) num_train_steps = config.training.n_iters logging.info("Starting training loop at step %d.", initial_step) rng = jax.random.fold_in(rng, jax.host_id()) for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. # Convert data to JAX arrays. Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter)) # pylint: disable=protected-access rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) loss, state = p_train_step(next_rng, state, batch) loss = flax.jax_utils.unreplicate(loss) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) if jax.host_id() == 0 and step % 50 == 0: logging.info("step: %d, training_loss: %.5e", step, loss) writer.scalar("training_loss", loss, step) # Save a temporary checkpoint to resume training after pre-emption. if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id( ) == 0: saved_state = flax_utils.unreplicate(state) saved_state = saved_state.replace(rng=rng) ckpt_meta.save(saved_state) # Report the loss on an evaluation dataset. if step % 100 == 0: rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter)) # pylint: disable=protected-access eval_loss, _ = p_eval_step(next_rng, state, eval_batch) eval_loss = flax.jax_utils.unreplicate(eval_loss) if jax.host_id() == 0: logging.info("step: %d, eval_loss: %.5e", step, eval_loss) writer.scalar("eval_loss", eval_loss, step) # Save a checkpoint periodically and generate samples. if (step + 1) % config.training.snapshot_freq == 0 or step == num_train_steps: # Save the checkpoint. if jax.host_id() == 0: saved_state = flax_utils.unreplicate(state) saved_state = saved_state.replace(rng=rng) ckpt.save(saved_state) # Generate and save samples if config.training.snapshot_sampling: rng, sample_rng = jax.random.split(rng) init_shape = tuple(train_ds.element_spec["image"].shape) samples = sampling.get_samples(sample_rng, config, flax_utils.unreplicate(state), init_shape, scaler, inverse_scaler, class_conditional=class_conditional) this_sample_dir = os.path.join( sample_dir, "iter_{}_host_{}".format(step, jax.host_id())) tf.io.gfile.makedirs(this_sample_dir) if config.sampling.final_only: # Do not save intermediate samples sample = samples[-1] image_grid = sample.reshape((-1, *sample.shape[2:])) nrow = int(np.sqrt(image_grid.shape[0])) sample = np.clip(sample * 255, 0, 255).astype(np.uint8) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.np"), "wb") as fout: np.save(fout, sample) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.png"), "wb") as fout: utils.save_image(image_grid, fout, nrow=nrow, padding=2) else: # Save all intermediate samples produced during sampling. for i, sample in enumerate(samples): image_grid = sample.reshape((-1, *sample.shape[2:])) nrow = int(np.sqrt(image_grid.shape[0])) sample = np.clip(sample * 255, 0, 255).astype(np.uint8) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample_{}.np".format(i)), "wb") as fout: np.save(fout, sample) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample_{}.png".format(i)), "wb") as fout: utils.save_image(image_grid, fout, nrow=nrow, padding=2)
def evaluate(config, workdir, eval_folder = "eval"): """Evaluate trained models. Args: config: Configuration to use. workdir: Working directory for checkpoints. eval_folder: The subfolder for storing evaluation results. Default to "eval". """ # Create eval_dir eval_dir = os.path.join(workdir, eval_folder) tf.io.gfile.makedirs(eval_dir) rng = jax.random.PRNGKey(config.seed + 1) # Build input pipeline. rng, ds_rng = jax.random.split(rng) _, eval_ds, _ = datasets.get_dataset(ds_rng, config, evaluation=True) scaler = datasets.get_data_scaler(config) inverse_scaler = datasets.get_data_inverse_scaler(config) # Initialize model. rng, model_rng = jax.random.split(rng) model_name = config.model.name ncsn_def = mutils.get_model(model_name).partial(config=config) rng, run_rng = jax.random.split(rng) class_conditional = "conditional" in config.training.loss.lower() with nn.stateful() as init_model_state: with nn.stochastic(run_rng): input_shape = tuple(eval_ds.element_spec["image"].shape[1:]) input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)] if class_conditional: input_list.append(input_list[-1]) _, initial_params = ncsn_def.init_by_shape( model_rng, input_list, train=True) ncsn = nn.Model(ncsn_def, initial_params) optimizer = losses.get_optimizer(config).create(ncsn) state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr, model_state=init_model_state, ema_rate=config.model.ema_rate, params_ema=initial_params, rng=rng) # pytype: disable=wrong-keyword-args del ncsn, init_model_state # Do not keep a copy of the initial model. checkpoint_dir = os.path.join(workdir, "checkpoints") if config.training.loss.lower() == "ddpm": # Use the score matching loss with DDPM-type perturbation. ddpm_params = mutils.get_ddpm_params() eval_step = functools.partial( losses.ddpm_loss, ddpm_params=ddpm_params, train=False) else: # Use the score matching loss with NCSN-type perturbation. sigmas = mutils.get_sigmas(config) continuous = "continuous" in config.training.loss.lower() eval_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, continuous=continuous, class_conditional=class_conditional, train=False, anneal_power=config.training.anneal_power) p_eval_step = jax.pmap(eval_step, axis_name="batch") rng = jax.random.fold_in(rng, jax.host_id()) # A data class for checkpointing. @flax.struct.dataclass class EvalMeta: ckpt_id: int round_id: int rng: Any # Add one additional round to get the exact number of samples as required. num_rounds = config.eval.num_samples // config.eval.batch_size + 1 eval_meta = EvalMeta(ckpt_id=config.eval.begin_ckpt, round_id=-1, rng=rng) eval_meta = checkpoints.restore_checkpoint( eval_dir, eval_meta, step=None, prefix=f"meta_{jax.host_id()}_") if eval_meta.round_id < num_rounds - 1: begin_ckpt = eval_meta.ckpt_id begin_round = eval_meta.round_id + 1 else: begin_ckpt = eval_meta.ckpt_id + 1 begin_round = 0 rng = eval_meta.rng # Use inceptionV3 for images with higher resolution inceptionv3 = config.data.image_size >= 256 inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3) logging.info("begin checkpoint: %d", begin_ckpt) for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1): ckpt_filename = os.path.join(checkpoint_dir, "ckpt-{}.flax".format(ckpt)) # Wait if the target checkpoint hasn't been produced yet. waiting_message_printed = False while not tf.io.gfile.exists(ckpt_filename): if not waiting_message_printed and jax.host_id() == 0: logging.warn("Waiting for the arrival of ckpt-%d.flax", ckpt) waiting_message_printed = True time.sleep(10) # In case the file was just written and not ready to read from yet. try: state = utils.load_state_dict(ckpt_filename, state) except: time.sleep(60) try: state = utils.load_state_dict(ckpt_filename, state) except: time.sleep(120) state = utils.load_state_dict(ckpt_filename, state) pstate = flax.jax_utils.replicate(state) eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types # Compute the loss function on the full evaluation dataset. all_losses = [] for i, batch in enumerate(eval_iter): rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), batch) # pylint: disable=protected-access eval_loss, _ = p_eval_step(next_rng, pstate, eval_batch) eval_loss = flax.jax_utils.unreplicate(eval_loss) all_losses.append(eval_loss) if (i + 1) % 1000 == 0 and jax.host_id() == 0: logging.info("Finished %dth step loss evaluation", i + 1) all_losses = jnp.asarray(all_losses) state = jax.device_put(state) # Sampling and computing statistics for Inception scores, FIDs, and KIDs. # Designed to be pre-emption safe. Automatically resumes when interrupted. for r in range(begin_round, num_rounds): if jax.host_id() == 0: logging.info("sampling -- ckpt: %d, round: %d", ckpt, r) rng, sample_rng = jax.random.split(rng) init_shape = tuple(eval_ds.element_spec["image"].shape) this_sample_dir = os.path.join( eval_dir, f"ckpt_{ckpt}_host_{jax.host_id()}") tf.io.gfile.makedirs(this_sample_dir) samples = sampling.get_samples(sample_rng, config, state, init_shape, scaler, inverse_scaler, class_conditional=class_conditional) samples = samples[-1] samples = np.clip(samples * 255., 0, 255).astype(np.uint8) samples = samples.reshape( (-1, config.data.image_size, config.data.image_size, 3)) with tf.io.gfile.GFile( os.path.join(this_sample_dir, f"samples_{r}.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, samples=samples) fout.write(io_buffer.getvalue()) gc.collect() latents = evaluation.run_inception_distributed(samples, inception_model, inceptionv3=inceptionv3) gc.collect() with tf.io.gfile.GFile( os.path.join(this_sample_dir, f"statistics_{r}.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed( io_buffer, pool_3=latents["pool_3"], logits=latents["logits"]) fout.write(io_buffer.getvalue()) eval_meta = eval_meta.replace(ckpt_id=ckpt, round_id=r, rng=rng) # Save an intermediate checkpoint directly if not the last round. # Otherwise save eval_meta after computing the Inception scores and FIDs if r < num_rounds - 1: checkpoints.save_checkpoint( eval_dir, eval_meta, step=ckpt * num_rounds + r, keep=1, prefix=f"meta_{jax.host_id()}_") # Compute inception scores, FIDs and KIDs. if jax.host_id() == 0: # Load all statistics that have been previously computed and saved. all_logits = [] all_pools = [] for host in range(jax.host_count()): this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}_host_{host}") stats = tf.io.gfile.glob( os.path.join(this_sample_dir, "statistics_*.npz")) wait_message = False while len(stats) < num_rounds: if not wait_message: logging.warn("Waiting for statistics on host %d", host) wait_message = True stats = tf.io.gfile.glob( os.path.join(this_sample_dir, "statistics_*.npz")) time.sleep(1) for stat_file in stats: with tf.io.gfile.GFile(stat_file, "rb") as fin: stat = np.load(fin) if not inceptionv3: all_logits.append(stat["logits"]) all_pools.append(stat["pool_3"]) if not inceptionv3: all_logits = np.concatenate( all_logits, axis=0)[:config.eval.num_samples] all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples] # Load pre-computed dataset statistics. data_stats = evaluation.load_dataset_stats(config) data_pools = data_stats["pool_3"] if hasattr(config.eval, "num_partitions"): # Divide samples into several partitions and compute FID/KID/IS on them. assert not inceptionv3 fids = [] kids = [] inception_scores = [] partition_size = config.eval.num_samples // config.eval.num_partitions tf_data_pools = tf.convert_to_tensor(data_pools) for i in range(config.eval.num_partitions): this_pools = all_pools[i * partition_size:(i + 1) * partition_size] this_logits = all_logits[i * partition_size:(i + 1) * partition_size] inception_scores.append( tfgan.eval.classifier_score_from_logits(this_logits)) fids.append( tfgan.eval.frechet_classifier_distance_from_activations( data_pools, this_pools)) this_pools = tf.convert_to_tensor(this_pools) kids.append( tfgan.eval.kernel_classifier_distance_from_activations( tf_data_pools, this_pools).numpy()) fids = np.asarray(fids) inception_scores = np.asarray(inception_scores) kids = np.asarray(kids) with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_all_{ckpt}.npz"), "wb") as f: io_buffer = io.BytesIO() np.savez_compressed( io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(), ISs=inception_scores, fids=fids, kids=kids) f.write(io_buffer.getvalue()) else: # Compute FID/KID/IS on all samples together. if not inceptionv3: inception_score = tfgan.eval.classifier_score_from_logits(all_logits) else: inception_score = -1 fid = tfgan.eval.frechet_classifier_distance_from_activations( data_pools, all_pools) # Hack to get tfgan KID work for eager execution. tf_data_pools = tf.convert_to_tensor(data_pools) tf_all_pools = tf.convert_to_tensor(all_pools) kid = tfgan.eval.kernel_classifier_distance_from_activations( tf_data_pools, tf_all_pools).numpy() del tf_data_pools, tf_all_pools logging.info( "ckpt-%d --- loss: %.6e, inception_score: %.6e, FID: %.6e, KID: %.6e", ckpt, all_losses.mean(), inception_score, fid, kid) with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_{ckpt}.npz"), "wb") as f: io_buffer = io.BytesIO() np.savez_compressed( io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(), IS=inception_score, fid=fid, kid=kid) f.write(io_buffer.getvalue()) else: # For host_id() != 0. # Use file existence to emulate synchronization across hosts. if hasattr(config.eval, "num_partitions"): assert not inceptionv3 while not tf.io.gfile.exists( os.path.join(eval_dir, f"report_all_{ckpt}.npz")): time.sleep(1.) else: while not tf.io.gfile.exists( os.path.join(eval_dir, f"report_{ckpt}.npz")): time.sleep(1.) # Save eval_meta after computing IS/KID/FID to mark the end of evaluation # for this checkpoint. checkpoints.save_checkpoint( eval_dir, eval_meta, step=ckpt * num_rounds + r, keep=1, prefix=f"meta_{jax.host_id()}_") begin_round = 0 # Remove all meta files after finishing evaluation. meta_files = tf.io.gfile.glob( os.path.join(eval_dir, f"meta_{jax.host_id()}_*")) for file in meta_files: tf.io.gfile.remove(file)
def training_loss_fn(self, flax_module, train_state, batch, dropout_rng, mixup_rng, sampled_layer): """Runs forward pass and computes loss. Args: flax_module: A flax module. train_state: TrainState, the state of training including the current global_step, model_state, rng, and optimizer. batch: Batches from different environments. dropout_rng: FLAX PRNG key. mixup_rng: FLAX PRNG key. sampled_layer: str; Name of the layer on which mixup will be applied. Returns: loss, new_module_state and computed logits for each batch. """ with nn.stochastic(dropout_rng): with nn.stateful(train_state.model_state) as new_model_state: logits, reps, _ = flax_module(batch['inputs'], train=True, return_activations=True) # Get mathing between examples from the mini batch: matching_matrix = pipeline_utils.get_self_matching_matrix( batch, reps[sampled_layer], mode=self.hparams.get('intra_mixup_mode', 'random'), label_cost=self.hparams.get('intra_mixup_label_cost', 1.0), l2_cost=self.hparams.get('intra_mixup_l2_cost', 0.001)) beta_params = self.hparams.get('beta_schedule_params') or { 'initial_value': 1.0, 'mode': 'constant' } alpha_params = self.hparams.get('alpha_schedule_params') or { 'initial_value': 1.0, 'mode': 'constant' } step = train_state.global_step beta = pipeline_utils.scheduler(step, beta_params) alpha = pipeline_utils.scheduler(step, alpha_params) with nn.stochastic(mixup_rng): with nn.stateful(new_model_state) as new_model_state: new_logits, sample_lambdas = self.interpolate_and_predict( nn.make_rng(), flax_module, matching_matrix, reps, sampled_layer, alpha, beta) new_batch = copy.deepcopy(batch) # Compute labels for the interpolated states: new_batch['label'] = tensor_util.convex_interpolate( batch['label'], batch['label'][jnp.argmax(matching_matrix, axis=-1)], sample_lambdas) # Compute weights for the interpolated states: if batch.get('weights') is not None: new_batch['weights'] = tensor_util.convex_interpolate( batch['weights'], batch['weights'][jnp.argmax(matching_matrix, axis=-1)], sample_lambdas) # Standard loss: loss = self.task.loss_function(logits, batch, flax_module.params) # Add the loss from interpolated states: loss += self.task.loss_function(new_logits, new_batch) return loss, (new_model_state, logits)