Example #1
0
def train(num_training_iterations, report_interval,
          reduce_learning_rate_interval):
  """Run the training of the deep LSTM model on tiny shakespeare."""

  dataset_train = dataset_shakespeare.TinyShakespeareDataset(
      num_steps=FLAGS.truncation_length,
      batch_size=FLAGS.batch_size,
      subset="train",
      random=True,
      name="shake_train")

  dataset_valid = dataset_shakespeare.TinyShakespeareDataset(
      num_steps=FLAGS.truncation_length,
      batch_size=FLAGS.batch_size,
      subset="valid",
      random=False,
      name="shake_valid")

  dataset_test = dataset_shakespeare.TinyShakespeareDataset(
      num_steps=FLAGS.truncation_length,
      batch_size=FLAGS.batch_size,
      subset="test",
      random=False,
      name="shake_test")

  model = TextModel(
      num_embedding=FLAGS.num_embedding,
      num_hidden=FLAGS.num_hidden,
      lstm_depth=FLAGS.lstm_depth,
      output_size=dataset_valid.vocab_size,
      use_dynamic_rnn=True,
      use_skip_connections=True)

  # Build the training model and get the training loss.
  train_input_sequence, train_target_sequence = dataset_train()
  train_output_sequence_logits, train_final_state = model(train_input_sequence)  # pylint: disable=not-callable
  train_loss = dataset_train.cost(train_output_sequence_logits,
                                  train_target_sequence)

  # Get the validation loss.
  valid_input_sequence, valid_target_sequence = dataset_valid()
  valid_output_sequence_logits, _ = model(valid_input_sequence)  # pylint: disable=not-callable
  valid_loss = dataset_valid.cost(valid_output_sequence_logits,
                                  valid_target_sequence)

  # Get the test loss.
  test_input_sequence, test_target_sequence = dataset_test()
  test_output_sequence_logits, _ = model(test_input_sequence)  # pylint: disable=not-callable
  test_loss = dataset_test.cost(test_output_sequence_logits,
                                test_target_sequence)

  # Build graph to sample some strings during training.
  initial_logits = train_output_sequence_logits[FLAGS.truncation_length - 1]
  train_generated_string = model.generate_string(
      initial_logits=initial_logits,
      initial_state=train_final_state,
      sequence_length=FLAGS.sample_length)

  # Set up optimizer with global norm clipping.
  trainable_variables = tf.trainable_variables()
  grads, _ = tf.clip_by_global_norm(
      tf.gradients(train_loss, trainable_variables),
      FLAGS.max_grad_norm)

  learning_rate = tf.get_variable(
      "learning_rate",
      shape=[],
      dtype=tf.float32,
      initializer=tf.constant_initializer(FLAGS.learning_rate),
      trainable=False)
  reduce_learning_rate = learning_rate.assign(
      learning_rate * FLAGS.reduce_learning_rate_multiplier)

  global_step = tf.get_variable(
      name="global_step",
      shape=[],
      dtype=tf.int64,
      initializer=tf.zeros_initializer(),
      trainable=False,
      collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP])

  optimizer = tf.train.AdamOptimizer(
      learning_rate, epsilon=FLAGS.optimizer_epsilon)
  train_step = optimizer.apply_gradients(
      zip(grads, trainable_variables),
      global_step=global_step)

  saver = tf.train.Saver()

  hooks = [
      tf.train.CheckpointSaverHook(
          checkpoint_dir=FLAGS.checkpoint_dir,
          save_steps=FLAGS.checkpoint_interval,
          saver=saver)
  ]

  # Train.
  with tf.train.SingularMonitoredSession(
      hooks=hooks, checkpoint_dir=FLAGS.checkpoint_dir) as sess:

    start_iteration = sess.run(global_step)

    for train_iteration in xrange(start_iteration, num_training_iterations):
      if (train_iteration + 1) % report_interval == 0:
        train_loss_v, valid_loss_v, _ = sess.run(
            (train_loss, valid_loss, train_step))

        train_generated_string_v = sess.run(train_generated_string)

        train_generated_string_human = dataset_train.to_human_readable(
            (train_generated_string_v, 0), indices=[0])

        tf.logging.info("%d: Training loss %f. Validation loss %f. Sample = %s",
                        train_iteration,
                        train_loss_v,
                        valid_loss_v,
                        train_generated_string_human)
      else:
        train_loss_v, _ = sess.run((train_loss, train_step))
        tf.logging.info("%d: Training loss %f.", train_iteration, train_loss_v)

      if (train_iteration + 1) % reduce_learning_rate_interval == 0:
        sess.run(reduce_learning_rate)
        tf.logging.info("Reducing learning rate.")

    test_loss = sess.run(test_loss)
    tf.logging.info("Test loss %f", test_loss)
