def loss_fn(model): if train: with nn.stateful(state.model_state) as new_model_state: with nn.stochastic(run_rng): if not class_conditional: scores = model(perturbed_data, labels, train=train) else: scores = model(perturbed_data, labels, y=class_labels, train=train) else: with nn.stateful(state.model_state, mutable=False): with nn.stochastic(run_rng): if not class_conditional: scores = model(perturbed_data, labels, train=train) else: scores = model(perturbed_data, labels, y=class_labels, train=train) new_model_state = state.model_state scores = scores.reshape((scores.shape[0], -1)) target = -1 / (used_sigmas**2) * noise target = target.reshape((target.shape[0], -1)) losses = 1 / 2. * ((scores - target)**2).sum( axis=-1) * used_sigmas.squeeze()**anneal_power loss = jnp.mean(losses) if loss_per_sigma: return loss, new_model_state, losses else: return loss, new_model_state
def init_single_head(init_rng, args, kwargs): if rng is None: _, head_params = wrapped_module.init(init_rng, *args, **kwargs) else: with nn.stochastic(rng): _, head_params = wrapped_module.init(init_rng, *args, **kwargs) return head_params
def create_model(config): """Create a model, starting with a pre-trained checkpoint.""" model_kwargs = dict(config=config.model, ) model_def = modeling.BertForPreTraining.partial(**model_kwargs) if config.init_checkpoint: initial_params = import_weights.load_params( init_checkpoint=config.init_checkpoint, hidden_size=config.model.hidden_size, num_attention_heads=config.model.num_attention_heads, keep_masked_lm_head=True) else: with nn.stochastic(jax.random.PRNGKey(0)): _, initial_params = model_def.init_by_shape( jax.random.PRNGKey(0), [((1, config.max_seq_length), jnp.int32), ((1, config.max_seq_length), jnp.int32), ((1, config.max_seq_length), jnp.int32), ((1, config.max_predictions_per_seq), jnp.int32)], deterministic=True) def fixup_for_tpu(x, i=[0]): """HACK to fix incorrect param initialization on TPU.""" if isinstance(x, jax.ShapeDtypeStruct): i[0] += 1 if len(x.shape) == 2: return jnp.zeros(x.shape, x.dtype) else: return nn.linear.default_kernel_init( jax.random.PRNGKey(i[0]), x.shape, x.dtype) else: return x initial_params = jax.tree_map(fixup_for_tpu, initial_params) model = nn.Model(model_def, initial_params) return model
def compute_pretraining_stats(model, batch): """Used for computing eval metrics during pre-training.""" with nn.stochastic(jax.random.PRNGKey(0)): masked_lm_logits, next_sentence_logits = model( batch['input_ids'], (batch['input_ids'] > 0).astype(np.int32), batch['token_type_ids'], batch['masked_lm_positions'], deterministic=True) stats = model.compute_metrics(masked_lm_logits, next_sentence_logits, batch['masked_lm_ids'], batch['masked_lm_weights'], batch['next_sentence_label']) masked_lm_correct = jnp.sum( (masked_lm_logits.argmax(-1) == batch['masked_lm_ids'].reshape( (-1, ))) * batch['masked_lm_weights'].reshape((-1, ))) next_sentence_labels = batch['next_sentence_label'].reshape((-1, )) next_sentence_correct = jnp.sum( next_sentence_logits.argmax(-1) == next_sentence_labels) stats = { 'masked_lm_correct': masked_lm_correct, 'masked_lm_total': jnp.sum(batch['masked_lm_weights']), 'next_sentence_correct': next_sentence_correct, 'next_sentence_total': jnp.sum(jnp.ones_like(next_sentence_labels)), **stats } return stats
def training_cost(self, flax_module, batch_stats, batch, dropout_rng): """Return cross entropy loss with (optional) L2 penalty on the weights.""" with nn.stateful(batch_stats) as new_batch_stats: with nn.stochastic(dropout_rng): # inputs/targets positions and segmentations are required # when we have packed examples. logits = flax_module(batch['inputs'], batch['targets'], batch.get('inputs_positions'), batch.get('targets_positions'), batch.get('inputs_segmentation'), batch.get('targets_segmentation'), train=True) weights = batch.get('weights') targets = batch['targets'] if self.dataset_meta_data['apply_one_hot_in_loss']: targets = one_hot(batch['targets'], logits.shape[-1]) # Optionally apply label smoothing. if self.hps.get('label_smoothing') is not None: targets = model_utils.apply_label_smoothing( targets, self.hps.get('label_smoothing')) total_loss = self.loss_fn(logits, targets, weights) if self.hps.get('l2_decay_factor'): l2_loss = model_utils.l2_regularization( flax_module.params, self.hps.l2_decay_rank_threshold) total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss return total_loss, (new_batch_stats)
def eval_step(state, module, batch, metrics_dict, rng): """Compute the metrics for the given model in inference mode. The model is applied to the inputs using all devices on the host. Afterwards metrics are averaged across *all* devices (of all hosts). Args: state: Replicated model state. module: Model function. batch: Inputs that should be evaluated. metrics_dict: A dictionary of metrics, mapping names to metric functions. rng: Jax pseudo-random number generator key. Returns: Dictionary of replicated metrics, stats output by the model and updated PRNG key. """ rng, new_rng = jax.random.split(rng) with nn.stochastic(rng), flax.nn.stateful(state.model_state, mutable=False): logits, stats = module.call(state.model_params, batch["image"], train=False) metrics = { m: fn(logits, batch["label"], stats) for (m, fn) in metrics_dict.items() } metrics = jax.lax.all_gather(metrics, axis_name="batch") stats = jax.lax.all_gather(stats, axis_name="batch") return metrics, stats, new_rng
def impl_loss_fn(model_params): with nn.stochastic(rng), nn.stateful( state.model_state) as new_model_state: logits, stats = module.call(model_params, batch["image"]) losses = loss_fn if isinstance(loss_fn, (list, tuple)) else [loss_fn] loss = sum(l(logits, batch["label"], stats) for l in losses) return loss, (logits, new_model_state, stats)
def test_stochastic_rngs(self): rng = random.PRNGKey(0) with nn.stochastic(rng): r1 = nn.make_rng() r2 = nn.make_rng() self.assertTrue(onp.all(r1 == random.fold_in(rng, 1))) self.assertTrue(onp.all(r2 == random.fold_in(rng, 2)))
def loss_fn(model): """Loss function used for training.""" with nn.stochastic(dropout_rng): logits = model(inputs, train=True) loss, weight_sum = compute_weighted_cross_entropy(logits, targets, weights) mean_loss = loss / weight_sum return mean_loss, logits
def loss_fn(model): if train: with nn.stateful(state.model_state) as new_model_state: with nn.stochastic(run_rng): scores = model(perturbed_data, T, train=train) else: with nn.stateful(state.model_state, mutable=False): with nn.stochastic(run_rng): scores = model(perturbed_data, T, train=train) new_model_state = state.model_state scores = scores.reshape((scores.shape[0], -1)) target = noise.reshape((noise.shape[0], -1)) loss = jnp.mean((scores - target)**2) return loss, new_model_state
def create_model(key, input_shape): def inducing_loc_init(key, shape): return jnp.linspace(-1.5, 1.5, FLAGS.num_inducing_points)[:, jnp.newaxis] kwargs = {} for i in range(1, FLAGS.num_layers + 1): kwargs['kernel_fn_{}_kwargs'.format(i)] = { 'amplitude_init': lambda key, shape: jnp.ones(shape), 'length_scale_init': lambda key, shape: jnp.ones(shape) } kwargs['inducing_var_{}_kwargs'.format(i)] = { 'fixed_locations': False, 'whiten': FLAGS.whiten, 'inducing_locations_init': inducing_loc_init } model_def = DeepGPModel.partial(**kwargs) with nn.stochastic(key): _, params = model_def.init_by_shape(key, [ (input_shape, jnp.float64), ], nn.make_rng(), **kwargs) return nn.Model(model_def, params)
def create_model(key, flax_module, input_shape, model_kwargs): module = flax_module.partial(**model_kwargs) with nn.stochastic(key): _, initial_params = module.init_by_shape(key, [(input_shape, jnp.float32)]) model = nn.Model(module, initial_params) return model
def compute_loss_and_metrics(model, batch, rng): """Compute cross-entropy loss for classification tasks.""" with nn.stochastic(rng): metrics = model(batch['input_ids'], (batch['input_ids'] > 0).astype(np.int32), batch['token_type_ids'], batch['label']) return metrics['loss'], metrics
def initialize(flax_module_def, initializer, loss_fn, input_shape, output_shape, hps, rng, metrics_logger): """Run the given initializer. We initialize in 3 phases. First we run the default initializer that is specified by the model constructor. Next we apply any rescaling as specified by hps.layer_rescale_factors. Finally we run the black box initializer provided by the initializer arg (the default is noop). Args: flax_module_def: An uninitialized flax module definition. initializer: An initializer defined in init_lib. loss_fn: A loss function. input_shape: The input shape of a single data example. output_shape: The output shape of a single data example. hps: A dictionary specifying the model and initializer hparams. rng: An rng key to seed the initialization. metrics_logger: Used for black box initializers that have learning curves. Returns: A tuple (model, batch_stats), where model is the initialized flax.nn.Model and batch_stats is the collection used for batch norm. """ model_dtype = utils.dtype_from_str(hps.model_dtype) # init_by_shape should either pass in a tuple or a list of tuples. # For example, for vision tasks typically input_shape is (image_shape) # For seq2seq tasks, shape can be a list of two tuples corresponding to # input_sequence_shape for encoder and output_sequence_shape for decoder. # TODO(gilmer,ankugarg): Support initializers for list of tuples. if isinstance(input_shape, list): # Typical case for seq2seq models input_specs = [((hps.batch_size, *x), model_dtype) for x in input_shape] else: # Typical case for classification models input_specs = [((hps.batch_size, *input_shape), model_dtype)] params_rng, init_rng, dropout_rng = jax.random.split(rng, num=3) with nn.stateful() as batch_stats: with nn.stochastic(dropout_rng): # Using flax_module_def.create can OOM for larger models, so we must use # create by shape here. # TODO(gilmer) Link to flax issue when bug reporting process finalizes. _, params = flax_module_def.init_by_shape(params_rng, input_specs, train=False) model = nn.Model(flax_module_def, params) if hps.get('layer_rescale_factors'): model = model_utils.rescale_layers(model, hps.layer_rescale_factors) # We don't pass batch_stats to the initializer, the initializer will just # run batch_norm in train mode and does not need to maintain the batch_stats. # TODO(gilmer): We hardcode here weighted_cross_entropy, but this will need # to change for other models. Maybe have meta_loss_inner as an initializer # hyper_param? # TODO(gilmer): instead of passing in weighted_xent, pass in the model and get # the loss from that. new_model = initializer(loss_fn, model, hps, input_shape, output_shape, init_rng, metrics_logger) return new_model, batch_stats
def decode(model, inputs, rng): """Decode inputs.""" init_decoder_input = onehot(CTABLE.encode('=')[0:1], CTABLE.vocab_size) init_decoder_inputs = jnp.tile(init_decoder_input, (inputs.shape[0], get_max_output_len(), 1)) with nn.stochastic(rng): _, predictions = model(inputs, init_decoder_inputs, teacher_force=False) return predictions
def loss_fn(model): """Loss function used for training.""" with nn.stochastic(dropout_rng): logits = model(inputs1, inputs2, train=True) loss, weight_sum = train_utils.compute_weighted_cross_entropy( logits, targets, num_classes=2, weights=None) mean_loss = loss / weight_sum return mean_loss, logits
def create_model(module, input_shape, rng): """Instanciates the model.""" model_rng, init_rng = jax.random.split(rng) with nn.stochastic(model_rng), nn.stateful() as init_state: x = jnp.ones(input_shape, dtype=jnp.float32) _, init_params = module.init(init_rng, x) model = nn.Model(module, init_params) return model, init_params, init_state
def create_model(key, flax_module, input_shape, model_kwargs): """Creates and initializes the mode.""" module = flax_module.partial(**model_kwargs) with nn.stateful() as init_state: with nn.stochastic(key): _, initial_params = module.init_by_shape( key, [(input_shape, jnp.float32)]) model = nn.Model(module, initial_params) return model, init_state
def compute_pretraining_loss_and_metrics(model, batch, rng): """Compute cross-entropy loss for classification tasks.""" with nn.stochastic(rng): metrics = model(batch['input_ids'], (batch['input_ids'] > 0).astype(np.int32), batch['token_type_ids'], batch['masked_lm_positions'], batch['masked_lm_ids'], batch['masked_lm_weights'], batch['next_sentence_label']) return metrics['loss'], metrics
def get_pretrain_model(): """Get pretrain model with pretrained weights loaded. Returns: jax_model: pretrain model with TF pretrained weights loaded to its transformer encoder part """ # Get pretrained TF model configuration and variables from checkpoint if FLAGS.load_tf_weights: if FLAGS.load_mlperf_weights: tf_config, tf_vars, _ = utils.get_mlperf_model_variables( FLAGS.bert_config_file, FLAGS.init_checkpoint) else: tf_config, tf_vars, _ = utils.get_tf_model_variables( FLAGS.bert_config_file, FLAGS.init_checkpoint) else: tf_config = utils.get_tf_config(FLAGS.bert_config_file) # Generate JAX model using same model configuration as TF model if FLAGS.seed: seed = FLAGS.seed else: seed = np.int64(time.time()) logging.info('RNG seed is %d.', seed) rng = random.PRNGKey(seed) tf.random.set_seed(seed) sequence_length = FLAGS.max_seq_length device_batch_size = int(FLAGS.train_batch_size // jax.device_count()) model_kwargs = utils.convert_tf_config_to_jax_bert(tf_config) model_kwargs['num_token_predictions'] = FLAGS.max_predictions_per_seq model_kwargs['num_classes'] = 2 with nn.stochastic(rng): model_def = bert_models.PretrainModel.partial(**model_kwargs) input_shape = (device_batch_size, sequence_length) inputs = jax.numpy.zeros(input_shape, dtype=jnp.int32) _, jax_model = model_def.create(rng, [inputs] * 4) # Update transformer encoder parameters with TF model pretrained weights if FLAGS.load_tf_weights: if FLAGS.load_mlperf_weights: jax_transformer_vars = utils.convert_mlperf_param_dict_to_jax( tf_vars, model_kwargs['emb_dim'], model_kwargs['num_heads']) jax_model.params.update(jax_transformer_vars) else: raise NotImplementedError( 'Loading kerasBERT checkpoint for pretraining not supported yet.') else: encoder_vars = jax_model.params['transformer_encoder'] encoder_vars['self_attention_mask'] = 0.0 masked_lm_vars = jax_model.params['masked_lm'] masked_lm_vars['0'] = 0.0 masked_lm_vars['GatherIndexes_0'] = 0.0 jax_model.params.update({'transformer_encoder': encoder_vars}) jax_model.params.update({'masked_lm': masked_lm_vars}) return jax_model, model_kwargs
def test_init_by_shape_lifts_stochastic(self): class StochasticModule(nn.Module): def apply(self): return nn.make_rng() with nn.stochastic(random.PRNGKey(0)): rng, _ = StochasticModule.init_by_shape(random.PRNGKey(1), []) expected_rng = random.fold_in(random.PRNGKey(0), 1) expected_rng = random.fold_in(expected_rng, 1) self.assertTrue(onp.all(rng == expected_rng))
def compute_regression_stats(model, batch): with nn.stochastic(jax.random.PRNGKey(0)): y = model(batch["input_ids"], (batch["input_ids"] > 0).astype(np.int32), batch["type_ids"], deterministic=True) return { "idx": batch["idx"], "label": batch["label"], "prediction": y[..., 0], }
def loss_fn(model): with nn.stochastic(dropout_rng): use_bf16 = FLAGS.use_bfloat16_activation dtype = jnp.bfloat16 if use_bf16 else jnp.float32 lm_outputs, sentence_outputs = model( inputs, train=True, dtype=dtype) assert lm_outputs.dtype == jnp.float32 assert sentence_outputs.dtype == jnp.float32 total_loss, lm_loss, sentence_loss = get_pretrain_loss( labels, lm_outputs, sentence_outputs) return total_loss, (lm_loss, sentence_loss)
def loss_fn(model): """Loss function used for training.""" with nn.stochastic(dropout_rng): logits = model(inputs, train=True, cache=None) loss, weight_sum = utils.compute_weighted_cross_entropy( logits, inputs, token_weights=token_weights, example_weights=example_weights) mean_loss = loss / weight_sum return mean_loss, logits
def compute_classification_stats(model, batch): with nn.stochastic(jax.random.PRNGKey(0)): y = model(batch['input_ids'], (batch['input_ids'] > 0).astype(np.int32), batch['token_type_ids'], deterministic=True) return { 'idx': batch['idx'], 'label': batch['label'], 'prediction': y.argmax(-1) }
def test_train_one_step(self): batch = train.get_batch(128) rng = random.PRNGKey(0) with nn.stochastic(rng): model = train.create_model() optimizer = train.create_optimizer(model, 0.003) optimizer, train_metrics = train.train_step( optimizer, batch, nn.make_rng()) self.assertLessEqual(train_metrics['loss'], 5) self.assertGreaterEqual(train_metrics['accuracy'], 0)
def apply(self, g, x, in_feats, hidden_feats, out_feats, num_layers, dropout): with nn.stochastic(jax.random.PRNGKey(0)): x = SAGEConv(g, x, in_feats, hidden_feats) for idx in range(num_layers-2): x = SAGEConv(g, x, hidden_feats, hidden_feats) x = nn.BatchNorm(x) x = nn.dropout(x, rate=dropout) x = SAGEConv(g, x, hidden_feats, out_feats) return jax.nn.log_softmax(x, axis=-1)
def loss_fn(model): with nn.stochastic(rng): logits = model(inputs, lengths, train=True) loss = jnp.mean(binary_cross_entropy_loss(logits, labels)) # L2 regularization l2_params = jax.tree_leaves(model.params['lstm_classifier']) l2_weight = jnp.sum([jnp.sum(p**2) for p in l2_params]) l2_penalty = l2_reg * l2_weight loss = loss + l2_penalty return loss, logits
def train_model(): """Train for a fixed number of steps and decode during training.""" with nn.stochastic(jax.random.PRNGKey(0)): model = create_model() optimizer = create_optimizer(model, FLAGS.learning_rate) for step in range(FLAGS.num_train_steps): batch = get_batch(FLAGS.batch_size) optimizer, metrics = train_step(optimizer, batch, nn.make_rng()) if step % FLAGS.decode_frequency == 0: logging.info('train step: %d, loss: %.4f, accuracy: %.2f', step, metrics['loss'], metrics['accuracy'] * 100) decode_batch(optimizer.target, 5) return optimizer.target
def test_autoencoder_model(self, model_str): """Test forward pass of the autoencoder models.""" model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_autoencoder_metrics' hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE[model_str]}) rng = jax.random.PRNGKey(0) model = model_cls(hps, {}, loss, metrics) xs = jnp.array(np.random.normal(size=INPUT_SHAPE[model_str])) rng, params_rng = jax.random.split(rng) with nn.stateful() as batch_stats: with nn.stochastic(params_rng): _, flax_module = model.flax_module_def.create(params_rng, xs) # Check that the forward pass works with mutated batch_stats. with nn.stateful(batch_stats) as new_batch_stats: with nn.stochastic(params_rng): outputs = flax_module(xs) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str]))) # If it's a batch norm model check the batch stats changed. if batch_stats.as_dict(): bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. with nn.stateful(batch_stats, mutable=False): outputs = flax_module(xs, train=False) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])))