Пример #1
0
def train_and_evaluate(seed, model_dir, num_epochs, batch_size, embedding_size,
                       hidden_size, min_freq, max_seq_len, dropout,
                       emb_dropout, word_dropout_rate, learning_rate,
                       checkpoints_to_keep, l2_reg):
    """Executes model training and evaluation loop.
  
  Args:
    seed: Random seed for network initialization.
    model_dir: Directory to store model data.
    num_epochs: Number of training epochs.
    batch_size: Batch size for training.
    embedding_size: Size of the word embeddings.
    hidden_size: Hidden size for the LSTM and MLP.
    min_freq: Minimum frequency for training set words to be in vocabulary.
    max_seq_len: Maximum sequence length in the dataset.
    dropout: Dropout rate.
    emb_dropout: Embedding dropout rate.
    word_dropout_rate: Word dropout rate.
    learning_rate: The learning rate for the Adam optimizer.
    checkpoints_to_keep: Number of checkpoints to keep.
    l2_reg: L2 regularization to keep. 
  """
    tf.enable_v2_behavior()

    # Prepare data.
    data_source = input_pipeline.SST2DataSource(min_freq=min_freq)

    # Create model.
    model = sst2_model.create_model(
        seed, batch_size, max_seq_len,
        dict(vocab_size=data_source.vocab_size,
             embedding_size=embedding_size,
             hidden_size=hidden_size,
             output_size=1,
             unk_idx=data_source.unk_idx,
             dropout=dropout,
             emb_dropout=emb_dropout,
             word_dropout_rate=word_dropout_rate))

    # Train the model.
    _, model = train(model,
                     learning_rate=learning_rate,
                     num_epochs=num_epochs,
                     seed=seed,
                     model_dir=model_dir,
                     data_source=data_source,
                     batch_size=batch_size,
                     checkpoints_to_keep=checkpoints_to_keep,
                     l2_reg=l2_reg)

    # Evaluate the best model.
    valid_batches = input_pipeline.get_batches(data_source.valid_dataset,
                                               batch_size=batch_size)
    metrics = evaluate(model, valid_batches)
    logging.info('Best validation accuracy: %.2f', metrics['acc'])
Пример #2
0
def main(argv):
    """Main function."""
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    assert FLAGS.model_dir is not None, 'Please provide model_dir.'
    if not gfile.exists(FLAGS.model_dir):
        gfile.makedirs(FLAGS.model_dir)

    tf.enable_v2_behavior()

    # Prepare data.
    data_source = input_pipeline.SST2DataSource(min_freq=FLAGS.min_freq)

    # Create model.
    model = sst2_model.create_model(
        FLAGS.seed, FLAGS.batch_size, FLAGS.max_seq_len,
        dict(vocab_size=data_source.vocab_size,
             embedding_size=FLAGS.embedding_size,
             hidden_size=FLAGS.hidden_size,
             output_size=1,
             unk_idx=data_source.unk_idx,
             dropout=FLAGS.dropout,
             emb_dropout=FLAGS.emb_dropout,
             word_dropout_rate=FLAGS.word_dropout_rate))

    # Train the model.
    train_stats, model = train(model,
                               learning_rate=FLAGS.learning_rate,
                               num_epochs=FLAGS.num_epochs,
                               seed=FLAGS.seed,
                               model_dir=FLAGS.model_dir,
                               data_source=data_source,
                               batch_size=FLAGS.batch_size,
                               checkpoints_to_keep=FLAGS.checkpoints_to_keep,
                               l2_reg=FLAGS.l2_reg)

    # Evaluate the best model.
    valid_batches = input_pipeline.get_batches(data_source.valid_dataset,
                                               batch_size=FLAGS.batch_size)
    metrics = evaluate(model, valid_batches)
    logging.info('Best validation accuracy: %.2f', metrics['acc'])