def testRecoverSession(self):
    # Create a checkpoint.
    checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
    try:
      gfile.DeleteRecursively(checkpoint_dir)
    except errors.OpError:
      pass  # Ignore
    gfile.MakeDirs(checkpoint_dir)

    with ops.Graph().as_default():
      v = variables.Variable(1, name="v")
      sm = session_manager.SessionManager(
          ready_op=variables.report_uninitialized_variables())
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm.recover_session(
          "", saver=saver, checkpoint_dir=checkpoint_dir)
      self.assertFalse(initialized)
      sess.run(v.initializer)
      self.assertEquals(1, sess.run(v))
      saver.save(sess,
                 os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
    self._test_recovered_variable(
        checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
            checkpoint_dir))
    # Cannot set both checkpoint_dir and checkpoint_filename_with_path.
    with self.assertRaises(ValueError):
      self._test_recovered_variable(
          checkpoint_dir=checkpoint_dir,
          checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
              checkpoint_dir))
  def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None):
    """Exports the given Estimator as a SavedModel.

    Args:
      estimator: the Estimator to export.
      export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the most recent checkpoint found within the model directory is chosen.
      eval_result: placehold args matching the call signature of ExportStrategy.

    Returns:
      The string path to the exported directory.
    """
    if not checkpoint_path:
      # TODO(b/67425018): switch to
      #    checkpoint_path = estimator.latest_checkpoint()
      #  as soon as contrib is cleaned up and we can thus be sure that
      #  estimator is a tf.estimator.Estimator and not a
      #  tf.contrib.learn.Estimator
      checkpoint_path = checkpoint_management.latest_checkpoint(
          estimator.model_dir)
    export_checkpoint_path, export_eval_result = best_model_selector.update(
        checkpoint_path, eval_result)

    if export_checkpoint_path and export_eval_result is not None:
      checkpoint_base = os.path.basename(export_checkpoint_path)
      export_dir = os.path.join(export_dir_base, checkpoint_base)
      return best_model_export_strategy.export(
          estimator, export_dir, export_checkpoint_path, export_eval_result)
    else:
      return ''
  def testGraphDistributionStrategy(self):
    self.skipTest("b/121381184")
    num_training_steps = 10
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    def _train_fn(optimizer, model):
      input_value = constant_op.constant([[3.]])
      return optimizer.minimize(
          functools.partial(model, input_value),
          global_step=root.optimizer_step)

    for training_continuation in range(3):
      with ops.Graph().as_default():
        strategy = mirrored_strategy.MirroredStrategy()
        with strategy.scope():
          model = MyModel()
          optimizer = adam.AdamOptimizer(0.001)
          root = checkpointable_utils.Checkpoint(
              optimizer=optimizer, model=model,
              optimizer_step=training_util.get_or_create_global_step())
          status = root.restore(checkpoint_management.latest_checkpoint(
              checkpoint_directory))
          train_op = strategy.extended.call_for_each_replica(
              functools.partial(_train_fn, optimizer, model))
          with self.session() as session:
            if training_continuation > 0:
              status.assert_consumed()
            status.initialize_or_restore()
            for _ in range(num_training_steps):
              session.run(train_op)
            root.save(file_prefix=checkpoint_prefix)
        self.assertEqual((training_continuation + 1) * num_training_steps,
                         root.optimizer_step.numpy())
  def testEagerTPUDistributionStrategy(self):
    self.skipTest("b/121387144")
    num_training_steps = 10
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    def _train_fn(optimizer, model):
      input_value = constant_op.constant([[3.]])
      optimizer.minimize(
          functools.partial(model, input_value),
          global_step=root.optimizer_step)

    for training_continuation in range(3):
      strategy = tpu_strategy.TPUStrategy()
      with strategy.scope():
        model = Subclassed()
        optimizer = adam_v1.AdamOptimizer(0.001)
        root = checkpointable_utils.Checkpoint(
            optimizer=optimizer, model=model,
            optimizer_step=training_util.get_or_create_global_step())
        root.restore(checkpoint_management.latest_checkpoint(
            checkpoint_directory))

        for _ in range(num_training_steps):
          strategy.extended.call_for_each_replica(
              functools.partial(_train_fn, optimizer, model))
        root.save(file_prefix=checkpoint_prefix)
        self.assertEqual((training_continuation + 1) * num_training_steps,
                         root.optimizer_step.numpy())
 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
       with ops.Graph().as_default():
         model = MyModel()
         optimizer = adam.AdamOptimizer(0.001)
         root = util.Checkpoint(
             optimizer=optimizer, model=model,
             global_step=training_util.get_or_create_global_step())
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             model(input_value),
             global_step=root.global_step)
         checkpoint_path = checkpoint_management.latest_checkpoint(
             checkpoint_directory)
         with self.session(graph=ops.get_default_graph()) as session:
           status = root.restore(save_path=checkpoint_path)
           status.initialize_or_restore(session=session)
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             with self.assertRaises(AssertionError):
               status.assert_consumed()
           else:
             status.assert_consumed()
           for _ in range(num_training_steps):
             session.run(train_op)
           root.save(file_prefix=checkpoint_prefix, session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
           self.assertEqual(training_continuation + 1,
                            session.run(root.save_counter))
 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     with ops.Graph().as_default(), self.test_session(
         graph=ops.get_default_graph()), test_util.device(use_gpu=True):
       model = MyModel()
       optimizer = adam.AdamOptimizer(0.001)
       root = util.Checkpoint(
           optimizer=optimizer, model=model,
           global_step=training_util.get_or_create_global_step())
       checkpoint_path = checkpoint_management.latest_checkpoint(
           checkpoint_directory)
       status = root.restore(save_path=checkpoint_path)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(
           optimizer.minimize,
           functools.partial(model, input_value),
           global_step=root.global_step)
       if not context.executing_eagerly():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       root.save(file_prefix=checkpoint_prefix)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
Esempio n. 7
0
def wait_for_new_checkpoint(checkpoint_dir,
                            last_checkpoint=None,
                            seconds_to_sleep=1,
                            timeout=None):
  """Waits until a new checkpoint file is found.

  Args:
    checkpoint_dir: The directory in which checkpoints are saved.
    last_checkpoint: The last checkpoint path used or `None` if we're expecting
      a checkpoint for the first time.
    seconds_to_sleep: The number of seconds to sleep for before looking for a
      new checkpoint.
    timeout: The maximum number of seconds to wait. If left as `None`, then the
      process will wait indefinitely.

  Returns:
    a new checkpoint path, or None if the timeout was reached.
  """
  logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
  stop_time = time.time() + timeout if timeout is not None else None
  while True:
    checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
    if checkpoint_path is None or checkpoint_path == last_checkpoint:
      if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
        return None
      time.sleep(seconds_to_sleep)
    else:
      logging.info('Found new checkpoint at %s', checkpoint_path)
      return checkpoint_path
  def _new_layer_weight_loading_test_template(
      self, first_model_fn, second_model_fn, restore_init_fn):
    with self.cached_session() as session:
      model = first_model_fn()
      temp_dir = self.get_temp_dir()
      prefix = os.path.join(temp_dir, 'ckpt')

      x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
      executing_eagerly = context.executing_eagerly()
      ref_y_tensor = model(x)
      if not executing_eagerly:
        session.run([v.initializer for v in model.variables])
      ref_y = self.evaluate(ref_y_tensor)
      model.save_weights(prefix)
      self.assertEqual(
          prefix,
          checkpoint_management.latest_checkpoint(temp_dir))
      for v in model.variables:
        self.evaluate(
            v.assign(random_ops.random_normal(shape=array_ops.shape(v))))

      self.addCleanup(shutil.rmtree, temp_dir)

      second_model = second_model_fn()
      second_model.load_weights(prefix)
      second_model(x)
      self.evaluate(restore_init_fn(second_model))
      second_model.save_weights(prefix)
      # Check that the second model's checkpoint loads into the original model
      model.load_weights(prefix)
      y = self.evaluate(model(x))
      self.assertAllClose(ref_y, y)
Esempio n. 9
0
  def _restore_or_save_initial_ckpt(self, session):
    # Ideally this should be run in after_create_session but is not for the
    # following reason:
    # Currently there is no way of enforcing an order of running the
    # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
    # is run *after* this hook. That is troublesome because
    # 1. If a checkpoint exists and this hook restores it, the initializer hook
    #    will override it.
    # 2. If no checkpoint exists, this hook will try to save an uninitialized
    #    iterator which will result in an exception.
    #
    # As a temporary fix we enter the following implicit contract between this
    # hook and the _DatasetInitializerHook.
    # 1. The _DatasetInitializerHook initializes the iterator in the call to
    #    after_create_session.
    # 2. This hook saves the iterator on the first call to `before_run()`, which
    #    is guaranteed to happen after `after_create_session()` of all hooks
    #    have been run.

    # Check if there is an existing checkpoint. If so, restore from it.
    # pylint: disable=protected-access
    latest_checkpoint_path = checkpoint_management.latest_checkpoint(
        self._checkpoint_saver_hook._checkpoint_dir,
        latest_filename=self._latest_filename)
    if latest_checkpoint_path:
      self._checkpoint_saver_hook._get_saver().restore(session,
                                                       latest_checkpoint_path)
    else:
      # The checkpoint saved here is the state at step "global_step".
      # Note: We do not save the GraphDef or MetaGraphDef here.
      global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
      self._checkpoint_saver_hook._save(session, global_step)
      self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
Esempio n. 10
0
 def _read_vars(self, model_dir):
   """Returns (global_step, latest_feature)."""
   with ops.Graph().as_default() as g:
     ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
     meta_filename = ckpt_path + '.meta'
     saver_lib.import_meta_graph(meta_filename)
     saver = saver_lib.Saver()
     with self.test_session(graph=g) as sess:
       saver.restore(sess, ckpt_path)
       return sess.run(ops.get_collection('my_vars'))
Esempio n. 11
0
 def testRestoreInReconstructedIterator(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
   dataset = Dataset.range(10)
   for i in range(5):
     iterator = datasets.Iterator(dataset)
     checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
     checkpoint.restore(checkpoint_management.latest_checkpoint(
         checkpoint_directory))
     for j in range(2):
       self.assertEqual(i * 2 + j, iterator.get_next().numpy())
     checkpoint.save(file_prefix=checkpoint_prefix)
 def testRestoreInReconstructedIteratorInitializable(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   dataset = dataset_ops.Dataset.range(10)
   iterator = dataset.make_initializable_iterator()
   get_next = iterator.get_next()
   checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
   for i in range(5):
     with self.cached_session() as sess:
       checkpoint.restore(checkpoint_management.latest_checkpoint(
           checkpoint_directory)).initialize_or_restore(sess)
       for j in range(2):
         self.assertEqual(i * 2 + j, self.evaluate(get_next))
       checkpoint.save(file_prefix=checkpoint_prefix)
 def testRestoreInReconstructedIteratorInitializable(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   dataset = dataset_ops.Dataset.range(10)
   iterator = iter(dataset) if context.executing_eagerly(
   ) else dataset_ops.make_initializable_iterator(dataset)
   get_next = iterator.get_next
   checkpoint = trackable_utils.Checkpoint(iterator=iterator)
   for i in range(5):
     checkpoint.restore(
         checkpoint_management.latest_checkpoint(
             checkpoint_directory)).initialize_or_restore()
     for j in range(2):
       self.assertEqual(i * 2 + j, self.evaluate(get_next()))
     checkpoint.save(file_prefix=checkpoint_prefix)
  def testNameCollision(self):
    # Make sure we have a clean directory to work in.
    with self.tempDir() as tempdir:
      # Jump to that directory until this test is done.
      with self.tempWorkingDir(tempdir):
        # Save training snapshots to a relative path.
        traindir = "train/"
        os.mkdir(traindir)
        # Collides with the default name of the checkpoint state file.
        filepath = os.path.join(traindir, "checkpoint")

        with self.test_session() as sess:
          unused_a = variables.Variable(0.0)  # So that Saver saves something.
          variables.global_variables_initializer().run()

          # Should fail.
          saver = saver_module.Saver(sharded=False)
          with self.assertRaisesRegexp(ValueError, "collides with"):
            saver.save(sess, filepath)

          # Succeeds: the file will be named "checkpoint-<step>".
          saver.save(sess, filepath, global_step=1)
          self.assertIsNotNone(
              checkpoint_management.latest_checkpoint(traindir))

          # Succeeds: the file will be named "checkpoint-<i>-of-<n>".
          saver = saver_module.Saver(sharded=True)
          saver.save(sess, filepath)
          self.assertIsNotNone(
              checkpoint_management.latest_checkpoint(traindir))

          # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
          saver = saver_module.Saver(sharded=True)
          saver.save(sess, filepath, global_step=1)
          self.assertIsNotNone(
              checkpoint_management.latest_checkpoint(traindir))
  def _GetGraphDef(self, use_trt, max_batch_size, model_dir):
    """Get the frozen mnist GraphDef.

    Args:
      use_trt: whether use TF-TRT to convert the graph.
      max_batch_size: the max batch size to apply during TF-TRT conversion.
      model_dir: the model directory to load the checkpoints.

    Returns:
      The frozen mnist GraphDef.
    """
    graph = ops.Graph()
    with self.session(graph=graph) as sess:
      with graph.device('/GPU:0'):
        x = array_ops.placeholder(
            shape=(None, 28, 28, 1), dtype=dtypes.float32, name=INPUT_NODE_NAME)
        self._BuildGraph(x)
      # Load weights
      mnist_saver = saver.Saver()
      checkpoint_file = latest_checkpoint(model_dir)
      mnist_saver.restore(sess, checkpoint_file)
      # Freeze
      graph_def = graph_util.convert_variables_to_constants(
          sess, sess.graph_def, output_node_names=[OUTPUT_NODE_NAME])
    # Convert with TF-TRT
    if use_trt:
      logging.info('Number of nodes before TF-TRT conversion: %d',
                   len(graph_def.node))
      converter = trt_convert.TrtGraphConverter(
          input_graph_def=graph_def,
          nodes_blacklist=[OUTPUT_NODE_NAME],
          max_batch_size=max_batch_size,
          precision_mode='INT8',
          # There is a 2GB GPU memory limit for each test, so we set
          # max_workspace_size_bytes to 256MB to leave enough room for TF
          # runtime to allocate GPU memory.
          max_workspace_size_bytes=1 << 28,
          minimum_segment_size=2,
          use_calibration=False,
          use_function_backup=False)
      graph_def = converter.convert()
      logging.info('Number of nodes after TF-TRT conversion: %d',
                   len(graph_def.node))
      num_engines = len(
          [1 for n in graph_def.node if str(n.op) == 'TRTEngineOp'])
      self.assertEqual(1, num_engines)
    return graph_def
  def __init__(self,
               estimator,
               prediction_input_fn,
               input_alternative_key=None,
               output_alternative_key=None,
               graph=None,
               config=None):
    """Initialize a `ContribEstimatorPredictor`.

    Args:
      estimator: an instance of `tf.contrib.learn.Estimator`.
      prediction_input_fn: a function that takes no arguments and returns an
        instance of `InputFnOps`.
      input_alternative_key: Optional. Specify the input alternative used for
        prediction.
      output_alternative_key: Specify the output alternative used for
        prediction. Not needed for single-headed models but required for
        multi-headed models.
      graph: Optional. The Tensorflow `graph` in which prediction should be
        done.
      config: `ConfigProto` proto used to configure the session.
    """
    self._graph = graph or ops.Graph()
    with self._graph.as_default():
      input_fn_ops = prediction_input_fn()
      # pylint: disable=protected-access
      model_fn_ops = estimator._get_predict_ops(input_fn_ops.features)
      # pylint: enable=protected-access
      checkpoint_path = checkpoint_management.latest_checkpoint(
          estimator.model_dir)
      self._session = monitored_session.MonitoredSession(
          session_creator=monitored_session.ChiefSessionCreator(
              config=config,
              checkpoint_filename_with_path=checkpoint_path))

    input_alternative_key = (
        input_alternative_key or
        saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY)
    input_alternatives, _ = saved_model_export_utils.get_input_alternatives(
        input_fn_ops)
    self._feed_tensors = input_alternatives[input_alternative_key]

    (output_alternatives,
     output_alternative_key) = saved_model_export_utils.get_output_alternatives(
         model_fn_ops, output_alternative_key)
    _, fetch_tensors = output_alternatives[output_alternative_key]
    self._fetch_tensors = fetch_tensors
  def testCheckpointExists(self):
    for sharded in (False, True):
      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
        with self.session(graph=ops_lib.Graph()) as sess:
          unused_v = variables.Variable(1.0, name="v")
          variables.global_variables_initializer().run()
          saver = saver_module.Saver(sharded=sharded, write_version=version)

          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
          self.assertFalse(
              checkpoint_management.checkpoint_exists(path))  # Not saved yet.

          ckpt_prefix = saver.save(sess, path)
          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))

          ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir)
          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
  def testRelativePath(self):
    # Make sure we have a clean directory to work in.
    with self.tempDir() as tempdir:

      # Jump to that directory until this test is done.
      with self.tempWorkingDir(tempdir):

        # Save training snapshots to a relative path.
        traindir = "train/"
        os.mkdir(traindir)

        filename = "snapshot"
        filepath = os.path.join(traindir, filename)

        with self.test_session() as sess:
          # Build a simple graph.
          v0 = variables.Variable(0.0)
          inc = v0.assign_add(1.0)

          save = saver_module.Saver({"v0": v0})

          # Record a short training history.
          variables.global_variables_initializer().run()
          save.save(sess, filepath, global_step=0)
          inc.eval()
          save.save(sess, filepath, global_step=1)
          inc.eval()
          save.save(sess, filepath, global_step=2)

        with self.test_session() as sess:
          # Build a new graph with different initialization.
          v0 = variables.Variable(-1.0)

          # Create a new saver.
          save = saver_module.Saver({"v0": v0})
          variables.global_variables_initializer().run()

          # Get the most recent checkpoint name from the training history file.
          name = checkpoint_management.latest_checkpoint(traindir)
          self.assertIsNotNone(name)

          # Restore "v0" from that checkpoint.
          save.restore(sess, name)
          self.assertEqual(v0.eval(), 2.0)
Esempio n. 19
0
  def testTrainSpinn(self):
    """Test with fake toy SNLI data and GloVe vectors."""

    # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = self._create_test_data(snli_1_0_dir)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    train_data = data.SnliData(fake_train_file, word2index)
    dev_data = data.SnliData(fake_train_file, word2index)
    test_data = data.SnliData(fake_train_file, word2index)

    # 2. Create a fake config.
    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"))

    # 3. Test training of a SPINN model.
    trainer = spinn.train_or_infer_spinn(
        embed, word2index, train_data, dev_data, test_data, config)

    # 4. Load train loss values from the summary files and verify that they
    #    decrease with training.
    summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0]
    events = summary_test_util.events_from_file(summary_file)
    train_losses = [event.summary.value[0].simple_value for event in events
                    if event.summary.value
                    and event.summary.value[0].tag == "train/loss"]
    self.assertEqual(config.epochs, len(train_losses))

    # 5. Verify that checkpoints exist and contains all the expected variables.
    self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
    object_graph = checkpointable_utils.object_metadata(
        checkpoint_management.latest_checkpoint(config.logdir))
    ckpt_variable_names = set()
    for node in object_graph.nodes:
      for attribute in node.attributes:
        ckpt_variable_names.add(attribute.full_name)
    self.assertIn("global_step", ckpt_variable_names)
    for v in trainer.variables:
      variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
      self.assertIn(variable_name, ckpt_variable_names)
