def create_conv_model(only_digits: bool = False) -> models.Model: """Creates EMNIST CNN model with dropout with haiku. Matches the model used in: Adaptive Federated Optimization Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295 Args: only_digits: Whether to use only digit classes [0-9] or include lower and upper case characters for a total of 62 classes. Returns: Model """ num_classes = 10 if only_digits else 62 def forward_pass(batch, is_train=True): return ConvDropoutModule(num_classes)(batch['x'], is_train) transformed_forward_pass = hk.transform(forward_pass) return models.create_model_from_haiku( transformed_forward_pass=transformed_forward_pass, sample_batch=_HAIKU_SAMPLE_BATCH, train_loss=_TRAIN_LOSS, eval_metrics=_EVAL_METRICS, # is_train determines whether to apply dropout or not. train_kwargs={'is_train': True}, eval_kwargs={'is_train': False})
def create_regression_model() -> models.Model: """Creates toy regression model. Matches the model used in: Communication-Efficient Agnostic Federated Averaging Jae Ro, Mingqing Chen, Rajiv Mathews, Mehryar Mohri, Ananda Theertha Suresh https://arxiv.org/abs/2104.02748 Returns: Model """ def forward_pass(batch): network = hk.Sequential([hk.Linear(1, with_bias=False)]) return jnp.mean(network(batch['x'])) def train_loss(batch, preds): return jnp.square(jnp.mean(batch['y']) - preds) transformed_forward_pass = hk.transform(forward_pass) sample_batch = {'x': np.zeros((1, 1)), 'y': np.zeros((1, ))} return models.create_model_from_haiku( transformed_forward_pass=transformed_forward_pass, sample_batch=sample_batch, train_loss=train_loss)
def test_create_model_from_haiku(self): def forward_pass(batch): return hk.Linear(10)(batch['x']) haiku_model = models.create_model_from_haiku( transformed_forward_pass=hk.transform(forward_pass), sample_batch={'x': jnp.ones((1, 2))}, train_loss=train_loss, eval_metrics=eval_metrics) self.check_model(haiku_model)
def create_logistic_model(only_digits: bool = False) -> models.Model: """Creates EMNIST logistic model with haiku.""" num_classes = 10 if only_digits else 62 def forward_pass(batch): network = hk.Sequential([ hk.Flatten(), hk.Linear(num_classes), ]) return network(batch['x']) transformed_forward_pass = hk.transform(forward_pass) return models.create_model_from_haiku( transformed_forward_pass=transformed_forward_pass, sample_batch=_HAIKU_SAMPLE_BATCH, train_loss=_TRAIN_LOSS, eval_metrics=_EVAL_METRICS)
def create_lstm_model(vocab_size: int = 86, embed_size: int = 8, lstm_hidden_size: int = 256, lstm_num_layers: int = 2) -> models.Model: """Creates LSTM language model. Character-level LSTM for Shakespeare language model. Defaults to the model used in: Communication-Efficient Learning of Deep Networks from Decentralized Data H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017. https://arxiv.org/abs/1602.05629 Args: vocab_size: The number of possible output characters. This does not include special tokens like PAD, BOS, EOS, or OOV. embed_size: Embedding size for each character. lstm_hidden_size: Hidden size for LSTM cells. lstm_num_layers: Number of LSTM layers. Returns: Model. """ # TODO(jaero): Replace these with direct references from dataset. pad = 0 bos = vocab_size + 1 eos = vocab_size + 2 oov = vocab_size + 3 full_vocab_size = vocab_size + 4 # We do not guess EOS, and if we guess OOV, it's treated as a mistake. logits_mask = [0. for _ in range(full_vocab_size)] for i in (pad, bos, eos, oov): logits_mask[i] = jnp.NINF logits_mask = tuple(logits_mask) def forward_pass(batch): x = batch['x'] # [time_steps, batch_size, ...]. x = jnp.transpose(x) # [time_steps, batch_size, embed_dim]. embedding_layer = hk.Embed(full_vocab_size, embed_size) embeddings = embedding_layer(x) lstm_layers = [] for _ in range(lstm_num_layers): lstm_layers.extend([hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh]) rnn_core = hk.DeepRNN(lstm_layers) initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1]) # [time_steps, batch_size, hidden_size]. output, _ = hk.static_unroll(rnn_core, embeddings, initial_state) output = hk.Linear(full_vocab_size)(output) # [batch_size, time_steps, full_vocab_size]. output = jnp.transpose(output, axes=(1, 0, 2)) return output def train_loss(batch, preds): """Returns average token loss per sequence.""" targets = batch['y'] per_token_loss = metrics.unreduced_cross_entropy_loss(targets, preds) # Don't count padded values in loss. per_token_loss *= targets != pad return jnp.mean(per_token_loss, axis=-1) transformed_forward_pass = hk.transform(forward_pass) return models.create_model_from_haiku( transformed_forward_pass=transformed_forward_pass, sample_batch={ 'x': jnp.zeros((1, 1), dtype=jnp.int32), 'y': jnp.zeros((1, 1), dtype=jnp.int32), }, train_loss=train_loss, eval_metrics={ 'accuracy_in_vocab': metrics.SequenceTokenAccuracy( masked_target_values=(pad, eos), logits_mask=logits_mask), 'accuracy_no_eos': metrics.SequenceTokenAccuracy(masked_target_values=(pad, eos)), 'num_tokens': metrics.SequenceTokenCount(masked_target_values=(pad,)), 'sequence_length': metrics.SequenceLength(masked_target_values=(pad,)), 'sequence_loss': metrics.SequenceCrossEntropyLoss(masked_target_values=(pad,)), 'token_loss': metrics.SequenceTokenCrossEntropyLoss( masked_target_values=(pad,)), 'token_oov_rate': metrics.SequenceTokenOOVRate( oov_target_values=(oov,), masked_target_values=(pad,)), })
def create_lstm_model(vocab_size: int = 10000, embed_size: int = 96, lstm_hidden_size: int = 670, lstm_num_layers: int = 1, share_input_output_embeddings: bool = False, expected_length: Optional[float] = None) -> models.Model: """Creates LSTM language model. Word-level language model for Stack Overflow. Defaults to the model used in: Adaptive Federated Optimization Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295 Args: vocab_size: The number of possible output words. This does not include special tokens like PAD, BOS, EOS, or OOV. embed_size: Embedding size for each word. lstm_hidden_size: Hidden size for LSTM cells. lstm_num_layers: Number of LSTM layers. share_input_output_embeddings: Whether to share the input embeddings with the output logits. expected_length: Expected average sentence length used to scale the training loss down by `1. / expected_length`. This constant term is used so that the total loss over all the words in a sentence can be scaled down to per word cross entropy values by a constant factor instead of dividing by number of words which can vary across batches. Defaults to no scaling. Returns: Model. """ # TODO(jaero): Replace these with direct references from dataset. pad = 0 bos = 1 eos = 2 oov = vocab_size + 3 full_vocab_size = vocab_size + 4 # We do not guess EOS, and if we guess OOV, it's treated as a mistake. logits_mask = [0. for _ in range(full_vocab_size)] for i in (pad, bos, eos, oov): logits_mask[i] = jnp.NINF logits_mask = tuple(logits_mask) def forward_pass(batch): x = batch['x'] # [time_steps, batch_size, ...]. x = jnp.transpose(x) # [time_steps, batch_size, embed_dim]. embedding_layer = hk.Embed(full_vocab_size, embed_size) embeddings = embedding_layer(x) lstm_layers = [] for _ in range(lstm_num_layers): lstm_layers.extend([ hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh, # Projection changes dimension from lstm_hidden_size to embed_size. hk.Linear(embed_size) ]) rnn_core = hk.DeepRNN(lstm_layers) initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1]) # [time_steps, batch_size, hidden_size]. output, _ = hk.static_unroll(rnn_core, embeddings, initial_state) if share_input_output_embeddings: output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings)) output = hk.Bias(bias_dims=[-1])(output) else: output = hk.Linear(full_vocab_size)(output) # [batch_size, time_steps, full_vocab_size]. output = jnp.transpose(output, axes=(1, 0, 2)) return output def train_loss(batch, preds): """Returns total loss per sentence optionally scaled down to token level.""" targets = batch['y'] per_token_loss = metrics.unreduced_cross_entropy_loss(targets, preds) # Don't count padded values in loss. per_token_loss *= targets != pad sentence_loss = jnp.sum(per_token_loss, axis=-1) if expected_length is not None: return sentence_loss * (1. / expected_length) return sentence_loss transformed_forward_pass = hk.transform(forward_pass) return models.create_model_from_haiku( transformed_forward_pass=transformed_forward_pass, sample_batch={ 'x': jnp.zeros((1, 1), dtype=jnp.int32), 'y': jnp.zeros((1, 1), dtype=jnp.int32), }, train_loss=train_loss, eval_metrics={ 'accuracy_in_vocab': metrics.SequenceTokenAccuracy(masked_target_values=(pad, eos), logits_mask=logits_mask), 'accuracy_no_eos': metrics.SequenceTokenAccuracy(masked_target_values=(pad, eos)), 'num_tokens': metrics.SequenceTokenCount(masked_target_values=(pad, )), 'sequence_length': metrics.SequenceLength(masked_target_values=(pad, )), 'sequence_loss': metrics.SequenceCrossEntropyLoss(masked_target_values=(pad, )), 'token_loss': metrics.SequenceTokenCrossEntropyLoss( masked_target_values=(pad, )), 'token_oov_rate': metrics.SequenceTokenOOVRate(oov_target_values=(oov, ), masked_target_values=(pad, )), 'truncation_rate': metrics.SequenceTruncationRate(eos_target_value=eos, masked_target_values=(pad, )), })