Exemple #1
0
    def loss_fn(model):
        if train:
            with nn.stateful(state.model_state) as new_model_state:
                with nn.stochastic(run_rng):
                    if not class_conditional:
                        scores = model(perturbed_data, labels, train=train)
                    else:
                        scores = model(perturbed_data,
                                       labels,
                                       y=class_labels,
                                       train=train)
        else:
            with nn.stateful(state.model_state, mutable=False):
                with nn.stochastic(run_rng):
                    if not class_conditional:
                        scores = model(perturbed_data, labels, train=train)
                    else:
                        scores = model(perturbed_data,
                                       labels,
                                       y=class_labels,
                                       train=train)

            new_model_state = state.model_state

        scores = scores.reshape((scores.shape[0], -1))
        target = -1 / (used_sigmas**2) * noise
        target = target.reshape((target.shape[0], -1))
        losses = 1 / 2. * ((scores - target)**2).sum(
            axis=-1) * used_sigmas.squeeze()**anneal_power
        loss = jnp.mean(losses)

        if loss_per_sigma:
            return loss, new_model_state, losses
        else:
            return loss, new_model_state
Exemple #2
0
 def init_single_head(init_rng, args, kwargs):
   if rng is None:
     _, head_params = wrapped_module.init(init_rng, *args, **kwargs)
   else:
     with nn.stochastic(rng):
       _, head_params = wrapped_module.init(init_rng, *args, **kwargs)
   return head_params
Exemple #3
0
def create_model(config):
    """Create a model, starting with a pre-trained checkpoint."""
    model_kwargs = dict(config=config.model, )
    model_def = modeling.BertForPreTraining.partial(**model_kwargs)
    if config.init_checkpoint:
        initial_params = import_weights.load_params(
            init_checkpoint=config.init_checkpoint,
            hidden_size=config.model.hidden_size,
            num_attention_heads=config.model.num_attention_heads,
            keep_masked_lm_head=True)
    else:
        with nn.stochastic(jax.random.PRNGKey(0)):
            _, initial_params = model_def.init_by_shape(
                jax.random.PRNGKey(0),
                [((1, config.max_seq_length), jnp.int32),
                 ((1, config.max_seq_length), jnp.int32),
                 ((1, config.max_seq_length), jnp.int32),
                 ((1, config.max_predictions_per_seq), jnp.int32)],
                deterministic=True)

            def fixup_for_tpu(x, i=[0]):
                """HACK to fix incorrect param initialization on TPU."""
                if isinstance(x, jax.ShapeDtypeStruct):
                    i[0] += 1
                    if len(x.shape) == 2:
                        return jnp.zeros(x.shape, x.dtype)
                    else:
                        return nn.linear.default_kernel_init(
                            jax.random.PRNGKey(i[0]), x.shape, x.dtype)
                else:
                    return x

            initial_params = jax.tree_map(fixup_for_tpu, initial_params)
    model = nn.Model(model_def, initial_params)
    return model
Exemple #4
0
def compute_pretraining_stats(model, batch):
    """Used for computing eval metrics during pre-training."""
    with nn.stochastic(jax.random.PRNGKey(0)):
        masked_lm_logits, next_sentence_logits = model(
            batch['input_ids'], (batch['input_ids'] > 0).astype(np.int32),
            batch['token_type_ids'],
            batch['masked_lm_positions'],
            deterministic=True)
        stats = model.compute_metrics(masked_lm_logits, next_sentence_logits,
                                      batch['masked_lm_ids'],
                                      batch['masked_lm_weights'],
                                      batch['next_sentence_label'])

    masked_lm_correct = jnp.sum(
        (masked_lm_logits.argmax(-1) == batch['masked_lm_ids'].reshape(
            (-1, ))) * batch['masked_lm_weights'].reshape((-1, )))
    next_sentence_labels = batch['next_sentence_label'].reshape((-1, ))
    next_sentence_correct = jnp.sum(
        next_sentence_logits.argmax(-1) == next_sentence_labels)
    stats = {
        'masked_lm_correct': masked_lm_correct,
        'masked_lm_total': jnp.sum(batch['masked_lm_weights']),
        'next_sentence_correct': next_sentence_correct,
        'next_sentence_total': jnp.sum(jnp.ones_like(next_sentence_labels)),
        **stats
    }
    return stats