Esempio n. 20
0
  def after_save(self, session, global_step_value):
    """Evaluates and exports the model after a checkpoint is created."""
    # Load and cache the path of the most recent checkpoint to avoid duplicate
    # searches on GCS.
    logging.info("Checking for checkpoint in %s", self._model_dir)
    latest_path = checkpoint_management.latest_checkpoint(self._model_dir)

    if not latest_path:
      logging.warning("Skipping evaluation and export since model has not been "
                      "saved yet.")
    elif latest_path == self._latest_path:
      logging.warning("Skipping evaluation due to same latest checkpoint %s.",
                      latest_path)
    else:
      self._latest_path = latest_path
      self._eval_result = self._eval_fn(
          name="intermediate_export", checkpoint_path=latest_path)
      self._export_results = self._export_fn(
          self._eval_result, checkpoint_path=latest_path)
 def testWithDefun(self):
   num_training_steps = 2
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     with ops.Graph().as_default(), self.test_session(
         graph=ops.get_default_graph()), test_util.device(use_gpu=True):
       model = MyModel()
       # Don't actually train so we can test variable values
       optimizer = adam.AdamOptimizer(0.)
       root = util.Checkpoint(
           optimizer=optimizer, model=model,
           global_step=training_util.get_or_create_global_step())
       checkpoint_path = checkpoint_management.latest_checkpoint(
           checkpoint_directory)
       status = root.restore(save_path=checkpoint_path)
       def train_fn():
         @function.defun
         def _call_model(x):
           return model(x)
         with backprop.GradientTape() as tape:
           loss = _call_model(constant_op.constant([[3.]]))
         gradients = tape.gradient(loss, model.variables)
         return optimizer.apply_gradients(zip(gradients, model.variables),
                                          global_step=root.global_step)
       if not context.executing_eagerly():
         train_fn = functools.partial(
             self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       if training_continuation > 0:
         status.assert_consumed()
         self.assertAllClose([[42.]], self.evaluate(model.variables[0]))
       else:
         self.evaluate(model.variables[0].assign([[42.]]))
       root.save(file_prefix=checkpoint_prefix)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
Esempio n. 22
0
def _save_first_checkpoint(keras_model, custom_objects, config):
  """Save first checkpoint for the keras Estimator.

  Args:
    keras_model: an instance of compiled keras model.
    custom_objects: Dictionary for custom objects.
    config: Estimator config.

  Returns:
    The path where keras model checkpoint is saved.
  """
  # save checkpoint into subdirectory to allow warm start
  keras_model_dir = os.path.join(config.model_dir, 'keras')
  # Load weights and save to checkpoint if there is no checkpoint
  latest_path = checkpoint_management.latest_checkpoint(keras_model_dir)
  if not latest_path:
    keras_weights = None
    if _any_weight_initialized(keras_model):
      keras_weights = keras_model.get_weights()
    if not gfile.IsDirectory(keras_model_dir):
      gfile.MakeDirs(keras_model_dir)
    with ops.Graph().as_default():
      random_seed.set_random_seed(config.tf_random_seed)
      training_util.create_global_step()
      model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
                                     custom_objects)
      # save to checkpoint
      with session.Session(config=config.session_config) as sess:
        if keras_weights:
          model.set_weights(keras_weights)
        # Make update ops and initialize all variables.
        if not model.train_function:
          # pylint: disable=protected-access
          model._make_train_function()
          K._initialize_variables(sess)
          # pylint: enable=protected-access
        saver = saver_lib.Saver()
        latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')
        saver.save(sess, latest_path)
  return latest_path
 def testDeferredRestorationUsageEager(self):
   """An idiomatic eager execution example."""
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     model = MyModel()
     optimizer = adam.AdamOptimizer(0.001)
     root = util.Checkpoint(
         optimizer=optimizer, model=model,
         optimizer_step=training_util.get_or_create_global_step())
     root.restore(checkpoint_management.latest_checkpoint(
         checkpoint_directory))
     for _ in range(num_training_steps):
       # TODO(allenl): Use a Dataset and serialize/checkpoint it.
       input_value = constant_op.constant([[3.]])
       optimizer.minimize(
           lambda: model(input_value),  # pylint: disable=cell-var-from-loop
           global_step=root.optimizer_step)
     root.save(file_prefix=checkpoint_prefix)
     self.assertEqual((training_continuation + 1) * num_training_steps,
                      root.optimizer_step.numpy())
