Example #1
0
def create_conv_model(only_digits: bool = False) -> models.Model:
    """Creates EMNIST CNN model with dropout with haiku.

  Matches the model used in:

  Adaptive Federated Optimization
    Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush,
    Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan.
    https://arxiv.org/abs/2003.00295

  Args:
    only_digits: Whether to use only digit classes [0-9] or include lower and
      upper case characters for a total of 62 classes.

  Returns:
    Model
  """
    num_classes = 10 if only_digits else 62

    def forward_pass(batch, is_train=True):
        return ConvDropoutModule(num_classes)(batch['x'], is_train)

    transformed_forward_pass = hk.transform(forward_pass)
    return models.create_model_from_haiku(
        transformed_forward_pass=transformed_forward_pass,
        sample_batch=_HAIKU_SAMPLE_BATCH,
        train_loss=_TRAIN_LOSS,
        eval_metrics=_EVAL_METRICS,
        # is_train determines whether to apply dropout or not.
        train_kwargs={'is_train': True},
        eval_kwargs={'is_train': False})
Example #2
0
def create_regression_model() -> models.Model:
    """Creates toy regression model.

  Matches the model used in:

  Communication-Efficient Agnostic Federated Averaging
    Jae Ro, Mingqing Chen, Rajiv Mathews, Mehryar Mohri, Ananda Theertha Suresh
    https://arxiv.org/abs/2104.02748

  Returns:
    Model
  """
    def forward_pass(batch):
        network = hk.Sequential([hk.Linear(1, with_bias=False)])
        return jnp.mean(network(batch['x']))

    def train_loss(batch, preds):
        return jnp.square(jnp.mean(batch['y']) - preds)

    transformed_forward_pass = hk.transform(forward_pass)
    sample_batch = {'x': np.zeros((1, 1)), 'y': np.zeros((1, ))}
    return models.create_model_from_haiku(
        transformed_forward_pass=transformed_forward_pass,
        sample_batch=sample_batch,
        train_loss=train_loss)
Example #3
0
    def test_create_model_from_haiku(self):
        def forward_pass(batch):
            return hk.Linear(10)(batch['x'])

        haiku_model = models.create_model_from_haiku(
            transformed_forward_pass=hk.transform(forward_pass),
            sample_batch={'x': jnp.ones((1, 2))},
            train_loss=train_loss,
            eval_metrics=eval_metrics)
        self.check_model(haiku_model)
Example #4
0
def create_logistic_model(only_digits: bool = False) -> models.Model:
    """Creates EMNIST logistic model with haiku."""
    num_classes = 10 if only_digits else 62

    def forward_pass(batch):
        network = hk.Sequential([
            hk.Flatten(),
            hk.Linear(num_classes),
        ])
        return network(batch['x'])

    transformed_forward_pass = hk.transform(forward_pass)
    return models.create_model_from_haiku(
        transformed_forward_pass=transformed_forward_pass,
        sample_batch=_HAIKU_SAMPLE_BATCH,
        train_loss=_TRAIN_LOSS,
        eval_metrics=_EVAL_METRICS)
Example #5
0
def create_lstm_model(vocab_size: int = 86,
                      embed_size: int = 8,
                      lstm_hidden_size: int = 256,
                      lstm_num_layers: int = 2) -> models.Model:
  """Creates LSTM language model.

  Character-level LSTM for Shakespeare language model.
  Defaults to the model used in:

  Communication-Efficient Learning of Deep Networks from Decentralized Data
    H. Brendan McMahan, Eider Moore, Daniel Ramage,
    Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017.
    https://arxiv.org/abs/1602.05629

  Args:
    vocab_size: The number of possible output characters. This does not include
      special tokens like PAD, BOS, EOS, or OOV.
    embed_size: Embedding size for each character.
    lstm_hidden_size: Hidden size for LSTM cells.
    lstm_num_layers: Number of LSTM layers.

  Returns:
    Model.
  """
  # TODO(jaero): Replace these with direct references from dataset.
  pad = 0
  bos = vocab_size + 1
  eos = vocab_size + 2
  oov = vocab_size + 3
  full_vocab_size = vocab_size + 4
  # We do not guess EOS, and if we guess OOV, it's treated as a mistake.
  logits_mask = [0. for _ in range(full_vocab_size)]
  for i in (pad, bos, eos, oov):
    logits_mask[i] = jnp.NINF
  logits_mask = tuple(logits_mask)

  def forward_pass(batch):
    x = batch['x']
    # [time_steps, batch_size, ...].
    x = jnp.transpose(x)
    # [time_steps, batch_size, embed_dim].
    embedding_layer = hk.Embed(full_vocab_size, embed_size)
    embeddings = embedding_layer(x)

    lstm_layers = []
    for _ in range(lstm_num_layers):
      lstm_layers.extend([hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh])
    rnn_core = hk.DeepRNN(lstm_layers)
    initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1])
    # [time_steps, batch_size, hidden_size].
    output, _ = hk.static_unroll(rnn_core, embeddings, initial_state)

    output = hk.Linear(full_vocab_size)(output)
    # [batch_size, time_steps, full_vocab_size].
    output = jnp.transpose(output, axes=(1, 0, 2))
    return output

  def train_loss(batch, preds):
    """Returns average token loss per sequence."""
    targets = batch['y']
    per_token_loss = metrics.unreduced_cross_entropy_loss(targets, preds)
    # Don't count padded values in loss.
    per_token_loss *= targets != pad
    return jnp.mean(per_token_loss, axis=-1)

  transformed_forward_pass = hk.transform(forward_pass)
  return models.create_model_from_haiku(
      transformed_forward_pass=transformed_forward_pass,
      sample_batch={
          'x': jnp.zeros((1, 1), dtype=jnp.int32),
          'y': jnp.zeros((1, 1), dtype=jnp.int32),
      },
      train_loss=train_loss,
      eval_metrics={
          'accuracy_in_vocab':
              metrics.SequenceTokenAccuracy(
                  masked_target_values=(pad, eos), logits_mask=logits_mask),
          'accuracy_no_eos':
              metrics.SequenceTokenAccuracy(masked_target_values=(pad, eos)),
          'num_tokens':
              metrics.SequenceTokenCount(masked_target_values=(pad,)),
          'sequence_length':
              metrics.SequenceLength(masked_target_values=(pad,)),
          'sequence_loss':
              metrics.SequenceCrossEntropyLoss(masked_target_values=(pad,)),
          'token_loss':
              metrics.SequenceTokenCrossEntropyLoss(
                  masked_target_values=(pad,)),
          'token_oov_rate':
              metrics.SequenceTokenOOVRate(
                  oov_target_values=(oov,), masked_target_values=(pad,)),
      })