Exemple #5
0
    def training_cost(self, flax_module, batch_stats, batch, dropout_rng):
        """Return cross entropy loss with (optional) L2 penalty on the weights."""

        with nn.stateful(batch_stats) as new_batch_stats:
            with nn.stochastic(dropout_rng):
                # inputs/targets positions and segmentations are required
                # when we have packed examples.
                logits = flax_module(batch['inputs'],
                                     batch['targets'],
                                     batch.get('inputs_positions'),
                                     batch.get('targets_positions'),
                                     batch.get('inputs_segmentation'),
                                     batch.get('targets_segmentation'),
                                     train=True)

        weights = batch.get('weights')
        targets = batch['targets']

        if self.dataset_meta_data['apply_one_hot_in_loss']:
            targets = one_hot(batch['targets'], logits.shape[-1])
        # Optionally apply label smoothing.
        if self.hps.get('label_smoothing') is not None:
            targets = model_utils.apply_label_smoothing(
                targets, self.hps.get('label_smoothing'))
        total_loss = self.loss_fn(logits, targets, weights)

        if self.hps.get('l2_decay_factor'):
            l2_loss = model_utils.l2_regularization(
                flax_module.params, self.hps.l2_decay_rank_threshold)
            total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss
        return total_loss, (new_batch_stats)
def eval_step(state, module, batch, metrics_dict, rng):
    """Compute the metrics for the given model in inference mode.

  The model is applied to the inputs using all devices on the host. Afterwards
  metrics are averaged across *all* devices (of all hosts).

  Args:
    state: Replicated model state.
    module: Model function.
    batch: Inputs that should be evaluated.
    metrics_dict: A dictionary of metrics, mapping names to metric functions.
    rng: Jax pseudo-random number generator key.

  Returns:
    Dictionary of replicated metrics, stats output by the model and updated PRNG
      key.
  """
    rng, new_rng = jax.random.split(rng)
    with nn.stochastic(rng), flax.nn.stateful(state.model_state,
                                              mutable=False):
        logits, stats = module.call(state.model_params,
                                    batch["image"],
                                    train=False)
    metrics = {
        m: fn(logits, batch["label"], stats)
        for (m, fn) in metrics_dict.items()
    }
    metrics = jax.lax.all_gather(metrics, axis_name="batch")
    stats = jax.lax.all_gather(stats, axis_name="batch")
    return metrics, stats, new_rng
 def impl_loss_fn(model_params):
     with nn.stochastic(rng), nn.stateful(
             state.model_state) as new_model_state:
         logits, stats = module.call(model_params, batch["image"])
     losses = loss_fn if isinstance(loss_fn, (list, tuple)) else [loss_fn]
     loss = sum(l(logits, batch["label"], stats) for l in losses)
     return loss, (logits, new_model_state, stats)
Exemple #8
0
 def test_stochastic_rngs(self):
   rng = random.PRNGKey(0)
   with nn.stochastic(rng):
     r1 = nn.make_rng()
     r2 = nn.make_rng()
   self.assertTrue(onp.all(r1 == random.fold_in(rng, 1)))
   self.assertTrue(onp.all(r2 == random.fold_in(rng, 2)))
Exemple #9
0
 def loss_fn(model):
   """Loss function used for training."""
   with nn.stochastic(dropout_rng):
     logits = model(inputs, train=True)
   loss, weight_sum = compute_weighted_cross_entropy(logits, targets, weights)
   mean_loss = loss / weight_sum
   return mean_loss, logits