Esempio n. 24
0
 def end(self, session=None):
   super(ExportMonitor, self).end(session=session)
   latest_path = checkpoint_management.latest_checkpoint(
       self._estimator.model_dir)
   if latest_path is None:
     logging.info("Skipping export at the end since model has not been saved "
                  "yet.")
     return
   if isinstance(self._estimator, core_estimator.Estimator):
     raise ValueError(
         "ExportMonitor does not support `tf.estimator.Estimator. `. "
         "Please pass an ExportStrategy to Experiment instead.")
   try:
     self._last_export_dir = self._estimator.export(
         self.export_dir,
         exports_to_keep=self.exports_to_keep,
         signature_fn=self.signature_fn,
         input_fn=self._input_fn,
         default_batch_size=self._default_batch_size,
         input_feature_key=self._input_feature_key,
         use_deprecated_input_fn=self._use_deprecated_input_fn)
   except RuntimeError:
     logging.info("Skipping exporting for the same step.")
 def testCustomNumbering(self):
   directory = self.get_temp_dir()
   step = variables.Variable(0, dtype=dtypes.int64)
   checkpoint = util.Checkpoint(step=step)
   manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=2)
   self.evaluate(step.initializer)
   for i in range(5):
     path = manager.save(checkpoint_number=step)
     expected_suffix = "-%d" % (2 * i,)
     if not path.endswith(expected_suffix):
       self.fail("%s should have suffix %s" % (path, expected_suffix))
     self.evaluate(step.assign_add(2))
   self.assertEqual(5, self.evaluate(checkpoint.save_counter))
   # Test regular integers
   last_path = manager.save(checkpoint_number=32)
   self.assertIn("-32", last_path)
   self.assertEqual(last_path, manager.latest_checkpoint)
   self.assertEqual(
       last_path, checkpoint_management.latest_checkpoint(directory))
   state = checkpoint_management.get_checkpoint_state(directory)
   # Only the most recent two checkpoints are saved
   self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
 def _get_step(self):
   ckpt = checkpoint_management.latest_checkpoint(self._estimator.model_dir)
   if ckpt:
     return int(os.path.basename(ckpt).split('-')[1])
   else:
     return 0
 def _latest_ckpt(self):
   return checkpoint_management.latest_checkpoint(self.get_temp_dir())
