def loss_fn(model): if train: with nn.stateful(state.model_state) as new_model_state: with nn.stochastic(run_rng): if not class_conditional: scores = model(perturbed_data, labels, train=train) else: scores = model(perturbed_data, labels, y=class_labels, train=train) else: with nn.stateful(state.model_state, mutable=False): with nn.stochastic(run_rng): if not class_conditional: scores = model(perturbed_data, labels, train=train) else: scores = model(perturbed_data, labels, y=class_labels, train=train) new_model_state = state.model_state scores = scores.reshape((scores.shape[0], -1)) target = -1 / (used_sigmas**2) * noise target = target.reshape((target.shape[0], -1)) losses = 1 / 2. * ((scores - target)**2).sum( axis=-1) * used_sigmas.squeeze()**anneal_power loss = jnp.mean(losses) if loss_per_sigma: return loss, new_model_state, losses else: return loss, new_model_state
def training_loss_fn(self, flax_model, train_state, batch, dropout_rng, env_ids): """Runs forward pass and computes loss. Args: flax_model: A flax module. train_state: TrainState, the state of training including the current global_step, model_state, rng, and optimizer. batch: Batches from different environments. dropout_rng: FLAX PRNG key. env_ids: list(int); List of environment codes. Returns: loss, new_module_state and computed logits for each batch. """ del env_ids inputs = pipeline_utils.get_multi_env_inputs(batch, 'inputs') with nn.stochastic(dropout_rng): env_logits, new_model_state = pipeline_utils.vmapped_flax_module_train( flax_model, train_state.model_state, inputs) # Model state, e.g. batch statistics, are averaged over all environments # because we use vmapped_flax_module_train. new_model_state = jax.tree_util.tree_map( functools.partial(jnp.mean, axis=0), new_model_state) loss = self.task.loss_function(env_logits, batch, flax_model.params, train_state.global_step) logs = None return loss, (new_model_state, env_logits, logs)
def forward_pass(self, flax_model, train_state, batch, rng, input_layer_key='input', train=True): """Forward pass. Args: flax_model: flax.deprecated.nn.Model; Flax model. train_state: TrainState object. batch: dict; A batch of examples. rng: float; Jax random number generator key. input_layer_key: str; Which layer the input should be plugged in. train: bool; Train flag. Returns: logits, hidden activations, activations of key layer, and new model state. """ # bind the rng to the host/device we are on. rng = pipeline_utils.bind_rng_to_host_device( rng, axis_name='batch', bind_to=['host', 'device']) inputs = batch['inputs'] with nn.stochastic(rng): (logits, all_reps, selected_reps, new_model_state) = pipeline_utils.forward_pass_with_reps( inputs, flax_model, train_state.model_state, input_layer_key, train) selected_reps = selected_reps.reshape( (selected_reps.shape[0], selected_reps.shape[1], -1)) return logits, all_reps, selected_reps, new_model_state
def impl_loss_fn(model_params): with nn.stochastic(rng), nn.stateful( state.model_state) as new_model_state: logits, stats = module.call(model_params, batch["image"]) losses = loss_fn if isinstance(loss_fn, (list, tuple)) else [loss_fn] loss = sum(l(logits, batch["label"], stats) for l in losses) return loss, (logits, new_model_state, stats)
def training_loss_fn(self, flax_model, train_state, batch, dropout_rng, env_ids): """Runs forward pass and computes loss. Args: flax_model: A flax module. train_state: TrainState, the state of training including the current global_step, model_state, rng, and optimizer. batch: Batches from different environments. dropout_rng: FLAX PRNG key. env_ids: list[int]; List of env codes. Returns: loss, new_module_state and computed logits for each batch. """ env_logits, _, selected_env_reps, new_model_state = self.forward_pass( flax_model, train_state, batch, dropout_rng) # Model state, e.g. batch statistics, are averaged over all environments # because we use vmapped_flax_module_train. new_model_state = jax.tree_util.tree_map( functools.partial(jnp.mean, axis=0), new_model_state) with nn.stochastic(dropout_rng): # Compute the total loss (inside nn.stochastic): loss = self.task.loss_function(env_logits, selected_env_reps, batch, env_ids, flax_model.params, train_state.global_step) logs = None return loss, (new_model_state, env_logits, logs)
def loss_fn(model): if train: with nn.stateful(state.model_state) as new_model_state: with nn.stochastic(run_rng): scores = model(perturbed_data, T, train=train) else: with nn.stateful(state.model_state, mutable=False): with nn.stochastic(run_rng): scores = model(perturbed_data, T, train=train) new_model_state = state.model_state scores = scores.reshape((scores.shape[0], -1)) target = noise.reshape((noise.shape[0], -1)) loss = jnp.mean((scores - target)**2) return loss, new_model_state
def eval_step( state, module, batch, metrics_dict, rng): """Compute the metrics for the given model in inference mode. The model is applied to the inputs using all devices on the host. Afterwards metrics are averaged across *all* devices (of all hosts). Args: state: Replicated model state. module: Model function. batch: Inputs that should be evaluated. metrics_dict: A dictionary of metrics, mapping names to metric functions. rng: Jax pseudo-random number generator key. Returns: Dictionary of replicated metrics, stats output by the model and updated PRNG key. """ rng, new_rng = jax.random.split(rng) with nn.stochastic(rng), flax.deprecated.nn.stateful( state.model_state, mutable=False): logits, stats = module.call(state.model_params, batch["image"], train=False) metrics = {m: fn(logits, batch["label"], stats) for (m, fn) in metrics_dict.items()} metrics = jax.lax.all_gather(metrics, axis_name="batch") stats = jax.lax.all_gather(stats, axis_name="batch") return metrics, stats, new_rng
def create_model(module, input_shape, rng): """Instanciates the model.""" model_rng, init_rng = jax.random.split(rng) with nn.stochastic(model_rng), nn.stateful() as init_state: x = jnp.ones(input_shape, dtype=jnp.float32) _, init_params = module.init(init_rng, x) model = nn.Model(module, init_params) return model, init_params, init_state
def get_reps(train_state, flax_module, batch): with nn.stochastic(train_state.rng): with nn.stateful(train_state.model_state): _, reps, _ = flax_module(batch['inputs'], train=True, return_activations=True) return reps
def _create_flax_module(): device_batch_size = hparams.batch_size // jax.device_count() shape = (device_batch_size, ) + tuple(input_shape[1:]) model_rng, init_rng = jax.random.split(rng) with nn.stateful() as init_model_state: with nn.stochastic(model_rng): _, initial_params = flax_module_def.init_by_shape( init_rng, [(shape, model_input_dtype)]) flax_module = nn.Model(flax_module_def, initial_params) num_trainable_params = model_utils.log_param_shapes(flax_module) return flax_module, init_model_state, num_trainable_params
def train_step( state, batch, module, loss_fn, optimizer, metrics_dict, rng ): """Perform a single training step. Args: state: Current training state. Updated training state will be returned. batch: Training inputs for this step. module: Module function. loss_fn: Loss function that takes logits and labels as input. optimizer: Optax optimizer to compute updates from gradients. metrics_dict: A dictionary of metrics, mapping names to metric functions. rng: Jax pseudo-random number generator key. Returns: Tuple of updated state, dictionary with metrics, and updated PRNG key. """ rng, new_rng = jax.random.split(rng) def impl_loss_fn(model_params): with nn.stochastic(rng), nn.stateful(state.model_state) as new_model_state: logits, stats = module.call(model_params, batch["image"]) losses = loss_fn if isinstance(loss_fn, (list, tuple)) else [loss_fn] loss = sum(l(logits, batch["label"], stats) for l in losses) return loss, (logits, new_model_state, stats) grad_fn = jax.value_and_grad(impl_loss_fn, has_aux=True) with nn.stochastic(rng): (_, loss_aux), grad = grad_fn(state.model_params) logits, new_model_state, stats = loss_aux # Compute average gradient across multiple workers. grad = jax.lax.pmean(grad, axis_name="batch") updates, new_opt_state = optimizer.update(grad, state.optimizer_state, params=state.model_params) new_model_params = optax.apply_updates(state.model_params, updates) metrics = {m: fn(logits, batch["label"], stats) for (m, fn) in metrics_dict.items()} metrics = jax.lax.all_gather(metrics, axis_name="batch") stats = jax.lax.all_gather(stats, axis_name="batch") stats = jax.tree_map(lambda x: x[0], stats) new_state = state.replace( # pytype: disable=attribute-error step=state.step + 1, optimizer_state=new_opt_state, model_state=new_model_state, model_params=new_model_params) return new_state, grad, updates, metrics, stats, new_rng
def training_loss_fn(self, flax_module, train_state, batch, dropout_rng): """Runs forward pass and computes loss. Args: flax_module: A flax module. train_state: TrainState, the state of training including the current global_step, model_state, rng, and optimizer. batch: Batches from different environments. dropout_rng: FLAX PRNG key. Returns: loss, new_module_state and computed logits for each batch. """ with nn.stateful(train_state.model_state) as new_model_state: with nn.stochastic(dropout_rng): logits = flax_module(batch['inputs'], train=True) loss = self.task.loss_function(logits, batch, flax_module.params) return loss, (new_model_state, logits)
def pseudo_label_generator(batch, train_state, pseudo_labels_transformer_fn=lambda x: (x, None), input_key='inputs', train=True): """Pseudo label generator passed to the dataset class. This function can be passed to datasets initializer for self-supervised training or distillation. Args: batch: dict; Batch of examples, witch an 'inputs' key. train_state: TrainState; Train state of the model which we want to use to generate pseudo labels. pseudo_labels_transformer_fn: function; A function that applies a specific transformation on the logits from the model to generate the labels. The most basic function to be used here is a simple softmax or argmax to get one-hot labels. This function should return the labels and the weights for each example in the batch (for each label) and has the following API: ``` new_labels, weights = pseudo_labels_transformer(logits) ``` input_key: str; What key to use to retrieve the input field of the batch. train: bool; Train flag passed to the model forward pass. Returns: Return the batch with ground truth labels and weights replaced with pseudo labels and new weights. """ inputs = batch[input_key] _, dropout_rng = jax.random.split(train_state.rng) with nn.stochastic(dropout_rng): with nn.stateful(train_state.model_state): logits = train_state.optimizer.target(inputs, train=train) # Make sure the parameter of the teacher are not updated. logits = jax.lax.stop_gradient(logits) batch['label'], weights = pseudo_labels_transformer_fn(logits) if weights is not None: batch['weights'] = weights return batch
def loss_fn(model): """Loss function used for training.""" # Stateful collection for tracking internal state like activations. with nn.stateful() as batch_stats: with nn.stochastic(dropout_rng): outputs = model(inputs, train=True, cache=None) if isinstance(outputs, dict): logits = outputs.get('logits', None) regression_predictions = outputs.get('regression', None) else: logits = outputs regression_predictions = None mean_loss = 0.0 # Classification loss if classification_targets is not None: classification_loss, classification_weight_sum = utils.compute_weighted_cross_entropy( logits, classification_targets, token_weights=classification_weights, example_weights=example_weights) classification_weight_sum = jnp.maximum(classification_weight_sum, epsilon) # Handle case where nothing is masked out in BERT # (Only occurs with very short sequences). mean_classification_loss = classification_loss / classification_weight_sum mean_loss += mean_classification_loss if regression_targets is not None: regression_loss, regression_weight_sum = utils.compute_weighted_mse( regression_predictions, regression_targets, weights=regression_weights) regression_weight_sum = jnp.maximum(regression_weight_sum, epsilon) mean_regression_loss = regression_loss / regression_weight_sum outputs['regression_loss'] = mean_regression_loss # TODO(ddohan): Allow weighting each loss separately. mean_loss += mean_regression_loss return mean_loss, (outputs, batch_stats)
def forward_pass(self, flax_model, train_state, batch, rng, input_layer_key='input', train=True): # bind the rng to the host/device we are on. rng = pipeline_utils.bind_rng_to_host_device( rng, axis_name='batch', bind_to=['host', 'device']) inputs = pipeline_utils.get_multi_env_inputs(batch, 'inputs') with nn.stochastic(rng): (env_logits, all_env_reps, selected_env_reps, new_model_state) = pipeline_utils.vmapped_flax_module_with_reps( inputs, flax_model, train_state.model_state, input_layer_key, train) selected_env_reps = selected_env_reps.reshape( (selected_env_reps.shape[0], selected_env_reps.shape[1], -1)) return env_logits, all_env_reps, selected_env_reps, new_model_state
def __init__(self, model_cls, task, hparams, experiment_dir, tb_summary_writer, rng): rng, init_rng = jax.random.split(rng) super().__init__(model_cls, task, hparams, experiment_dir, tb_summary_writer, init_rng) # Set up state transformers to compute the representation based # auxilary loss. # Get sample batch # TODO(samiraabnar): Refactor this by implementing a sample_batch for task. _, train_iters = list( zip(*dict(self.task.dataset.data_iters['train']).items())) init_batch = self.get_next_batch(train_iters) # Run the forward pass once to get the representations and their dimensions. flax_model = self.train_state.optimizer.target with nn.stochastic(rng): _, _, selected_env_reps, _ = jax.pmap( self.forward_pass, axis_name='batch')(flax_model, self.train_state, init_batch, self.train_state.rng) self.task.setup_transformers(hidden_reps_dim=selected_env_reps.shape[-1])
def get_env_aligned_pairs_idx(self, env_reps, env_batches, env_ids): """Computes alignments between all environment pairs. Args: env_reps: jnp array; Reps for different environments (sharded). env_batches: list of dict; Batches of different environments (sharded). env_ids: jnp array; Environment ids. Returns: alignment between batches of environment pairs (sharded). """ # TODO(riannevdberg, samiraabnar): aligning is done on the total # unsharded batch, but that requires access between local batches # when computing the loss. Unsure why this works! To be compatible # with random alignment and sinkhorn soft alignment we should do # alignment only within local batches. env_reps = shard_util.unshard_env_batch(env_reps) env_batches = shard_util.unshard(env_batches) with nn.stochastic(jax_utils.unreplicate(self.train_state.rng)): alignments = self.task.get_env_aligned_pairs_idx(env_reps, env_batches, env_ids) alignments = dataset_utils.shard(alignments) return alignments
def train(config, workdir): """Runs a training loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ # Create directories for experimental logs tf.io.gfile.makedirs(workdir) sample_dir = os.path.join(workdir, "samples") tf.io.gfile.makedirs(sample_dir) rng = jax.random.PRNGKey(config.seed) tb_dir = os.path.join(workdir, "tensorboard") tf.io.gfile.makedirs(tb_dir) if jax.host_id() == 0: writer = tensorboard.SummaryWriter(tb_dir) # Initialize model. rng, model_rng = jax.random.split(rng) model_name = config.model.name ncsn_def = mutils.get_model(model_name).partial(config=config) rng, run_rng = jax.random.split(rng) # Whether the generative model is conditioned on class labels class_conditional = "conditional" in config.training.loss.lower() with nn.stateful() as init_model_state: with nn.stochastic(run_rng): input_shape = (jax.local_device_count(), config.data.image_size, config.data.image_size, 3) input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)] if class_conditional: input_list.append(input_list[-1]) _, initial_params = ncsn_def.init_by_shape( model_rng, input_list, train=True) ncsn = nn.Model(ncsn_def, initial_params) optimizer = losses.get_optimizer(config).create(ncsn) state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr, model_state=init_model_state, ema_rate=config.model.ema_rate, params_ema=initial_params, rng=rng) # pytype: disable=wrong-keyword-args del ncsn, init_model_state # Do not keep a copy of the initial model. # Create checkpoints directory and the initial checkpoint checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = utils.Checkpoint( checkpoint_dir, max_to_keep=None) ckpt.restore_or_initialize(state) # Save intermediate checkpoints to resume training automatically checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta") ckpt_meta = utils.Checkpoint( checkpoint_meta_dir, max_to_keep=1) state = ckpt_meta.restore_or_initialize(state) initial_step = int(state.step) rng = state.rng # Build input pipeline. rng, ds_rng = jax.random.split(rng) train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config) train_iter = iter(train_ds) # pytype: disable=wrong-arg-types eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types scaler = datasets.get_data_scaler(config) # data normalizer inverse_scaler = datasets.get_data_inverse_scaler(config) # Distribute training. optimize_fn = losses.optimization_manager(config) if config.training.loss.lower() == "ddpm": # Use score matching loss with DDPM-type perturbation. ddpm_params = mutils.get_ddpm_params() train_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params, train=True, optimize_fn=optimize_fn) eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params, train=False) else: # Use score matching loss with NCSN-type perturbation. sigmas = mutils.get_sigmas(config) # Whether to use a continuous distribution of noise levels continuous = "continuous" in config.training.loss.lower() train_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, class_conditional=class_conditional, continuous=continuous, train=True, optimize_fn=optimize_fn, anneal_power=config.training.anneal_power) eval_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, class_conditional=class_conditional, continuous=continuous, train=False, anneal_power=config.training.anneal_power) p_train_step = jax.pmap(train_step, axis_name="batch") p_eval_step = jax.pmap(eval_step, axis_name="batch") state = flax_utils.replicate(state) num_train_steps = config.training.n_iters logging.info("Starting training loop at step %d.", initial_step) rng = jax.random.fold_in(rng, jax.host_id()) for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. # Convert data to JAX arrays. Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter)) # pylint: disable=protected-access rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) loss, state = p_train_step(next_rng, state, batch) loss = flax.jax_utils.unreplicate(loss) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) if jax.host_id() == 0 and step % 50 == 0: logging.info("step: %d, training_loss: %.5e", step, loss) writer.scalar("training_loss", loss, step) # Save a temporary checkpoint to resume training after pre-emption. if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id( ) == 0: saved_state = flax_utils.unreplicate(state) saved_state = saved_state.replace(rng=rng) ckpt_meta.save(saved_state) # Report the loss on an evaluation dataset. if step % 100 == 0: rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter)) # pylint: disable=protected-access eval_loss, _ = p_eval_step(next_rng, state, eval_batch) eval_loss = flax.jax_utils.unreplicate(eval_loss) if jax.host_id() == 0: logging.info("step: %d, eval_loss: %.5e", step, eval_loss) writer.scalar("eval_loss", eval_loss, step) # Save a checkpoint periodically and generate samples. if (step + 1) % config.training.snapshot_freq == 0 or step == num_train_steps: # Save the checkpoint. if jax.host_id() == 0: saved_state = flax_utils.unreplicate(state) saved_state = saved_state.replace(rng=rng) ckpt.save(saved_state) # Generate and save samples if config.training.snapshot_sampling: rng, sample_rng = jax.random.split(rng) init_shape = tuple(train_ds.element_spec["image"].shape) samples = sampling.get_samples(sample_rng, config, flax_utils.unreplicate(state), init_shape, scaler, inverse_scaler, class_conditional=class_conditional) this_sample_dir = os.path.join( sample_dir, "iter_{}_host_{}".format(step, jax.host_id())) tf.io.gfile.makedirs(this_sample_dir) if config.sampling.final_only: # Do not save intermediate samples sample = samples[-1] image_grid = sample.reshape((-1, *sample.shape[2:])) nrow = int(np.sqrt(image_grid.shape[0])) sample = np.clip(sample * 255, 0, 255).astype(np.uint8) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.np"), "wb") as fout: np.save(fout, sample) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.png"), "wb") as fout: utils.save_image(image_grid, fout, nrow=nrow, padding=2) else: # Save all intermediate samples produced during sampling. for i, sample in enumerate(samples): image_grid = sample.reshape((-1, *sample.shape[2:])) nrow = int(np.sqrt(image_grid.shape[0])) sample = np.clip(sample * 255, 0, 255).astype(np.uint8) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample_{}.np".format(i)), "wb") as fout: np.save(fout, sample) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample_{}.png".format(i)), "wb") as fout: utils.save_image(image_grid, fout, nrow=nrow, padding=2)
def evaluate(config, workdir, eval_folder = "eval"): """Evaluate trained models. Args: config: Configuration to use. workdir: Working directory for checkpoints. eval_folder: The subfolder for storing evaluation results. Default to "eval". """ # Create eval_dir eval_dir = os.path.join(workdir, eval_folder) tf.io.gfile.makedirs(eval_dir) rng = jax.random.PRNGKey(config.seed + 1) # Build input pipeline. rng, ds_rng = jax.random.split(rng) _, eval_ds, _ = datasets.get_dataset(ds_rng, config, evaluation=True) scaler = datasets.get_data_scaler(config) inverse_scaler = datasets.get_data_inverse_scaler(config) # Initialize model. rng, model_rng = jax.random.split(rng) model_name = config.model.name ncsn_def = mutils.get_model(model_name).partial(config=config) rng, run_rng = jax.random.split(rng) class_conditional = "conditional" in config.training.loss.lower() with nn.stateful() as init_model_state: with nn.stochastic(run_rng): input_shape = tuple(eval_ds.element_spec["image"].shape[1:]) input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)] if class_conditional: input_list.append(input_list[-1]) _, initial_params = ncsn_def.init_by_shape( model_rng, input_list, train=True) ncsn = nn.Model(ncsn_def, initial_params) optimizer = losses.get_optimizer(config).create(ncsn) state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr, model_state=init_model_state, ema_rate=config.model.ema_rate, params_ema=initial_params, rng=rng) # pytype: disable=wrong-keyword-args del ncsn, init_model_state # Do not keep a copy of the initial model. checkpoint_dir = os.path.join(workdir, "checkpoints") if config.training.loss.lower() == "ddpm": # Use the score matching loss with DDPM-type perturbation. ddpm_params = mutils.get_ddpm_params() eval_step = functools.partial( losses.ddpm_loss, ddpm_params=ddpm_params, train=False) else: # Use the score matching loss with NCSN-type perturbation. sigmas = mutils.get_sigmas(config) continuous = "continuous" in config.training.loss.lower() eval_step = functools.partial( losses.ncsn_loss, sigmas=sigmas, continuous=continuous, class_conditional=class_conditional, train=False, anneal_power=config.training.anneal_power) p_eval_step = jax.pmap(eval_step, axis_name="batch") rng = jax.random.fold_in(rng, jax.host_id()) # A data class for checkpointing. @flax.struct.dataclass class EvalMeta: ckpt_id: int round_id: int rng: Any # Add one additional round to get the exact number of samples as required. num_rounds = config.eval.num_samples // config.eval.batch_size + 1 eval_meta = EvalMeta(ckpt_id=config.eval.begin_ckpt, round_id=-1, rng=rng) eval_meta = checkpoints.restore_checkpoint( eval_dir, eval_meta, step=None, prefix=f"meta_{jax.host_id()}_") if eval_meta.round_id < num_rounds - 1: begin_ckpt = eval_meta.ckpt_id begin_round = eval_meta.round_id + 1 else: begin_ckpt = eval_meta.ckpt_id + 1 begin_round = 0 rng = eval_meta.rng # Use inceptionV3 for images with higher resolution inceptionv3 = config.data.image_size >= 256 inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3) logging.info("begin checkpoint: %d", begin_ckpt) for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1): ckpt_filename = os.path.join(checkpoint_dir, "ckpt-{}.flax".format(ckpt)) # Wait if the target checkpoint hasn't been produced yet. waiting_message_printed = False while not tf.io.gfile.exists(ckpt_filename): if not waiting_message_printed and jax.host_id() == 0: logging.warn("Waiting for the arrival of ckpt-%d.flax", ckpt) waiting_message_printed = True time.sleep(10) # In case the file was just written and not ready to read from yet. try: state = utils.load_state_dict(ckpt_filename, state) except: time.sleep(60) try: state = utils.load_state_dict(ckpt_filename, state) except: time.sleep(120) state = utils.load_state_dict(ckpt_filename, state) pstate = flax.jax_utils.replicate(state) eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types # Compute the loss function on the full evaluation dataset. all_losses = [] for i, batch in enumerate(eval_iter): rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) next_rng = jnp.asarray(next_rng) eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), batch) # pylint: disable=protected-access eval_loss, _ = p_eval_step(next_rng, pstate, eval_batch) eval_loss = flax.jax_utils.unreplicate(eval_loss) all_losses.append(eval_loss) if (i + 1) % 1000 == 0 and jax.host_id() == 0: logging.info("Finished %dth step loss evaluation", i + 1) all_losses = jnp.asarray(all_losses) state = jax.device_put(state) # Sampling and computing statistics for Inception scores, FIDs, and KIDs. # Designed to be pre-emption safe. Automatically resumes when interrupted. for r in range(begin_round, num_rounds): if jax.host_id() == 0: logging.info("sampling -- ckpt: %d, round: %d", ckpt, r) rng, sample_rng = jax.random.split(rng) init_shape = tuple(eval_ds.element_spec["image"].shape) this_sample_dir = os.path.join( eval_dir, f"ckpt_{ckpt}_host_{jax.host_id()}") tf.io.gfile.makedirs(this_sample_dir) samples = sampling.get_samples(sample_rng, config, state, init_shape, scaler, inverse_scaler, class_conditional=class_conditional) samples = samples[-1] samples = np.clip(samples * 255., 0, 255).astype(np.uint8) samples = samples.reshape( (-1, config.data.image_size, config.data.image_size, 3)) with tf.io.gfile.GFile( os.path.join(this_sample_dir, f"samples_{r}.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, samples=samples) fout.write(io_buffer.getvalue()) gc.collect() latents = evaluation.run_inception_distributed(samples, inception_model, inceptionv3=inceptionv3) gc.collect() with tf.io.gfile.GFile( os.path.join(this_sample_dir, f"statistics_{r}.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed( io_buffer, pool_3=latents["pool_3"], logits=latents["logits"]) fout.write(io_buffer.getvalue()) eval_meta = eval_meta.replace(ckpt_id=ckpt, round_id=r, rng=rng) # Save an intermediate checkpoint directly if not the last round. # Otherwise save eval_meta after computing the Inception scores and FIDs if r < num_rounds - 1: checkpoints.save_checkpoint( eval_dir, eval_meta, step=ckpt * num_rounds + r, keep=1, prefix=f"meta_{jax.host_id()}_") # Compute inception scores, FIDs and KIDs. if jax.host_id() == 0: # Load all statistics that have been previously computed and saved. all_logits = [] all_pools = [] for host in range(jax.host_count()): this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}_host_{host}") stats = tf.io.gfile.glob( os.path.join(this_sample_dir, "statistics_*.npz")) wait_message = False while len(stats) < num_rounds: if not wait_message: logging.warn("Waiting for statistics on host %d", host) wait_message = True stats = tf.io.gfile.glob( os.path.join(this_sample_dir, "statistics_*.npz")) time.sleep(1) for stat_file in stats: with tf.io.gfile.GFile(stat_file, "rb") as fin: stat = np.load(fin) if not inceptionv3: all_logits.append(stat["logits"]) all_pools.append(stat["pool_3"]) if not inceptionv3: all_logits = np.concatenate( all_logits, axis=0)[:config.eval.num_samples] all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples] # Load pre-computed dataset statistics. data_stats = evaluation.load_dataset_stats(config) data_pools = data_stats["pool_3"] if hasattr(config.eval, "num_partitions"): # Divide samples into several partitions and compute FID/KID/IS on them. assert not inceptionv3 fids = [] kids = [] inception_scores = [] partition_size = config.eval.num_samples // config.eval.num_partitions tf_data_pools = tf.convert_to_tensor(data_pools) for i in range(config.eval.num_partitions): this_pools = all_pools[i * partition_size:(i + 1) * partition_size] this_logits = all_logits[i * partition_size:(i + 1) * partition_size] inception_scores.append( tfgan.eval.classifier_score_from_logits(this_logits)) fids.append( tfgan.eval.frechet_classifier_distance_from_activations( data_pools, this_pools)) this_pools = tf.convert_to_tensor(this_pools) kids.append( tfgan.eval.kernel_classifier_distance_from_activations( tf_data_pools, this_pools).numpy()) fids = np.asarray(fids) inception_scores = np.asarray(inception_scores) kids = np.asarray(kids) with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_all_{ckpt}.npz"), "wb") as f: io_buffer = io.BytesIO() np.savez_compressed( io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(), ISs=inception_scores, fids=fids, kids=kids) f.write(io_buffer.getvalue()) else: # Compute FID/KID/IS on all samples together. if not inceptionv3: inception_score = tfgan.eval.classifier_score_from_logits(all_logits) else: inception_score = -1 fid = tfgan.eval.frechet_classifier_distance_from_activations( data_pools, all_pools) # Hack to get tfgan KID work for eager execution. tf_data_pools = tf.convert_to_tensor(data_pools) tf_all_pools = tf.convert_to_tensor(all_pools) kid = tfgan.eval.kernel_classifier_distance_from_activations( tf_data_pools, tf_all_pools).numpy() del tf_data_pools, tf_all_pools logging.info( "ckpt-%d --- loss: %.6e, inception_score: %.6e, FID: %.6e, KID: %.6e", ckpt, all_losses.mean(), inception_score, fid, kid) with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_{ckpt}.npz"), "wb") as f: io_buffer = io.BytesIO() np.savez_compressed( io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(), IS=inception_score, fid=fid, kid=kid) f.write(io_buffer.getvalue()) else: # For host_id() != 0. # Use file existence to emulate synchronization across hosts. if hasattr(config.eval, "num_partitions"): assert not inceptionv3 while not tf.io.gfile.exists( os.path.join(eval_dir, f"report_all_{ckpt}.npz")): time.sleep(1.) else: while not tf.io.gfile.exists( os.path.join(eval_dir, f"report_{ckpt}.npz")): time.sleep(1.) # Save eval_meta after computing IS/KID/FID to mark the end of evaluation # for this checkpoint. checkpoints.save_checkpoint( eval_dir, eval_meta, step=ckpt * num_rounds + r, keep=1, prefix=f"meta_{jax.host_id()}_") begin_round = 0 # Remove all meta files after finishing evaluation. meta_files = tf.io.gfile.glob( os.path.join(eval_dir, f"meta_{jax.host_id()}_*")) for file in meta_files: tf.io.gfile.remove(file)
def training_loss_fn(self, flax_model, train_state, teacher_train_state, batch, unlabeled_batch, dropout_rng, env_ids, unlabeled_env_ids, sampled_layer): """Runs forward pass and computes loss. Args: flax_model: A flax module. train_state: TrainState; The state of training including the current global_step, model_state, rng, and optimizer. teacher_train_state: TrainState; The state of training for the teacher (including the current global_step, model_state, rng, and optimizer). batch: list(dict); A batch of data for each environment in the labeld set. unlabeled_batch: list(dict); A batch of data for each environment in the unlabeld set. dropout_rng: FLAX PRNG key. env_ids: list(int); List of labeled training environments ids. unlabeled_env_ids: list(int); List of unlabeled environments ids. sampled_layer: str; Name of the layer on which mixup is applied. Returns: loss, new_module_state and computed logits for each batch. """ dropout_rng, new_rng = jax.random.split(dropout_rng) with nn.stochastic(dropout_rng): # Run student forward pass on the labeled envs. (all_std_env_reps, std_env_logits, _, train_state) = self.stateful_forward_pass(flax_model, train_state, batch) # Run teacher forward pass on the labeled envs. (labeled_tchr_env_logits, _, _) = self.stateless_forward_pass( teacher_train_state.optimizer.target, teacher_train_state, batch) # Run teacher forward pass on the unlabeled envs. (unlabeled_tchr_env_logits, all_tchr_unlabeled_env_reps, _) = self.stateless_forward_pass( teacher_train_state.optimizer.target, teacher_train_state, unlabeled_batch) # Replace labels with predicted labels from the teacher model. for ub_id in range(len(unlabeled_env_ids)): unlabeled_batch[ub_id]['label'] = jnp.argmax( unlabeled_tchr_env_logits[ub_id], axis=-1) # Get sampled layer for interpolations: std_sampled_reps = all_std_env_reps[sampled_layer] sampled_unlabeled_reps = all_tchr_unlabeled_env_reps[sampled_layer] interpolation_rng, new_rng = jax.random.split(new_rng) with nn.stochastic(interpolation_rng): (interpolated_batches, interpolated_logits, _, train_state) = self.maybe_inter_env_interpolation( batch, env_ids, flax_model, self.intra_interpolate_fn, sampled_layer, std_sampled_reps, std_sampled_reps, train_state) (same_env_interpolated_batches, same_env_interpolated_logits, _, train_state) = self.maybe_intra_env_interpolation( batch, env_ids, flax_model, self.intra_interpolate_fn, sampled_layer, std_sampled_reps, train_state) (unlabeled_interpolated_batches, unlabeled_interpolated_logits, unlabeled_mixup_lambdas, unlabeled_mixup_alpha, unlabeled_mixup_beta, train_state) = self.maybe_gradual_interpolation( batch, unlabeled_batch, env_ids, unlabeled_env_ids, flax_model, self.interpolate_fn, sampled_layer, std_sampled_reps, sampled_unlabeled_reps, std_sampled_reps, sampled_unlabeled_reps, labeled_tchr_env_logits, unlabeled_tchr_env_logits, train_state, teacher_train_state) # Compute the total loss (inside nn.stochastic): # env_reps and env_ids are set to None to avoid computing a loss for # domain mapping (the mapping model is not trained and not used in # computing the loss). ground_truth_factor_params = pipeline_utils.get_weight_param( self.hparams, 'ground_truth_factor', 1.0) ground_truth_factor = pipeline_utils.scheduler( train_state.global_step, ground_truth_factor_params) ground_truth_loss = self.task.loss_function( std_env_logits, None, batch, None, flax_model.params, train_state.global_step) loss = ground_truth_loss * ground_truth_factor # Add the loss for cross environment interpolated states: if len(env_ids) > 1 and self.hparams.get('inter_env_interpolation', True): inter_mixup_factor_params = pipeline_utils.get_weight_param( self.hparams, 'inter_mixup_factor', 1.0) inter_mixup_factor = pipeline_utils.scheduler( train_state.global_step, inter_mixup_factor_params) loss += self.task.loss_function( interpolated_logits, None, interpolated_batches, None, None, train_state.global_step) * inter_mixup_factor # Add the loss for same environment interpolated states: if self.hparams.get('intra_env_interpolation', True): intra_mixup_factor_params = pipeline_utils.get_weight_param( self.hparams, 'intra_mixup_factor', 1.0) intra_mixup_factor = pipeline_utils.scheduler( train_state.global_step, intra_mixup_factor_params) loss += self.task.loss_function( same_env_interpolated_logits, None, same_env_interpolated_batches, None, None, train_state.global_step) * intra_mixup_factor # Add the loss for gradual environment interpolations toward unlabeled # target environment(s): unlabeled_mixup_factor = 0 unlabeled_loss = 0 if self.hparams.get('unlabeled_interpolation', True): unlabeled_mixup_factor_params = pipeline_utils.get_weight_param( self.hparams, 'unlabeled_mixup_factor', 1.0) unlabeled_mixup_factor = pipeline_utils.scheduler( train_state.global_step, unlabeled_mixup_factor_params) unlabeled_loss = self.task.loss_function( unlabeled_interpolated_logits, None, unlabeled_interpolated_batches, None, None, train_state.global_step) loss += unlabeled_loss * unlabeled_mixup_factor logs = {} logs['unlabeled_mixup_lambda'] = unlabeled_mixup_lambdas logs['unlabeled_mixup_alpha'] = unlabeled_mixup_alpha logs['unlabeled_mixup_beta'] = unlabeled_mixup_beta logs['unlabeled_mixup_factor'] = unlabeled_mixup_factor logs['train_loss'] = ground_truth_loss logs['unlabeled_loss'] = unlabeled_loss return loss, (train_state.model_state, std_env_logits, logs)
def training_loss_fn(self, flax_module, train_state, batch, dropout_rng, mixup_rng, sampled_layer): """Runs forward pass and computes loss. Args: flax_module: A flax module. train_state: TrainState, the state of training including the current global_step, model_state, rng, and optimizer. batch: Batches from different environments. dropout_rng: FLAX PRNG key. mixup_rng: FLAX PRNG key. sampled_layer: str; Name of the layer on which mixup will be applied. Returns: loss, new_module_state and computed logits for each batch. """ with nn.stochastic(dropout_rng): with nn.stateful(train_state.model_state) as new_model_state: logits, reps, _ = flax_module(batch['inputs'], train=True, return_activations=True) # Get mathing between examples from the mini batch: matching_matrix = pipeline_utils.get_self_matching_matrix( batch, reps[sampled_layer], mode=self.hparams.get('intra_mixup_mode', 'random'), label_cost=self.hparams.get('intra_mixup_label_cost', 1.0), l2_cost=self.hparams.get('intra_mixup_l2_cost', 0.001)) beta_params = self.hparams.get('beta_schedule_params') or { 'initial_value': 1.0, 'mode': 'constant' } alpha_params = self.hparams.get('alpha_schedule_params') or { 'initial_value': 1.0, 'mode': 'constant' } step = train_state.global_step beta = pipeline_utils.scheduler(step, beta_params) alpha = pipeline_utils.scheduler(step, alpha_params) with nn.stochastic(mixup_rng): with nn.stateful(new_model_state) as new_model_state: new_logits, sample_lambdas = self.interpolate_and_predict( nn.make_rng(), flax_module, matching_matrix, reps, sampled_layer, alpha, beta) new_batch = copy.deepcopy(batch) # Compute labels for the interpolated states: new_batch['label'] = tensor_util.convex_interpolate( batch['label'], batch['label'][jnp.argmax(matching_matrix, axis=-1)], sample_lambdas) # Compute weights for the interpolated states: if batch.get('weights') is not None: new_batch['weights'] = tensor_util.convex_interpolate( batch['weights'], batch['weights'][jnp.argmax(matching_matrix, axis=-1)], sample_lambdas) # Standard loss: loss = self.task.loss_function(logits, batch, flax_module.params) # Add the loss from interpolated states: loss += self.task.loss_function(new_logits, new_batch) return loss, (new_model_state, logits)
def training_loss_fn(self, flax_model, train_state, batch, dropout_rng, env_ids, sampled_layer): """Runs forward pass and computes loss. Args: flax_model: A flax module. train_state: TrainState, the state of training including the current global_step, model_state, rng, and optimizer. batch: Batches from different environments. dropout_rng: FLAX PRNG key. env_ids: list[int]; List of env codes. sampled_layer: str; Name of the layer on which mixup is applied. Returns: loss, new_module_state and computed logits for each batch. """ dropout_rng, new_rng = jax.random.split(dropout_rng) with nn.stochastic(dropout_rng): # Run student forward pass: (all_env_reps, env_logits, selected_env_reps, train_state) = self.stateful_forward_pass(flax_model, train_state, batch) new_model_state = train_state.model_state sampled_reps = all_env_reps[sampled_layer] interpolate_fn = jax.vmap( pipeline_utils.interpolate, in_axes=(0, 0, 0, 0, None, None, None, None)) interpolate_rng, new_rng = jax.random.split(new_rng) with nn.stochastic(interpolate_rng): (interpolated_batches, interpolated_logits, sampled_lambdas, train_state) = self.maybe_inter_env_interpolation( batch, env_ids, flax_model, interpolate_fn, sampled_layer, sampled_reps, selected_env_reps, train_state) (same_env_interpolated_batches, same_env_interpolated_logits, _, train_state) = self.maybe_intra_env_interpolation( batch, env_ids, flax_model, interpolate_fn, sampled_layer, sampled_reps, train_state) loss_rng, new_rng = jax.random.split(new_rng) with nn.stochastic(loss_rng): # Compute the total loss (inside nn.stochastic): loss = self.task.loss_function(env_logits, selected_env_reps, batch, env_ids, flax_model.params, train_state.global_step) # Add the loss for cross environment interpolated states: if len(env_ids) > 1 and self.hparams.get('inter_env_interpolation', True): inter_mixup_factor = self.hparams.get('inter_mixup_factor', 1.0) loss += self.task.loss_function( interpolated_logits, None, interpolated_batches, None, None, train_state.global_step) * inter_mixup_factor # Add the loss for same environment interpolated states: if self.hparams.get('intra_env_interpolation', True): intra_mixup_factor = self.hparams.get('intra_mixup_factor', 1.0) loss += self.task.loss_function( same_env_interpolated_logits, None, same_env_interpolated_batches, None, None, train_state.global_step) * intra_mixup_factor logs = {'sampled_lambdas': sampled_lambdas} return loss, (new_model_state, env_logits, logs)