Exemple #10
0
    def loss_fn(model):
        if train:
            with nn.stateful(state.model_state) as new_model_state:
                with nn.stochastic(run_rng):
                    scores = model(perturbed_data, T, train=train)
        else:
            with nn.stateful(state.model_state, mutable=False):
                with nn.stochastic(run_rng):
                    scores = model(perturbed_data, T, train=train)

            new_model_state = state.model_state

        scores = scores.reshape((scores.shape[0], -1))
        target = noise.reshape((noise.shape[0], -1))
        loss = jnp.mean((scores - target)**2)
        return loss, new_model_state
Exemple #11
0
def create_model(key, input_shape):
    def inducing_loc_init(key, shape):
        return jnp.linspace(-1.5, 1.5, FLAGS.num_inducing_points)[:,
                                                                  jnp.newaxis]

    kwargs = {}
    for i in range(1, FLAGS.num_layers + 1):
        kwargs['kernel_fn_{}_kwargs'.format(i)] = {
            'amplitude_init': lambda key, shape: jnp.ones(shape),
            'length_scale_init': lambda key, shape: jnp.ones(shape)
        }
        kwargs['inducing_var_{}_kwargs'.format(i)] = {
            'fixed_locations': False,
            'whiten': FLAGS.whiten,
            'inducing_locations_init': inducing_loc_init
        }

    model_def = DeepGPModel.partial(**kwargs)

    with nn.stochastic(key):
        _, params = model_def.init_by_shape(key, [
            (input_shape, jnp.float64),
        ], nn.make_rng(), **kwargs)

        return nn.Model(model_def, params)
Exemple #12
0
def create_model(key, flax_module, input_shape, model_kwargs):
    module = flax_module.partial(**model_kwargs)
    with nn.stochastic(key):
        _, initial_params = module.init_by_shape(key,
                                                 [(input_shape, jnp.float32)])
        model = nn.Model(module, initial_params)
    return model
Exemple #13
0
def compute_loss_and_metrics(model, batch, rng):
    """Compute cross-entropy loss for classification tasks."""
    with nn.stochastic(rng):
        metrics = model(batch['input_ids'],
                        (batch['input_ids'] > 0).astype(np.int32),
                        batch['token_type_ids'], batch['label'])
    return metrics['loss'], metrics
Exemple #14
0
def initialize(flax_module_def, initializer, loss_fn, input_shape,
               output_shape, hps, rng, metrics_logger):
    """Run the given initializer.

  We initialize in 3 phases. First we run the default initializer that is
  specified by the model constructor. Next we apply any rescaling as specified
  by hps.layer_rescale_factors. Finally we run the black box initializer
  provided by the initializer arg (the default is noop).

  Args:
    flax_module_def: An uninitialized flax module definition.
    initializer: An initializer defined in init_lib.
    loss_fn: A loss function.
    input_shape: The input shape of a single data example.
    output_shape: The output shape of a single data example.
    hps: A dictionary specifying the model and initializer hparams.
    rng: An rng key to seed the initialization.
    metrics_logger: Used for black box initializers that have learning curves.

  Returns:
    A tuple (model, batch_stats), where model is the initialized
    flax.nn.Model and batch_stats is the collection used for batch norm.
  """
    model_dtype = utils.dtype_from_str(hps.model_dtype)
    # init_by_shape should either pass in a tuple or a list of tuples.
    # For example, for vision tasks typically input_shape is (image_shape)
    # For seq2seq tasks, shape can be a list of two tuples corresponding to
    # input_sequence_shape for encoder and output_sequence_shape for decoder.
    # TODO(gilmer,ankugarg): Support initializers for list of tuples.
    if isinstance(input_shape, list):  # Typical case for seq2seq models
        input_specs = [((hps.batch_size, *x), model_dtype)
                       for x in input_shape]
    else:  # Typical case for classification models
        input_specs = [((hps.batch_size, *input_shape), model_dtype)]
    params_rng, init_rng, dropout_rng = jax.random.split(rng, num=3)

    with nn.stateful() as batch_stats:
        with nn.stochastic(dropout_rng):
            # Using flax_module_def.create can OOM for larger models, so we must use
            # create by shape here.
            # TODO(gilmer) Link to flax issue when bug reporting process finalizes.
            _, params = flax_module_def.init_by_shape(params_rng,
                                                      input_specs,
                                                      train=False)
    model = nn.Model(flax_module_def, params)

    if hps.get('layer_rescale_factors'):
        model = model_utils.rescale_layers(model, hps.layer_rescale_factors)
    # We don't pass batch_stats to the initializer, the initializer will just
    # run batch_norm in train mode and does not need to maintain the batch_stats.
    # TODO(gilmer): We hardcode here weighted_cross_entropy, but this will need
    # to change for other models. Maybe have meta_loss_inner as an initializer
    # hyper_param?
    # TODO(gilmer): instead of passing in weighted_xent, pass in the model and get
    # the loss from that.
    new_model = initializer(loss_fn, model, hps, input_shape, output_shape,
                            init_rng, metrics_logger)

    return new_model, batch_stats
