def train_and_evaluate(seed, model_dir, num_epochs, batch_size, embedding_size, hidden_size, min_freq, max_seq_len, dropout, emb_dropout, word_dropout_rate, learning_rate, checkpoints_to_keep, l2_reg): """Executes model training and evaluation loop. Args: seed: Random seed for network initialization. model_dir: Directory to store model data. num_epochs: Number of training epochs. batch_size: Batch size for training. embedding_size: Size of the word embeddings. hidden_size: Hidden size for the LSTM and MLP. min_freq: Minimum frequency for training set words to be in vocabulary. max_seq_len: Maximum sequence length in the dataset. dropout: Dropout rate. emb_dropout: Embedding dropout rate. word_dropout_rate: Word dropout rate. learning_rate: The learning rate for the Adam optimizer. checkpoints_to_keep: Number of checkpoints to keep. l2_reg: L2 regularization to keep. """ tf.enable_v2_behavior() # Prepare data. data_source = input_pipeline.SST2DataSource(min_freq=min_freq) # Create model. model = sst2_model.create_model( seed, batch_size, max_seq_len, dict(vocab_size=data_source.vocab_size, embedding_size=embedding_size, hidden_size=hidden_size, output_size=1, unk_idx=data_source.unk_idx, dropout=dropout, emb_dropout=emb_dropout, word_dropout_rate=word_dropout_rate)) # Train the model. _, model = train(model, learning_rate=learning_rate, num_epochs=num_epochs, seed=seed, model_dir=model_dir, data_source=data_source, batch_size=batch_size, checkpoints_to_keep=checkpoints_to_keep, l2_reg=l2_reg) # Evaluate the best model. valid_batches = input_pipeline.get_batches(data_source.valid_dataset, batch_size=batch_size) metrics = evaluate(model, valid_batches) logging.info('Best validation accuracy: %.2f', metrics['acc'])
def main(argv): """Main function.""" if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') assert FLAGS.model_dir is not None, 'Please provide model_dir.' if not gfile.exists(FLAGS.model_dir): gfile.makedirs(FLAGS.model_dir) tf.enable_v2_behavior() # Prepare data. data_source = input_pipeline.SST2DataSource(min_freq=FLAGS.min_freq) # Create model. model = sst2_model.create_model( FLAGS.seed, FLAGS.batch_size, FLAGS.max_seq_len, dict(vocab_size=data_source.vocab_size, embedding_size=FLAGS.embedding_size, hidden_size=FLAGS.hidden_size, output_size=1, unk_idx=data_source.unk_idx, dropout=FLAGS.dropout, emb_dropout=FLAGS.emb_dropout, word_dropout_rate=FLAGS.word_dropout_rate)) # Train the model. train_stats, model = train(model, learning_rate=FLAGS.learning_rate, num_epochs=FLAGS.num_epochs, seed=FLAGS.seed, model_dir=FLAGS.model_dir, data_source=data_source, batch_size=FLAGS.batch_size, checkpoints_to_keep=FLAGS.checkpoints_to_keep, l2_reg=FLAGS.l2_reg) # Evaluate the best model. valid_batches = input_pipeline.get_batches(data_source.valid_dataset, batch_size=FLAGS.batch_size) metrics = evaluate(model, valid_batches) logging.info('Best validation accuracy: %.2f', metrics['acc'])
def train( model: nn.Model, learning_rate: float = None, num_epochs: int = None, seed: int = None, model_dir: Text = None, data_source: Any = None, batch_size: int = None, checkpoints_to_keep: int = None, l2_reg: float = None, ) -> Tuple[Dict[Text, Any], nn.Model]: """Training loop. Args: model: An initialized model to be trained. learning_rate: The learning rate. num_epochs: Train for this many epochs. seed: Seed for shuffling. model_dir: Directory to save best model. data_source: The data source with pre-processed data examples. batch_size: The batch size to use for training and validation data. l2_reg: L2 regularization weight. Returns: A dict with training statistics and the best model. """ rng = jax.random.PRNGKey(seed) optimizer = flax.optim.Adam(learning_rate=learning_rate).create(model) stats = collections.defaultdict(list) best_score = 0. train_batches = input_pipeline.get_shuffled_batches( data_source.train_dataset, batch_size=batch_size, seed=seed) valid_batches = input_pipeline.get_batches(data_source.valid_dataset, batch_size=batch_size) for epoch in range(num_epochs): train_metrics = collections.defaultdict(float) # Train for one epoch. for ex in tfds.as_numpy(train_batches): inputs, lengths, labels = ex['sentence'], ex['length'], ex['label'] optimizer, loss, rng = train_step(optimizer, inputs, lengths, labels, rng, l2_reg) train_metrics['loss'] += loss * inputs.shape[0] train_metrics['total'] += inputs.shape[0] # Evaluate on validation data. optimizer.target is the updated model. valid_metrics = evaluate(optimizer.target, valid_batches) log(stats, epoch, train_metrics, valid_metrics) # Save a checkpoint if this is the best model so far. if valid_metrics['acc'] > best_score: best_score = valid_metrics['acc'] flax.training.checkpoints.save_checkpoint(model_dir, optimizer.target, epoch + 1, keep=checkpoints_to_keep) # Done training. Restore best model. logging.info('Training done! Best validation accuracy: %.2f', best_score) best_model = flax.training.checkpoints.restore_checkpoint(model_dir, model) return stats, best_model