Exemplo n.º 1
0
def fit(model, dataset, optimizer, verbose=False, logdir=None):
  """Fit the linear-regression model.

  Args:
    model: The LinearModel to fit.
    dataset: The tf.data.Dataset to use for training data.
    optimizer: The TensorFlow Optimizer object to be used.
    verbose: If true, will print out loss values at every iteration.
    logdir: The directory in which summaries will be written for TensorBoard
      (optional).
  """

  # The loss function to optimize.
  mse = lambda xs, ys: mean_square_loss(model, xs, ys)
  loss_and_grads = tfe.implicit_value_and_gradients(mse)

  if logdir:
    # Support for TensorBoard summaries. Once training has started, use:
    #   tensorboard --logdir=<logdir>
    summary_writer = tf.contrib.summary.create_file_writer(logdir)

  # Training loop.
  for i, (xs, ys) in enumerate(tfe.Iterator(dataset)):
    loss, grads = loss_and_grads(xs, ys)
    if verbose:
      print("Iteration %d: loss = %s" % (i, loss.numpy()))

    optimizer.apply_gradients(grads)

    if logdir:
      with summary_writer.as_default():
        with tf.contrib.summary.always_record_summaries():
          tf.contrib.summary.scalar("loss", loss, step=i)
          tf.contrib.summary.scalar("step", i, step=i)
  def testSyntheticDataset(self):
    true_w = tf.random_uniform([3, 1])
    true_b = [1.0]
    batch_size = 10
    num_batches = 2
    noise_level = 0.
    dataset = linear_regression.synthetic_dataset(true_w, true_b, noise_level,
                                                  batch_size, num_batches)

    it = tfe.Iterator(dataset)
    for _ in range(2):
      (xs, ys) = it.next()
      self.assertEqual((batch_size, 3), xs.shape)
      self.assertEqual((batch_size, 1), ys.shape)
      self.assertEqual(tf.float32, xs.dtype)
      self.assertEqual(tf.float32, ys.dtype)
    with self.assertRaises(StopIteration):
      it.next()
Exemplo n.º 3
0
 def make_iterator(tensors):
     with tf.device('/device:CPU:0'):
         ds = tf.data.Dataset.from_tensors(tensors).repeat()
     return tfe.Iterator(ds)