Example #6
0
def create_lstm_model(vocab_size: int = 10000,
                      embed_size: int = 96,
                      lstm_hidden_size: int = 670,
                      lstm_num_layers: int = 1,
                      share_input_output_embeddings: bool = False,
                      expected_length: Optional[float] = None) -> models.Model:
    """Creates LSTM language model.

  Word-level language model for Stack Overflow.
  Defaults to the model used in:

  Adaptive Federated Optimization
    Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush,
    Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan.
    https://arxiv.org/abs/2003.00295

  Args:
    vocab_size: The number of possible output words. This does not include
      special tokens like PAD, BOS, EOS, or OOV.
    embed_size: Embedding size for each word.
    lstm_hidden_size: Hidden size for LSTM cells.
    lstm_num_layers: Number of LSTM layers.
    share_input_output_embeddings: Whether to share the input embeddings with
      the output logits.
    expected_length: Expected average sentence length used to scale the training
      loss down by `1. / expected_length`. This constant term is used so that
      the total loss over all the words in a sentence can be scaled down to per
      word cross entropy values by a constant factor instead of dividing by
      number of words which can vary across batches. Defaults to no scaling.

  Returns:
    Model.
  """
    # TODO(jaero): Replace these with direct references from dataset.
    pad = 0
    bos = 1
    eos = 2
    oov = vocab_size + 3
    full_vocab_size = vocab_size + 4
    # We do not guess EOS, and if we guess OOV, it's treated as a mistake.
    logits_mask = [0. for _ in range(full_vocab_size)]
    for i in (pad, bos, eos, oov):
        logits_mask[i] = jnp.NINF
    logits_mask = tuple(logits_mask)

    def forward_pass(batch):
        x = batch['x']
        # [time_steps, batch_size, ...].
        x = jnp.transpose(x)
        # [time_steps, batch_size, embed_dim].
        embedding_layer = hk.Embed(full_vocab_size, embed_size)
        embeddings = embedding_layer(x)

        lstm_layers = []
        for _ in range(lstm_num_layers):
            lstm_layers.extend([
                hk.LSTM(hidden_size=lstm_hidden_size),
                jnp.tanh,
                # Projection changes dimension from lstm_hidden_size to embed_size.
                hk.Linear(embed_size)
            ])
        rnn_core = hk.DeepRNN(lstm_layers)
        initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1])
        # [time_steps, batch_size, hidden_size].
        output, _ = hk.static_unroll(rnn_core, embeddings, initial_state)

        if share_input_output_embeddings:
            output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings))
            output = hk.Bias(bias_dims=[-1])(output)
        else:
            output = hk.Linear(full_vocab_size)(output)
        # [batch_size, time_steps, full_vocab_size].
        output = jnp.transpose(output, axes=(1, 0, 2))
        return output

    def train_loss(batch, preds):
        """Returns total loss per sentence optionally scaled down to token level."""
        targets = batch['y']
        per_token_loss = metrics.unreduced_cross_entropy_loss(targets, preds)
        # Don't count padded values in loss.
        per_token_loss *= targets != pad
        sentence_loss = jnp.sum(per_token_loss, axis=-1)
        if expected_length is not None:
            return sentence_loss * (1. / expected_length)
        return sentence_loss

    transformed_forward_pass = hk.transform(forward_pass)
    return models.create_model_from_haiku(
        transformed_forward_pass=transformed_forward_pass,
        sample_batch={
            'x': jnp.zeros((1, 1), dtype=jnp.int32),
            'y': jnp.zeros((1, 1), dtype=jnp.int32),
        },
        train_loss=train_loss,
        eval_metrics={
            'accuracy_in_vocab':
            metrics.SequenceTokenAccuracy(masked_target_values=(pad, eos),
                                          logits_mask=logits_mask),
            'accuracy_no_eos':
            metrics.SequenceTokenAccuracy(masked_target_values=(pad, eos)),
            'num_tokens':
            metrics.SequenceTokenCount(masked_target_values=(pad, )),
            'sequence_length':
            metrics.SequenceLength(masked_target_values=(pad, )),
            'sequence_loss':
            metrics.SequenceCrossEntropyLoss(masked_target_values=(pad, )),
            'token_loss':
            metrics.SequenceTokenCrossEntropyLoss(
                masked_target_values=(pad, )),
            'token_oov_rate':
            metrics.SequenceTokenOOVRate(oov_target_values=(oov, ),
                                         masked_target_values=(pad, )),
            'truncation_rate':
            metrics.SequenceTruncationRate(eos_target_value=eos,
                                           masked_target_values=(pad, )),
        })