コード例 #1
0
ファイル: train.py プロジェクト: wdevazelhes/flax
def train_and_evaluate(seed, model_dir, num_epochs, batch_size, embedding_size,
                       hidden_size, min_freq, max_seq_len, dropout,
                       emb_dropout, word_dropout_rate, learning_rate,
                       checkpoints_to_keep, l2_reg):
    """Executes model training and evaluation loop.
  
  Args:
    seed: Random seed for network initialization.
    model_dir: Directory to store model data.
    num_epochs: Number of training epochs.
    batch_size: Batch size for training.
    embedding_size: Size of the word embeddings.
    hidden_size: Hidden size for the LSTM and MLP.
    min_freq: Minimum frequency for training set words to be in vocabulary.
    max_seq_len: Maximum sequence length in the dataset.
    dropout: Dropout rate.
    emb_dropout: Embedding dropout rate.
    word_dropout_rate: Word dropout rate.
    learning_rate: The learning rate for the Adam optimizer.
    checkpoints_to_keep: Number of checkpoints to keep.
    l2_reg: L2 regularization to keep. 
  """
    tf.enable_v2_behavior()

    # Prepare data.
    data_source = input_pipeline.SST2DataSource(min_freq=min_freq)

    # Create model.
    model = sst2_model.create_model(
        seed, batch_size, max_seq_len,
        dict(vocab_size=data_source.vocab_size,
             embedding_size=embedding_size,
             hidden_size=hidden_size,
             output_size=1,
             unk_idx=data_source.unk_idx,
             dropout=dropout,
             emb_dropout=emb_dropout,
             word_dropout_rate=word_dropout_rate))

    # Train the model.
    _, model = train(model,
                     learning_rate=learning_rate,
                     num_epochs=num_epochs,
                     seed=seed,
                     model_dir=model_dir,
                     data_source=data_source,
                     batch_size=batch_size,
                     checkpoints_to_keep=checkpoints_to_keep,
                     l2_reg=l2_reg)

    # Evaluate the best model.
    valid_batches = input_pipeline.get_batches(data_source.valid_dataset,
                                               batch_size=batch_size)
    metrics = evaluate(model, valid_batches)
    logging.info('Best validation accuracy: %.2f', metrics['acc'])
コード例 #2
0
def main(argv):
    """Main function."""
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    assert FLAGS.model_dir is not None, 'Please provide model_dir.'
    if not gfile.exists(FLAGS.model_dir):
        gfile.makedirs(FLAGS.model_dir)

    tf.enable_v2_behavior()

    # Prepare data.
    data_source = input_pipeline.SST2DataSource(min_freq=FLAGS.min_freq)

    # Create model.
    model = sst2_model.create_model(
        FLAGS.seed, FLAGS.batch_size, FLAGS.max_seq_len,
        dict(vocab_size=data_source.vocab_size,
             embedding_size=FLAGS.embedding_size,
             hidden_size=FLAGS.hidden_size,
             output_size=1,
             unk_idx=data_source.unk_idx,
             dropout=FLAGS.dropout,
             emb_dropout=FLAGS.emb_dropout,
             word_dropout_rate=FLAGS.word_dropout_rate))

    # Train the model.
    train_stats, model = train(model,
                               learning_rate=FLAGS.learning_rate,
                               num_epochs=FLAGS.num_epochs,
                               seed=FLAGS.seed,
                               model_dir=FLAGS.model_dir,
                               data_source=data_source,
                               batch_size=FLAGS.batch_size,
                               checkpoints_to_keep=FLAGS.checkpoints_to_keep,
                               l2_reg=FLAGS.l2_reg)

    # Evaluate the best model.
    valid_batches = input_pipeline.get_batches(data_source.valid_dataset,
                                               batch_size=FLAGS.batch_size)
    metrics = evaluate(model, valid_batches)
    logging.info('Best validation accuracy: %.2f', metrics['acc'])
コード例 #3
0
def train(
    model: nn.Model,
    learning_rate: float = None,
    num_epochs: int = None,
    seed: int = None,
    model_dir: Text = None,
    data_source: Any = None,
    batch_size: int = None,
    checkpoints_to_keep: int = None,
    l2_reg: float = None,
) -> Tuple[Dict[Text, Any], nn.Model]:
    """Training loop.

  Args:
    model: An initialized model to be trained.
    learning_rate: The learning rate.
    num_epochs: Train for this many epochs.
    seed: Seed for shuffling.
    model_dir: Directory to save best model.
    data_source: The data source with pre-processed data examples.
    batch_size: The batch size to use for training and validation data.
    l2_reg: L2 regularization weight.

  Returns:
    A dict with training statistics and the best model.
  """
    rng = jax.random.PRNGKey(seed)
    optimizer = flax.optim.Adam(learning_rate=learning_rate).create(model)
    stats = collections.defaultdict(list)
    best_score = 0.
    train_batches = input_pipeline.get_shuffled_batches(
        data_source.train_dataset, batch_size=batch_size, seed=seed)
    valid_batches = input_pipeline.get_batches(data_source.valid_dataset,
                                               batch_size=batch_size)

    for epoch in range(num_epochs):
        train_metrics = collections.defaultdict(float)

        # Train for one epoch.
        for ex in tfds.as_numpy(train_batches):
            inputs, lengths, labels = ex['sentence'], ex['length'], ex['label']
            optimizer, loss, rng = train_step(optimizer, inputs, lengths,
                                              labels, rng, l2_reg)
            train_metrics['loss'] += loss * inputs.shape[0]
            train_metrics['total'] += inputs.shape[0]

        # Evaluate on validation data. optimizer.target is the updated model.
        valid_metrics = evaluate(optimizer.target, valid_batches)
        log(stats, epoch, train_metrics, valid_metrics)

        # Save a checkpoint if this is the best model so far.
        if valid_metrics['acc'] > best_score:
            best_score = valid_metrics['acc']
            flax.training.checkpoints.save_checkpoint(model_dir,
                                                      optimizer.target,
                                                      epoch + 1,
                                                      keep=checkpoints_to_keep)

    # Done training. Restore best model.
    logging.info('Training done! Best validation accuracy: %.2f', best_score)
    best_model = flax.training.checkpoints.restore_checkpoint(model_dir, model)

    return stats, best_model