def _get_checkpoint_filename(ckpt_dir_or_file):
  if gfile.IsDirectory(ckpt_dir_or_file):
    return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
  return ckpt_dir_or_file
Esempio n. 29
0
  def every_n_step_end(self, step, outputs):
    super(ValidationMonitor, self).every_n_step_end(step, outputs)
    # TODO(mdan): The use of step below is probably misleading.
    # The code should probably use the step from the checkpoint, because
    # that's what is being evaluated.
    if self._estimator is None:
      raise ValueError("Missing call to set_estimator.")
    current_time = time.time()
    if (self._check_interval_secs is not None and
        self._last_checkpoint_check_time is not None and
        current_time - self._last_checkpoint_check_time <=
        self._check_interval_secs):
      logging.debug(
          "Skipping evaluation since less than %d seconds have passed since "
          "last check for a new checkpoint.", self._check_interval_secs)
      return False
    self._last_checkpoint_check_time = current_time
    # Check that we are not running evaluation on the same checkpoint.
    latest_path = checkpoint_management.latest_checkpoint(
        self._estimator.model_dir)
    if latest_path is None:
      logging.debug("Skipping evaluation since model has not been saved yet "
                    "at step %d.", step)
      return False
    if latest_path is not None and latest_path == self._latest_path:
      logging.debug("Skipping evaluation due to same checkpoint %s for step %d "
                    "as for step %d.", latest_path, step,
                    self._latest_path_step)
      return False
    self._latest_path = latest_path
    self._latest_path_step = step

    # Run evaluation and log it.
    validation_outputs = self._evaluate_estimator()
    stats = []
    for name in validation_outputs:
      stats.append("%s = %s" % (name, str(validation_outputs[name])))
    logging.info("Validation (step %d): %s", step, ", ".join(stats))

    # Early stopping logic.
    if self.early_stopping_rounds is not None:
      if self.early_stopping_metric not in validation_outputs:
        raise ValueError("Metric %s missing from outputs %s." %
                         (self.early_stopping_metric,
                          set(validation_outputs.keys())))
      current_value = validation_outputs[self.early_stopping_metric]
      if (self._best_value is None or (self.early_stopping_metric_minimize and
                                       (current_value < self._best_value)) or
          (not self.early_stopping_metric_minimize and
           (current_value > self._best_value))):
        self._best_value = current_value
        self._best_metrics = copy.deepcopy(validation_outputs)
        self._best_value_step = step
      stop_now = (step - self._best_value_step >= self.early_stopping_rounds)
      if stop_now:
        logging.info("Stopping. Best step: {} with {} = {}.".format(
            self._best_value_step, self.early_stopping_metric,
            self._best_value))
        self._early_stopped = True
        return True
    return False