Example #2
0
def build_graph(lstm_depth=3, batch_size=32, num_embedding=32, num_hidden=128,
                truncation_length=64, sample_length=1000, max_grad_norm=5,
                initial_learning_rate=0.1, reduce_learning_rate_multiplier=0.1,
                optimizer_epsilon=0.01):
  """Constructs the computation graph."""

  # Get datasets.
  dataset_train = dataset_shakespeare.TinyShakespeareDataset(
      num_steps=truncation_length,
      batch_size=batch_size,
      subset="train",
      random=True,
      name="shake_train")

  dataset_valid = dataset_shakespeare.TinyShakespeareDataset(
      num_steps=truncation_length,
      batch_size=batch_size,
      subset="valid",
      random=False,
      name="shake_valid")

  dataset_test = dataset_shakespeare.TinyShakespeareDataset(
      num_steps=truncation_length,
      batch_size=batch_size,
      subset="test",
      random=False,
      name="shake_test")

  # Define model.
  model = TextModel(
      num_embedding=num_embedding,
      num_hidden=num_hidden,
      lstm_depth=lstm_depth,
      output_size=dataset_valid.vocab_size,
      use_dynamic_rnn=True,
      use_skip_connections=True)

  # Get the training loss.
  train_input_sequence, train_target_sequence = dataset_train()
  train_output_sequence_logits, train_final_state = model(train_input_sequence)  # pylint: disable=not-callable
  train_loss = dataset_train.cost(train_output_sequence_logits,
                                  train_target_sequence)

  # Get the validation loss.
  valid_input_sequence, valid_target_sequence = dataset_valid()
  valid_output_sequence_logits, _ = model(valid_input_sequence)  # pylint: disable=not-callable
  valid_loss = dataset_valid.cost(valid_output_sequence_logits,
                                  valid_target_sequence)

  # Get the test loss.
  test_input_sequence, test_target_sequence = dataset_test()
  test_output_sequence_logits, _ = model(test_input_sequence)  # pylint: disable=not-callable
  test_loss = dataset_test.cost(test_output_sequence_logits,
                                test_target_sequence)

  # Build graph to sample some strings during training.
  initial_logits = train_output_sequence_logits[truncation_length - 1]
  train_generated_string = model.generate_string(
      initial_logits=initial_logits,
      initial_state=train_final_state,
      sequence_length=sample_length)

  # Set up global norm clipping of gradients.
  trainable_variables = tf.trainable_variables()
  grads, _ = tf.clip_by_global_norm(
      tf.gradients(train_loss, trainable_variables), max_grad_norm)

  # Get learning rate and define annealing.
  learning_rate = tf.get_variable(
      "learning_rate",
      shape=[],
      dtype=tf.float32,
      initializer=tf.constant_initializer(initial_learning_rate),
      trainable=False)
  reduce_learning_rate = learning_rate.assign(
      learning_rate * reduce_learning_rate_multiplier)

  # Get training step counter.
  global_step = tf.get_variable(
      name="global_step",
      shape=[],
      dtype=tf.int64,
      initializer=tf.zeros_initializer(),
      trainable=False,
      collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP])

  # Define optimizer and training step.
  optimizer = tf.train.AdamOptimizer(
      learning_rate, epsilon=optimizer_epsilon)
  train_step = optimizer.apply_gradients(
      zip(grads, trainable_variables),
      global_step=global_step)

  graph_tensors = {
      "train_loss": train_loss,
      "valid_loss": valid_loss,
      "test_loss": test_loss,
      "train_generated_string": train_generated_string,
      "reduce_learning_rate": reduce_learning_rate,
      "global_step": global_step,
      "train_step": train_step
  }

  # Return dataset_train for translation to human readable text.
  return graph_tensors, dataset_train