Exemple #15
0
Fichier : train.py Projet : us/flax
def decode(model, inputs, rng):
  """Decode inputs."""
  init_decoder_input = onehot(CTABLE.encode('=')[0:1], CTABLE.vocab_size)
  init_decoder_inputs = jnp.tile(init_decoder_input,
                                 (inputs.shape[0], get_max_output_len(), 1))
  with nn.stochastic(rng):
    _, predictions = model(inputs, init_decoder_inputs, teacher_force=False)
  return predictions
Exemple #16
0
 def loss_fn(model):
     """Loss function used for training."""
     with nn.stochastic(dropout_rng):
         logits = model(inputs1, inputs2, train=True)
     loss, weight_sum = train_utils.compute_weighted_cross_entropy(
         logits, targets, num_classes=2, weights=None)
     mean_loss = loss / weight_sum
     return mean_loss, logits
def create_model(module, input_shape, rng):
    """Instanciates the model."""
    model_rng, init_rng = jax.random.split(rng)
    with nn.stochastic(model_rng), nn.stateful() as init_state:
        x = jnp.ones(input_shape, dtype=jnp.float32)
        _, init_params = module.init(init_rng, x)
    model = nn.Model(module, init_params)
    return model, init_params, init_state
Exemple #18
0
def create_model(key, flax_module, input_shape, model_kwargs):
    """Creates and initializes the mode."""
    module = flax_module.partial(**model_kwargs)
    with nn.stateful() as init_state:
        with nn.stochastic(key):
            _, initial_params = module.init_by_shape(
                key, [(input_shape, jnp.float32)])
            model = nn.Model(module, initial_params)
    return model, init_state
Exemple #19
0
def compute_pretraining_loss_and_metrics(model, batch, rng):
    """Compute cross-entropy loss for classification tasks."""
    with nn.stochastic(rng):
        metrics = model(batch['input_ids'],
                        (batch['input_ids'] > 0).astype(np.int32),
                        batch['token_type_ids'], batch['masked_lm_positions'],
                        batch['masked_lm_ids'], batch['masked_lm_weights'],
                        batch['next_sentence_label'])
    return metrics['loss'], metrics