Esempio n. 30
0
    def test_initialize_if_not_restoring(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        optimizer_only_prefix = os.path.join(checkpoint_directory, "opt")
        with test_util.device(use_gpu=True):
            model = MyModel()
            optimizer = adam.AdamOptimizer(0.001)
            root = trackable_utils.Checkpoint(
                model=model,  # Do not save the optimizer with the checkpoint.
                global_step=training_util.get_or_create_global_step())
            optimizer_checkpoint = trackable_utils.Checkpoint(
                optimizer=optimizer)

            checkpoint_path = checkpoint_management.latest_checkpoint(
                checkpoint_directory)
            status = root.restore(save_path=checkpoint_path)
            input_value = constant_op.constant([[3.]])
            train_fn = functools.partial(optimizer.minimize,
                                         functools.partial(model, input_value),
                                         global_step=root.global_step)
            if not context.executing_eagerly():
                train_fn = functools.partial(self.evaluate, train_fn())
            status.initialize_or_restore()
            self.evaluate([v.initializer for v in optimizer.variables()])
            train_fn()
            model_save_path = root.save(file_prefix=checkpoint_prefix)
            self.evaluate(optimizer.variables()[0].assign(42.))
            optimizer_save_path = optimizer_checkpoint.save(
                optimizer_only_prefix)

        # Restore into a graph with the optimizer
        with test_util.device(use_gpu=True):
            model = MyModel()
            optimizer = adam.AdamOptimizer(0.001)
            root = trackable_utils.Checkpoint(
                optimizer=optimizer,
                model=model,
                global_step=training_util.get_or_create_global_step())
            status = root.restore(save_path=model_save_path)
            input_value = constant_op.constant([[3.]])
            train_fn = functools.partial(optimizer.minimize,
                                         functools.partial(model, input_value),
                                         global_step=root.global_step)
            if not context.executing_eagerly():
                train_fn = functools.partial(self.evaluate, train_fn())
            status.initialize_or_restore()
            train_fn()
            with self.assertRaises(AssertionError):
                status.assert_existing_objects_matched()
            with self.assertRaises(AssertionError):
                status.assert_consumed()

        # Make sure initialization doesn't clobber later restores
        with test_util.device(use_gpu=True):
            model = MyModel()
            optimizer = adam.AdamOptimizer(0.001, beta1=1.0)
            root = trackable_utils.Checkpoint(
                optimizer=optimizer,
                model=model,
                global_step=training_util.get_or_create_global_step())
            opt_root = trackable_utils.Checkpoint(optimizer=optimizer)
            status = root.restore(save_path=model_save_path)
            init_only_optimizer_status = opt_root.restore(save_path=None)
            optimizer_status = opt_root.restore(save_path=optimizer_save_path)
            input_value = constant_op.constant([[3.]])
            train_fn = functools.partial(optimizer.minimize,
                                         functools.partial(model, input_value),
                                         global_step=root.global_step)
            if not context.executing_eagerly():
                train_fn = functools.partial(self.evaluate, train_fn())
            optimizer_status.run_restore_ops()
            status.initialize_or_restore()
            init_only_optimizer_status.initialize_or_restore()
            train_fn()
            self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
    def proc_func(model_path, checkpoint_dir):
      global_batch_size = per_worker_batch_size * num_workers
      strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
      with strategy.scope():
        multi_worker_model = build_and_compile_cnn_model()

      callbacks = [
          keras.callbacks.ModelCheckpoint(
              filepath=os.path.join(self.get_temp_dir(), 'checkpoint'))
      ]

      multi_worker_dataset = mnist_dataset(global_batch_size)
      if shard_policy:
        options = dataset_ops.Options()
        options.experimental_distribute.auto_shard_policy = shard_policy
        multi_worker_dataset = multi_worker_dataset.with_options(options)

      multi_worker_model.fit(
          multi_worker_dataset,
          epochs=2,
          steps_per_epoch=20,
          callbacks=callbacks)

      def _is_chief(task_type, task_id):
        return task_type is None or task_type == 'chief' or (
            task_type == 'worker' and task_id == 0)

      def _get_temp_dir(dirpath, task_id):
        base_dirpath = 'workertemp_' + str(task_id)
        temp_dir = os.path.join(dirpath, base_dirpath)
        file_io.recursive_create_dir_v2(temp_dir)
        return temp_dir

      def write_filepath(filepath, task_type, task_id):
        dirpath = os.path.dirname(filepath)
        base = os.path.basename(filepath)
        if not _is_chief(task_type, task_id):
          dirpath = _get_temp_dir(dirpath, task_id)
        return os.path.join(dirpath, base)

      task_type, task_id = (strategy.cluster_resolver.task_type,
                            strategy.cluster_resolver.task_id)
      write_model_path = write_filepath(model_path, task_type, task_id)

      multi_worker_model.save(write_model_path)
      if not _is_chief(task_type, task_id):
        file_io.delete_recursively_v2(os.path.dirname(write_model_path))

      # Make sure chief finishes saving before non-chief's assertions.
      multi_process_runner.barrier().wait()

      if not file_io.file_exists(model_path):
        raise RuntimeError()
      if file_io.file_exists(write_model_path) != _is_chief(task_type, task_id):
        raise RuntimeError()

      loaded_model = keras.saving.save.load_model(model_path)
      loaded_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)

      checkpoint = tracking_util.Checkpoint(model=multi_worker_model)
      write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
      checkpoint_manager = checkpoint_management.CheckpointManager(
          checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

      checkpoint_manager.save()
      if not _is_chief(task_type, task_id):
        file_io.delete_recursively_v2(write_checkpoint_dir)

      # Make sure chief finishes saving before non-chief's assertions.
      multi_process_runner.barrier().wait()

      if not file_io.file_exists(checkpoint_dir):
        raise RuntimeError()
      if file_io.file_exists(write_checkpoint_dir) != _is_chief(
          task_type, task_id):
        raise RuntimeError()

      latest_checkpoint = checkpoint_management.latest_checkpoint(
          checkpoint_dir)
      checkpoint.restore(latest_checkpoint)
      multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)

      logging.info('testMultiWorkerTutorial successfully ends')
Esempio n. 32
0
def _export_estimator(estimator,
                      export_dir,
                      signature_fn,
                      input_fn,
                      default_batch_size,
                      exports_to_keep,
                      input_feature_key=None,
                      use_deprecated_input_fn=True,
                      prediction_key=None,
                      checkpoint_path=None):
  if use_deprecated_input_fn:
    input_fn = input_fn or _default_input_fn
  elif input_fn is None:
    raise ValueError('input_fn must be defined.')

  # If checkpoint_path is specified, use the specified checkpoint path.
  checkpoint_path = (checkpoint_path or
                     checkpoint_management.latest_checkpoint(
                         estimator._model_dir))
  with ops.Graph().as_default() as g:
    training_util.create_global_step(g)

    if use_deprecated_input_fn:
      examples = array_ops.placeholder(dtype=dtypes.string,
                                       shape=[default_batch_size],
                                       name='input_example_tensor')
      features = input_fn(estimator, examples)
    else:
      features, _ = input_fn()
      examples = None
      if input_feature_key is not None:
        examples = features.pop(input_feature_key)

    if (not features) and (examples is None):
      raise ValueError('Either features or examples must be defined.')

    predictions = estimator._get_predict_ops(features).predictions

    if prediction_key is not None:
      predictions = predictions[prediction_key]

    # Explicit signature_fn takes priority
    if signature_fn:
      default_signature, named_graph_signatures = signature_fn(examples,
                                                               features,
                                                               predictions)
    else:
      try:
        # Some estimators provide a signature function.
        # TODO(zakaria): check if the estimator has this function,
        #   raise helpful error if not
        signature_fn = estimator._create_signature_fn()

        default_signature, named_graph_signatures = (
            signature_fn(examples, features, predictions))
      except AttributeError:
        logging.warn(
            'Change warning: `signature_fn` will be required after'
            '2016-08-01.\n'
            'Using generic signatures for now.  To maintain this behavior, '
            'pass:\n'
            '  signature_fn=export.generic_signature_fn\n'
            'Also consider passing a regression or classification signature; '
            'see cl/126430915 for an example.')
        default_signature, named_graph_signatures = generic_signature_fn(
            examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    return _export_graph(
        g,
        _get_saver(),
        checkpoint_path,
        export_dir,
        default_graph_signature=default_signature,
        named_graph_signatures=named_graph_signatures,
        exports_to_keep=exports_to_keep)
Esempio n. 33
0
  def every_n_step_end(self, step, outputs):
    super(ValidationMonitor, self).every_n_step_end(step, outputs)
    # TODO(mdan): The use of step below is probably misleading.
    # The code should probably use the step from the checkpoint, because
    # that's what is being evaluated.
    if self._estimator is None:
      raise ValueError("Missing call to set_estimator.")
    current_time = time.time()
    if (self._check_interval_secs is not None and
        self._last_checkpoint_check_time is not None and
        current_time - self._last_checkpoint_check_time <=
        self._check_interval_secs):
      logging.debug(
          "Skipping evaluation since less than %d seconds have passed since "
          "last check for a new checkpoint.", self._check_interval_secs)
      return False
    self._last_checkpoint_check_time = current_time
    # Check that we are not running evaluation on the same checkpoint.
    latest_path = checkpoint_management.latest_checkpoint(
        self._estimator.model_dir)
    if latest_path is None:
      logging.debug("Skipping evaluation since model has not been saved yet "
                    "at step %d.", step)
      return False
    if latest_path is not None and latest_path == self._latest_path:
      logging.debug("Skipping evaluation due to same checkpoint %s for step %d "
                    "as for step %d.", latest_path, step,
                    self._latest_path_step)
      return False
    self._latest_path = latest_path
    self._latest_path_step = step

    # Run evaluation and log it.
    validation_outputs = self._evaluate_estimator()
    stats = []
    for name in validation_outputs:
      stats.append("%s = %s" % (name, str(validation_outputs[name])))
    logging.info("Validation (step %d): %s", step, ", ".join(stats))

    # Early stopping logic.
    if self.early_stopping_rounds is not None:
      if self.early_stopping_metric not in validation_outputs:
        raise ValueError("Metric %s missing from outputs %s." %
                         (self.early_stopping_metric,
                          set(validation_outputs.keys())))
      current_value = validation_outputs[self.early_stopping_metric]
      if (self._best_value is None or (self.early_stopping_metric_minimize and
                                       (current_value < self._best_value)) or
          (not self.early_stopping_metric_minimize and
           (current_value > self._best_value))):
        self._best_value = current_value
        self._best_metrics = copy.deepcopy(validation_outputs)
        self._best_value_step = step
      stop_now = (step - self._best_value_step >= self.early_stopping_rounds)
      if stop_now:
        logging.info("Stopping. Best step: {} with {} = {}.".format(
            self._best_value_step, self.early_stopping_metric,
            self._best_value))
        self._early_stopped = True
        return True
    return False
Esempio n. 34
0
def _get_checkpoint_filename(filepattern):
  """Returns checkpoint filename given directory or specific filepattern."""
  if gfile.IsDirectory(filepattern):
    return checkpoint_management.latest_checkpoint(filepattern)
  return filepattern
Esempio n. 35
0
 def _latest_ckpt(self):
   return checkpoint_management.latest_checkpoint(self.get_temp_dir())
Esempio n. 36
0
def _get_checkpoint_filename(ckpt_dir_or_file):
  """Returns checkpoint filename given directory or specific checkpoint file."""
  if gfile.IsDirectory(ckpt_dir_or_file):
    return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
  return ckpt_dir_or_file
Esempio n. 37
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)


  if len(argv) > 1:
    FLAGS.predict_file = argv[1]

  validate_flags_or_throw(bert_config)
  tf.gfile.MakeDirs(FLAGS.output_dir)
  moran_tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case, use_moran=True)
  basic_tokenizer = tokenization.BasicTokenizer(use_moran=False)

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      keep_checkpoint_max=FLAGS.keep_checkpoint_max,
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host))

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None
  if FLAGS.do_train :
    train_examples = read_squad_examples(input_file=FLAGS.train_file, is_training=True)
    num_train_steps = int( len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    # Pre-shuffle the input to avoid having to make a very large shuffle
    # buffer in in the `input_fn`.
    rng = random.Random(42)
    rng.shuffle(train_examples)

  model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)


  if FLAGS.do_train:

    train_record_exists = False
    train_writer = FeatureWriter(
        filename=os.path.join(FLAGS.output_dir, "train.tf_record"),
        is_training=True,
        record_file_exists=train_record_exists)
    convert_examples_to_features(
        examples=train_examples,
        tokenizer=moran_tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=True,
        output_fn=train_writer.process_feature)
    train_writer.close()

    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num orig examples = %d", len(train_examples))
    tf.logging.info("  Num split examples = %d", train_writer.num_features)
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num steps = %d", num_train_steps)
    del train_examples

    train_input_fn = input_fn_builder(
        input_file=train_writer.filename,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

  if FLAGS.do_predict :
    output_prediction_file_name = "predictions.json"
    output_nbest_file_name = "nbest_predictions.json"
    output_null_log_odds_file_name = "null_odds.json"
    if FLAGS.korquad_refine_answer_by_pos:
        output_prediction_file_name = "predictions_pos.json"

    if FLAGS.do_predict:
      eval_examples = read_squad_examples(input_file=FLAGS.predict_file, is_training=False)

    eval_record_exists = os.path.exists(os.path.join(FLAGS.output_dir, "eval.tf_record"))
    eval_record_exists = False
    if eval_record_exists:
      tf.logging.info("eval.tf_record exists. Do not write tf example file.")
    eval_writer = FeatureWriter(
        filename=os.path.join(FLAGS.output_dir, "eval.tf_record"),
        is_training=False,
        record_file_exists=eval_record_exists)
    eval_features = []

    def append_feature(feature):
      eval_features.append(feature)
      eval_writer.process_feature(feature)

    convert_examples_to_features(
        examples=eval_examples,
        tokenizer=moran_tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=False,
        output_fn=append_feature)
    eval_writer.close()

    tf.logging.info("***** Running predictions *****")
    tf.logging.info("  Num orig examples = %d", len(eval_examples))
    tf.logging.info("  Num split examples = %d", len(eval_features))
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    all_results = []

    predict_input_fn = input_fn_builder(
        input_file=eval_writer.filename,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False)

    if FLAGS.do_train:
      init_checkpoint = None
    else:
      if FLAGS.init_checkpoint is not None and tf.gfile.IsDirectory(FLAGS.init_checkpoint):
        from tensorflow.python.training import checkpoint_management
        init_checkpoint = checkpoint_management.latest_checkpoint(FLAGS.init_checkpoint)
      else:
        init_checkpoint = FLAGS.init_checkpoint

    all_results = []
    for result in estimator.predict(
        predict_input_fn, yield_single_examples=True, checkpoint_path=init_checkpoint):
      unique_id = int(result["unique_ids"])
      start_logits = [float(x) for x in result["start_logits"].flat]
      end_logits = [float(x) for x in result["end_logits"].flat]
      all_results.append(
          RawResult(
              unique_id=unique_id,
              start_logits=start_logits,
              end_logits=end_logits))

    if len(argv) > 1:
      output_prediction_file = os.path.join(FLAGS.output_dir, argv[2])
    else:
      output_prediction_file = os.path.join(FLAGS.output_dir, output_prediction_file_name)
    output_nbest_file = os.path.join(FLAGS.output_dir, output_nbest_file_name)
    output_null_log_odds_file = os.path.join(FLAGS.output_dir, output_null_log_odds_file_name)

    write_predictions(eval_examples, eval_features, all_results,
                      FLAGS.n_best_size, FLAGS.max_answer_length,
                      FLAGS.do_lower_case, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file, basic_tokenizer)
  def test_initialize_if_not_restoring(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    optimizer_only_prefix = os.path.join(checkpoint_directory, "opt")
    with test_util.device(use_gpu=True):
      model = MyModel()
      optimizer = adam.AdamOptimizer(0.001)
      root = checkpointable_utils.Checkpoint(
          model=model,  # Do not save the optimizer with the checkpoint.
          global_step=training_util.get_or_create_global_step())
      optimizer_checkpoint = checkpointable_utils.Checkpoint(
          optimizer=optimizer)

      checkpoint_path = checkpoint_management.latest_checkpoint(
          checkpoint_directory)
      status = root.restore(save_path=checkpoint_path)
      input_value = constant_op.constant([[3.]])
      train_fn = functools.partial(
          optimizer.minimize,
          functools.partial(model, input_value),
          global_step=root.global_step)
      if not context.executing_eagerly():
        train_fn = functools.partial(self.evaluate, train_fn())
      status.initialize_or_restore()
      self.evaluate([v.initializer for v in optimizer.variables()])
      train_fn()
      model_save_path = root.save(file_prefix=checkpoint_prefix)
      self.evaluate(optimizer.variables()[0].assign(42.))
      optimizer_save_path = optimizer_checkpoint.save(optimizer_only_prefix)

    # Restore into a graph with the optimizer
    with test_util.device(use_gpu=True):
      model = MyModel()
      optimizer = adam.AdamOptimizer(0.001)
      root = checkpointable_utils.Checkpoint(
          optimizer=optimizer, model=model,
          global_step=training_util.get_or_create_global_step())
      status = root.restore(save_path=model_save_path)
      input_value = constant_op.constant([[3.]])
      train_fn = functools.partial(
          optimizer.minimize,
          functools.partial(model, input_value),
          global_step=root.global_step)
      if not context.executing_eagerly():
        train_fn = functools.partial(self.evaluate, train_fn())
      status.initialize_or_restore()
      train_fn()
      with self.assertRaises(AssertionError):
        status.assert_existing_objects_matched()
      with self.assertRaises(AssertionError):
        status.assert_consumed()

    # Make sure initialization doesn't clobber later restores
    with test_util.device(use_gpu=True):
      model = MyModel()
      optimizer = adam.AdamOptimizer(0.001, beta1=1.0)
      root = checkpointable_utils.Checkpoint(
          optimizer=optimizer, model=model,
          global_step=training_util.get_or_create_global_step())
      opt_root = checkpointable_utils.Checkpoint(
          optimizer=optimizer)
      status = root.restore(save_path=model_save_path)
      init_only_optimizer_status = opt_root.restore(save_path=None)
      optimizer_status = opt_root.restore(save_path=optimizer_save_path)
      input_value = constant_op.constant([[3.]])
      train_fn = functools.partial(
          optimizer.minimize,
          functools.partial(model, input_value),
          global_step=root.global_step)
      if not context.executing_eagerly():
        train_fn = functools.partial(self.evaluate, train_fn())
      optimizer_status.run_restore_ops()
      status.initialize_or_restore()
      init_only_optimizer_status.initialize_or_restore()
      train_fn()
      self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
Esempio n. 39
0
def _get_checkpoint_filename(ckpt_dir_or_file):
    """Returns checkpoint filename given directory or specific checkpoint file."""
    if gfile.IsDirectory(ckpt_dir_or_file):
        return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
    return ckpt_dir_or_file
Esempio n. 40
0
  def continuous_train_and_eval(self, continuous_eval_predicate_fn=None):
    """Interleaves training and evaluation.

    The frequency of evaluation is controlled by the `train_steps_per_iteration`
    (via constructor). The model will be first trained for
    `train_steps_per_iteration`, and then be evaluated in turns.

    This method is intended for single machine usage.

    This differs from `train_and_evaluate` as follows:

      1. The procedure will have train and evaluation in turns. The model
      will be trained for a number of steps (usually smaller than `train_steps`
      if provided) and then be evaluated.  `train_and_evaluate` will train the
      model for `train_steps` (no small training iterations).

      2. Due to the different approach this schedule takes, it leads to two
      differences in resource control. First, the resources (e.g., memory) used
      by training will be released before evaluation (`train_and_evaluate` takes
      double resources). Second, more checkpoints will be saved as a checkpoint
      is generated at the end of each training iteration.

      3. As the estimator.train starts from scratch (new graph, new states for
      input, etc) at each iteration, it is recommended to have the
      `train_steps_per_iteration` larger. It is also recommended to shuffle your
      input.

    Args:
      continuous_eval_predicate_fn: A predicate function determining whether to
        continue eval after each iteration. A `predicate_fn` has one of the
        following signatures:
          * (eval_results) -> boolean
          * (eval_results, checkpoint_path) -> boolean
        Where `eval_results` is the dictionary of metric evaluations and
        checkpoint_path is the path to the checkpoint containing the parameters
        on which that evaluation was based.
        At the beginning of evaluation, the passed `eval_results` and
        `checkpoint_path` will be None so it's expected that the predicate
        function handles that gracefully.
        When `predicate_fn` is not specified, continuous eval will run in an
        infinite loop (if `train_steps` is None). or exit once global step
        reaches `train_steps`.

    Returns:
      A tuple of the result of the `evaluate` call to the `Estimator` and the
      export results using the specified `ExportStrategy`.

    Raises:
      ValueError: if `continuous_eval_predicate_fn` is neither None nor
        callable.
    """

    if continuous_eval_predicate_fn is not None:
      if not callable(continuous_eval_predicate_fn):
        raise ValueError(
            "`continuous_eval_predicate_fn` must be a callable, or None.")
      predicate_fn = _get_standardized_predicate_fn(
          continuous_eval_predicate_fn)
    else:
      predicate_fn = None

    export_results = None
    latest_checkpoint = None
    eval_result = None

    # Set the default value for train_steps_per_iteration, which will be
    # overridden by other settings.
    train_steps_per_iteration = 1000
    if self._train_steps_per_iteration is not None:
      train_steps_per_iteration = self._train_steps_per_iteration
    elif self._train_steps is not None:
      train_steps_per_iteration = int(self._train_steps / 10)

    while (not predicate_fn or predicate_fn(
        eval_result, checkpoint_path=latest_checkpoint
        if eval_result else None)):

      if self._has_training_stopped(eval_result):
        # Exits once max steps of training is satisfied.
        logging.info("Stop training model as max steps reached")
        break

      logging.info("Training model for %s steps", train_steps_per_iteration)
      self._call_train(
          input_fn=self._train_input_fn,
          steps=train_steps_per_iteration,
          hooks=self._train_monitors,
          saving_listeners=self._saving_listeners)

      logging.info("Evaluating model now.")
      latest_checkpoint = checkpoint_management.latest_checkpoint(
          self._estimator.model_dir)
      eval_result = self._call_evaluate(
          input_fn=self._eval_input_fn,
          steps=self._eval_steps,
          metrics=self._eval_metrics,
          name="one_pass",
          checkpoint_path=latest_checkpoint,
          hooks=self._eval_hooks)
      export_results = self._maybe_export(eval_result)

    return eval_result, export_results
Esempio n. 41
0
    def _get_most_recently_modified_file_matching_pattern(self, pattern):
        """Returns the most recently modified filepath matching pattern.

    Pattern may contain python formatting placeholder. If
    `tf.train.latest_checkpoint()` does not return None, use that; otherwise,
    check for most recently modified one that matches the pattern.

    In the rare case where there are more than one pattern-matching file having
    the same modified time that is most recent among all, return the filepath
    that is largest (by `>` operator, lexicographically using the numeric
    equivalents). This provides a tie-breaker when multiple files are most
    recent. Note that a larger `filepath` can sometimes indicate a later time of
    modification (for instance, when epoch/batch is used as formatting option),
    but not necessarily (when accuracy or loss is used). The tie-breaker is
    put in the logic as best effort to return the most recent, and to avoid
    undeterministic result.

    Modified time of a file is obtained with `os.path.getmtime()`.

    This utility function is best demonstrated via an example:

    ```python
    file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5'
    test_dir = self.get_temp_dir()
    path_pattern = os.path.join(test_dir, file_pattern)
    file_paths = [
        os.path.join(test_dir, file_name) for file_name in
        ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.batch01epoch01.h5']
    ]
    for file_path in file_paths:
      # Write something to each of the files
    self.assertEqual(
        _get_most_recently_modified_file_matching_pattern(path_pattern),
        file_paths[-1])
    ```

    Arguments:
        pattern: The file pattern that may optionally contain python placeholder
            such as `{epoch:02d}`.

    Returns:
        The most recently modified file's full filepath matching `pattern`. If
        `pattern` does not contain any placeholder, this returns the filepath
        that
        exactly matches `pattern`. Returns `None` if no match is found.
    """
        dir_name = os.path.dirname(pattern)
        base_name = os.path.basename(pattern)
        base_name_regex = '^' + re.sub(r'{.*}', r'.*', base_name) + '$'

        # If tf.train.latest_checkpoint tells us there exists a latest checkpoint,
        # use that as it is more robust than `os.path.getmtime()`.
        latest_tf_checkpoint = checkpoint_management.latest_checkpoint(
            dir_name)
        if latest_tf_checkpoint is not None and re.match(
                base_name_regex, os.path.basename(latest_tf_checkpoint)):
            return latest_tf_checkpoint

        latest_mod_time = 0
        file_path_with_latest_mod_time = None
        n_file_with_latest_mod_time = 0
        file_path_with_largest_file_name = None

        if os.path.exists(dir_name):
            for file_name in os.listdir(dir_name):
                # Only consider if `file_name` matches the pattern.
                if re.match(base_name_regex, file_name):
                    file_path = os.path.join(dir_name, file_name)
                    mod_time = os.path.getmtime(file_path)
                    if (file_path_with_largest_file_name is None
                            or file_path > file_path_with_largest_file_name):
                        file_path_with_largest_file_name = file_path
                    if mod_time > latest_mod_time:
                        latest_mod_time = mod_time
                        file_path_with_latest_mod_time = file_path
                        # In the case a file with later modified time is found, reset
                        # the counter for the number of files with latest modified time.
                        n_file_with_latest_mod_time = 1
                    elif mod_time == latest_mod_time:
                        # In the case a file has modified time tied with the most recent,
                        # increment the counter for the number of files with latest modified
                        # time by 1.
                        n_file_with_latest_mod_time += 1

        if n_file_with_latest_mod_time == 1:
            # Return the sole file that has most recent modified time.
            return file_path_with_latest_mod_time
        else:
            # If there are more than one file having latest modified time, return
            # the file path with the largest file name.
            return file_path_with_largest_file_name
Esempio n. 42
0
  def _continuous_eval(self,
                       input_fn,
                       name,
                       delay_secs,
                       throttle_delay_secs,
                       evaluate_checkpoint_only_once=True,
                       continuous_eval_predicate_fn=None,
                       export=True):
    """Run continuous eval.

    Runs infinite eval on the evaluation data set. This function starts
    evaluating after `delay_secs` seconds and then runs no more than one
    evaluation (with `self._eval_steps` steps each time) per
    `throttle_delay_secs`. If `train_steps` is not None, will return after
    global_step reaches `train_steps`.

    Args:
      input_fn: The input to use for this eval.
      name: A string appended to the folder name of evaluation results.
      delay_secs: Start evaluating after this many seconds. If None, defaults to
        self._eval_delay_secs.
      throttle_delay_secs: Do not re-evaluate unless the last evaluation was
        started at least this many seconds ago. If None, defaults to
        self._continuous_eval_throttle_secs.
      evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints
        that have already been evaluated. Default is `True`.
      continuous_eval_predicate_fn: A predicate function determining whether to
        continue eval after each iteration. A `predicate_fn` has one of the
        following signatures:
          * (eval_results) -> boolean
          * (eval_results, checkpoint_path) -> boolean
        Where `eval_results` is the dictionary of metric evaluations and
        checkpoint_path is the path to the checkpoint containing the parameters
        on which that evaluation was based.
        At the beginning of evaluation, the passed `eval_results` will be None
        so it's expected that the predicate function handles that gracefully.
        Continuous eval behavior under different conditions:
          * When `predicate_fn` is specified:
            + if `train_steps` is None, run until `predicate_fn` returns False.
            + if `train_steps` is specified, run until either global step
              reaches `train_steps` or `predicate_fn` returns False.
          * When `predicate_fn` is not specified:
            + if `train_steps` is None, run in an infinite loop.
            + if `train_steps` is specified, run until global step reaches
              `train_steps`.
      export: Whether to export from this step. Default is 'True'.

    Raises:
      ValueError: if `continuous_eval_predicate_fn` is neither None nor
        callable.
    """
    if continuous_eval_predicate_fn is not None:
      if not callable(continuous_eval_predicate_fn):
        raise ValueError(
            "`continuous_eval_predicate_fn` must be a callable, or None.")
      predicate_fn = _get_standardized_predicate_fn(
          continuous_eval_predicate_fn)
    else:
      predicate_fn = None

    if delay_secs is None:
      delay_secs = self._eval_delay_secs
    if throttle_delay_secs is None:
      throttle_delay_secs = self._continuous_eval_throttle_secs

    if delay_secs:
      logging.info("Waiting %f secs before starting eval.", delay_secs)
      time.sleep(delay_secs)

    previous_path = None
    eval_result = None
    last_warning_time = 0
    while (not predicate_fn or predicate_fn(
        eval_result, checkpoint_path=previous_path)):
      # Exit if we have already reached number of steps to train.
      if self._has_training_stopped(eval_result):
        logging.info("Exiting continuous eval, global_step=%s >= "
                     "train_step=%s", eval_result[ops.GraphKeys.GLOBAL_STEP],
                     self._train_steps)
        return

      start = time.time()

      error_msg = None
      latest_path = checkpoint_management.latest_checkpoint(
          self._estimator.model_dir)
      if not latest_path:
        error_msg = ("Estimator is not fitted yet. "
                     "Will start an evaluation when a checkpoint is ready.")
      elif evaluate_checkpoint_only_once and latest_path == previous_path:
        error_msg = "No new checkpoint ready for evaluation."

      if error_msg:
        # Print warning message every 10 mins.
        eval_result = {}
        if time.time() - last_warning_time > 600:
          logging.warning(error_msg)
          last_warning_time = time.time()
      else:
        eval_result = self._call_evaluate(
            input_fn=input_fn,
            steps=self._eval_steps,
            metrics=self._eval_metrics,
            name=name,
            checkpoint_path=latest_path,
            hooks=self._eval_hooks)
        # Ensure eval result is not None for next round of evaluation.
        if not eval_result:
          eval_result = {}

        if export:
          self._maybe_export(eval_result, checkpoint_path=latest_path)

        # Clear warning timer and update last evaluated checkpoint
        last_warning_time = 0
        previous_path = latest_path

      duration = time.time() - start
      if duration < throttle_delay_secs:
        difference = throttle_delay_secs - duration
        logging.info("Waiting %f secs before starting next eval run.",
                     difference)
        time.sleep(difference)
Esempio n. 43
0
  def testTrainWithInitFromCheckpoint(self):
    logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
    logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')

    if gfile.Exists(logdir1):  # For running on jenkins.
      gfile.DeleteRecursively(logdir1)
    if gfile.Exists(logdir2):  # For running on jenkins.
      gfile.DeleteRecursively(logdir2)

    # First, train the model one step (make sure the error is high).
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      train_op = self.create_train_op()
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir1,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir1, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=1),
          ],
          save_checkpoint_secs=None,
          save_summaries_steps=None)
      self.assertGreater(loss, .5)

    # Next, train the model to convergence.
    with ops.Graph().as_default():
      random_seed.set_random_seed(1)
      train_op = self.create_train_op()
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir1,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir1, save_steps=300, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=300),
          ],
          save_checkpoint_secs=None,
          save_summaries_steps=None)
      self.assertIsNotNone(loss)
      self.assertLess(loss, .02)

    # Finally, advance the model a single step and validate that the loss is
    # still low.
    with ops.Graph().as_default():
      random_seed.set_random_seed(2)
      train_op = self.create_train_op()

      model_variables = variables_lib2.global_variables()
      model_path = checkpoint_management.latest_checkpoint(logdir1)

      assign_fn = variables_lib.assign_from_checkpoint_fn(
          model_path, model_variables)

      def init_fn(_, session):
        assign_fn(session)

      loss = training.train(
          train_op,
          None,
          scaffold=monitored_session.Scaffold(init_fn=init_fn),
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)],
          save_checkpoint_secs=None,
          save_summaries_steps=None)

      self.assertIsNotNone(loss)
      self.assertLess(loss, .02)