Exemple #20
0
def get_pretrain_model():
  """Get pretrain model with pretrained weights loaded.

  Returns:
    jax_model: pretrain model with TF pretrained weights loaded to its
    transformer encoder part
  """
  # Get pretrained TF model configuration and variables from checkpoint
  if FLAGS.load_tf_weights:
    if FLAGS.load_mlperf_weights:
      tf_config, tf_vars, _ = utils.get_mlperf_model_variables(
          FLAGS.bert_config_file, FLAGS.init_checkpoint)
    else:
      tf_config, tf_vars, _ = utils.get_tf_model_variables(
          FLAGS.bert_config_file, FLAGS.init_checkpoint)
  else:
    tf_config = utils.get_tf_config(FLAGS.bert_config_file)

  # Generate JAX model using same model configuration as TF model
  if FLAGS.seed:
    seed = FLAGS.seed
  else:
    seed = np.int64(time.time())
  logging.info('RNG seed is %d.', seed)
  rng = random.PRNGKey(seed)
  tf.random.set_seed(seed)

  sequence_length = FLAGS.max_seq_length
  device_batch_size = int(FLAGS.train_batch_size // jax.device_count())
  model_kwargs = utils.convert_tf_config_to_jax_bert(tf_config)
  model_kwargs['num_token_predictions'] = FLAGS.max_predictions_per_seq
  model_kwargs['num_classes'] = 2
  with nn.stochastic(rng):
    model_def = bert_models.PretrainModel.partial(**model_kwargs)
    input_shape = (device_batch_size, sequence_length)
    inputs = jax.numpy.zeros(input_shape, dtype=jnp.int32)
    _, jax_model = model_def.create(rng, [inputs] * 4)
  # Update transformer encoder parameters with TF model pretrained weights
  if FLAGS.load_tf_weights:
    if FLAGS.load_mlperf_weights:
      jax_transformer_vars = utils.convert_mlperf_param_dict_to_jax(
          tf_vars, model_kwargs['emb_dim'], model_kwargs['num_heads'])
      jax_model.params.update(jax_transformer_vars)
    else:
      raise NotImplementedError(
          'Loading kerasBERT checkpoint for pretraining not supported yet.')
  else:
    encoder_vars = jax_model.params['transformer_encoder']
    encoder_vars['self_attention_mask'] = 0.0
    masked_lm_vars = jax_model.params['masked_lm']
    masked_lm_vars['0'] = 0.0
    masked_lm_vars['GatherIndexes_0'] = 0.0
    jax_model.params.update({'transformer_encoder': encoder_vars})
    jax_model.params.update({'masked_lm': masked_lm_vars})

  return jax_model, model_kwargs
Exemple #21
0
  def test_init_by_shape_lifts_stochastic(self):
    class StochasticModule(nn.Module):
      def apply(self):
        return nn.make_rng()

    with nn.stochastic(random.PRNGKey(0)):
      rng, _ = StochasticModule.init_by_shape(random.PRNGKey(1), [])
    expected_rng = random.fold_in(random.PRNGKey(0), 1)
    expected_rng = random.fold_in(expected_rng, 1)
    self.assertTrue(onp.all(rng == expected_rng))
Exemple #22
0
def compute_regression_stats(model, batch):
    with nn.stochastic(jax.random.PRNGKey(0)):
        y = model(batch["input_ids"],
                  (batch["input_ids"] > 0).astype(np.int32),
                  batch["type_ids"],
                  deterministic=True)
    return {
        "idx": batch["idx"],
        "label": batch["label"],
        "prediction": y[..., 0],
    }
Exemple #23
0
 def loss_fn(model):
   with nn.stochastic(dropout_rng):
     use_bf16 = FLAGS.use_bfloat16_activation
     dtype = jnp.bfloat16 if use_bf16 else jnp.float32
     lm_outputs, sentence_outputs = model(
         inputs, train=True, dtype=dtype)
     assert lm_outputs.dtype == jnp.float32
     assert sentence_outputs.dtype == jnp.float32
   total_loss, lm_loss, sentence_loss = get_pretrain_loss(
       labels, lm_outputs, sentence_outputs)
   return total_loss, (lm_loss, sentence_loss)
Exemple #24
0
 def loss_fn(model):
     """Loss function used for training."""
     with nn.stochastic(dropout_rng):
         logits = model(inputs, train=True, cache=None)
     loss, weight_sum = utils.compute_weighted_cross_entropy(
         logits,
         inputs,
         token_weights=token_weights,
         example_weights=example_weights)
     mean_loss = loss / weight_sum
     return mean_loss, logits
Exemple #25
0
def compute_classification_stats(model, batch):
    with nn.stochastic(jax.random.PRNGKey(0)):
        y = model(batch['input_ids'],
                  (batch['input_ids'] > 0).astype(np.int32),
                  batch['token_type_ids'],
                  deterministic=True)
    return {
        'idx': batch['idx'],
        'label': batch['label'],
        'prediction': y.argmax(-1)
    }
Exemple #26
0
  def test_train_one_step(self):
    batch = train.get_batch(128)
    rng = random.PRNGKey(0)

    with nn.stochastic(rng):
      model = train.create_model()
      optimizer = train.create_optimizer(model, 0.003)
      optimizer, train_metrics = train.train_step(
          optimizer, batch, nn.make_rng())

    self.assertLessEqual(train_metrics['loss'], 5)
    self.assertGreaterEqual(train_metrics['accuracy'], 0)
Exemple #27
0
    def apply(self, g, x, in_feats, hidden_feats, out_feats, num_layers, dropout):
        with nn.stochastic(jax.random.PRNGKey(0)):
            x = SAGEConv(g, x, in_feats, hidden_feats)

            for idx in range(num_layers-2):
                x = SAGEConv(g, x, hidden_feats, hidden_feats)
                x = nn.BatchNorm(x)
                x = nn.dropout(x, rate=dropout)

            x = SAGEConv(g, x, hidden_feats, out_feats)

        return jax.nn.log_softmax(x, axis=-1)
Exemple #28
0
    def loss_fn(model):
        with nn.stochastic(rng):
            logits = model(inputs, lengths, train=True)
        loss = jnp.mean(binary_cross_entropy_loss(logits, labels))

        # L2 regularization
        l2_params = jax.tree_leaves(model.params['lstm_classifier'])
        l2_weight = jnp.sum([jnp.sum(p**2) for p in l2_params])
        l2_penalty = l2_reg * l2_weight

        loss = loss + l2_penalty
        return loss, logits
Exemple #29
0
Fichier : train.py Projet : us/flax
def train_model():
  """Train for a fixed number of steps and decode during training."""
  with nn.stochastic(jax.random.PRNGKey(0)):
    model = create_model()
    optimizer = create_optimizer(model, FLAGS.learning_rate)
    for step in range(FLAGS.num_train_steps):
      batch = get_batch(FLAGS.batch_size)
      optimizer, metrics = train_step(optimizer, batch, nn.make_rng())
      if step % FLAGS.decode_frequency == 0:
        logging.info('train step: %d, loss: %.4f, accuracy: %.2f', step,
                     metrics['loss'], metrics['accuracy'] * 100)
        decode_batch(optimizer.target, 5)
  return optimizer.target
Exemple #30
0
    def test_autoencoder_model(self, model_str):
        """Test forward pass of the autoencoder models."""

        model_cls = models.get_model(model_str)
        model_hps = models.get_model_hparams(model_str)
        loss = 'sigmoid_binary_cross_entropy'
        metrics = 'binary_autoencoder_metrics'
        hps = copy.copy(model_hps)
        hps.update({'output_shape': OUTPUT_SHAPE[model_str]})
        rng = jax.random.PRNGKey(0)
        model = model_cls(hps, {}, loss, metrics)
        xs = jnp.array(np.random.normal(size=INPUT_SHAPE[model_str]))
        rng, params_rng = jax.random.split(rng)
        with nn.stateful() as batch_stats:
            with nn.stochastic(params_rng):
                _, flax_module = model.flax_module_def.create(params_rng, xs)

        # Check that the forward pass works with mutated batch_stats.
        with nn.stateful(batch_stats) as new_batch_stats:
            with nn.stochastic(params_rng):
                outputs = flax_module(xs)
                self.assertEqual(
                    outputs.shape,
                    tuple([INPUT_SHAPE[model_str][0]] +
                          list(OUTPUT_SHAPE[model_str])))

        # If it's a batch norm model check the batch stats changed.
        if batch_stats.as_dict():
            bflat, _ = ravel_pytree(batch_stats)
            new_bflat, _ = ravel_pytree(new_batch_stats)
            self.assertFalse(jnp.array_equal(bflat, new_bflat))

        # Test batch_norm in inference mode.
        with nn.stateful(batch_stats, mutable=False):
            outputs = flax_module(xs, train=False)
        self.assertEqual(
            outputs.shape,
            tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])))