Esempio n. 44
0
import re
import sys
import os
from tensorflow.python.training import checkpoint_management
from shutil import copyfile

if len(sys.argv) < 2:
    print("Usage: fix_embed_mask_name.py [chkpt-dir]")
    print(
        "Fix the latest checkpoint in the directory by renaming the embedding mask variable."
    )
    print("Note: overwrites original checkpoint!")
    exit()
else:
    CHECKPOINT_DIR = sys.argv[1].rstrip('/')
    CHECKPOINT_FILE = checkpoint_management.latest_checkpoint(CHECKPOINT_DIR)

with tf.Session() as sess:

    # Load all the variables from the checkpoint, renaming as we go
    for var_name, _ in tf.train.list_variables(CHECKPOINT_FILE):
        var_tensor = tf.contrib.framework.load_variable(
            CHECKPOINT_FILE, var_name)

        new_name = re.sub(r'bert/embeddings//mask',
                          r'bert/embeddings/embed_mask/mask', var_name)
        new_name = re.sub(r'bert/embeddings//threshold',
                          r'bert/embeddings/embed_mask/threshold', new_name)

        if new_name != var_name:
            print(f"Renaming {var_name} to {new_name}")
Esempio n. 45
0
def _get_checkpoint_filename(filepattern):
    """Returns checkpoint filename given directory or specific filepattern."""
    if gfile.IsDirectory(filepattern):
        return checkpoint_management.latest_checkpoint(filepattern)
    return filepattern