def testRecoverSession(self): # Create a checkpoint. checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session") try: gfile.DeleteRecursively(checkpoint_dir) except errors.OpError: pass # Ignore gfile.MakeDirs(checkpoint_dir) with ops.Graph().as_default(): v = variables.Variable(1, name="v") sm = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) saver = saver_lib.Saver({"v": v}) sess, initialized = sm.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir) self.assertFalse(initialized) sess.run(v.initializer) self.assertEquals(1, sess.run(v)) saver.save(sess, os.path.join(checkpoint_dir, "recover_session_checkpoint")) self._test_recovered_variable(checkpoint_dir=checkpoint_dir) self._test_recovered_variable( checkpoint_filename_with_path=checkpoint_management.latest_checkpoint( checkpoint_dir)) # Cannot set both checkpoint_dir and checkpoint_filename_with_path. with self.assertRaises(ValueError): self._test_recovered_variable( checkpoint_dir=checkpoint_dir, checkpoint_filename_with_path=checkpoint_management.latest_checkpoint( checkpoint_dir))
def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None): """Exports the given Estimator as a SavedModel. Args: estimator: the Estimator to export. export_dir_base: A string containing a directory to write the exported graph and checkpoints. checkpoint_path: The checkpoint path to export. If None (the default), the most recent checkpoint found within the model directory is chosen. eval_result: placehold args matching the call signature of ExportStrategy. Returns: The string path to the exported directory. """ if not checkpoint_path: # TODO(b/67425018): switch to # checkpoint_path = estimator.latest_checkpoint() # as soon as contrib is cleaned up and we can thus be sure that # estimator is a tf.estimator.Estimator and not a # tf.contrib.learn.Estimator checkpoint_path = checkpoint_management.latest_checkpoint( estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) if export_checkpoint_path and export_eval_result is not None: checkpoint_base = os.path.basename(export_checkpoint_path) export_dir = os.path.join(export_dir_base, checkpoint_base) return best_model_export_strategy.export( estimator, export_dir, export_checkpoint_path, export_eval_result) else: return ''
def testGraphDistributionStrategy(self): self.skipTest("b/121381184") num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") def _train_fn(optimizer, model): input_value = constant_op.constant([[3.]]) return optimizer.minimize( functools.partial(model, input_value), global_step=root.optimizer_step) for training_continuation in range(3): with ops.Graph().as_default(): strategy = mirrored_strategy.MirroredStrategy() with strategy.scope(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) status = root.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)) train_op = strategy.extended.call_for_each_replica( functools.partial(_train_fn, optimizer, model)) with self.session() as session: if training_continuation > 0: status.assert_consumed() status.initialize_or_restore() for _ in range(num_training_steps): session.run(train_op) root.save(file_prefix=checkpoint_prefix) self.assertEqual((training_continuation + 1) * num_training_steps, root.optimizer_step.numpy())
def testEagerTPUDistributionStrategy(self): self.skipTest("b/121387144") num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") def _train_fn(optimizer, model): input_value = constant_op.constant([[3.]]) optimizer.minimize( functools.partial(model, input_value), global_step=root.optimizer_step) for training_continuation in range(3): strategy = tpu_strategy.TPUStrategy() with strategy.scope(): model = Subclassed() optimizer = adam_v1.AdamOptimizer(0.001) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) root.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)) for _ in range(num_training_steps): strategy.extended.call_for_each_replica( functools.partial(_train_fn, optimizer, model)) root.save(file_prefix=checkpoint_prefix) self.assertEqual((training_continuation + 1) * num_training_steps, root.optimizer_step.numpy())
def testUsageGraph(self): """Expected usage when graph building.""" with context.graph_mode(): num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): with ops.Graph().as_default(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) input_value = constant_op.constant([[3.]]) train_op = optimizer.minimize( model(input_value), global_step=root.global_step) checkpoint_path = checkpoint_management.latest_checkpoint( checkpoint_directory) with self.session(graph=ops.get_default_graph()) as session: status = root.restore(save_path=checkpoint_path) status.initialize_or_restore(session=session) if checkpoint_path is None: self.assertEqual(0, training_continuation) with self.assertRaises(AssertionError): status.assert_consumed() else: status.assert_consumed() for _ in range(num_training_steps): session.run(train_op) root.save(file_prefix=checkpoint_prefix, session=session) self.assertEqual((training_continuation + 1) * num_training_steps, session.run(root.global_step)) self.assertEqual(training_continuation + 1, session.run(root.save_counter))
def testAgnosticUsage(self): """Graph/eager agnostic usage.""" # Does create garbage when executing eagerly due to ops.Graph() creation. num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): with ops.Graph().as_default(), self.test_session( graph=ops.get_default_graph()), test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) checkpoint_path = checkpoint_management.latest_checkpoint( checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( optimizer.minimize, functools.partial(model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() for _ in range(num_training_steps): train_fn() root.save(file_prefix=checkpoint_prefix) self.assertEqual((training_continuation + 1) * num_training_steps, self.evaluate(root.global_step)) self.assertEqual(training_continuation + 1, self.evaluate(root.save_counter))
def wait_for_new_checkpoint(checkpoint_dir, last_checkpoint=None, seconds_to_sleep=1, timeout=None): """Waits until a new checkpoint file is found. Args: checkpoint_dir: The directory in which checkpoints are saved. last_checkpoint: The last checkpoint path used or `None` if we're expecting a checkpoint for the first time. seconds_to_sleep: The number of seconds to sleep for before looking for a new checkpoint. timeout: The maximum number of seconds to wait. If left as `None`, then the process will wait indefinitely. Returns: a new checkpoint path, or None if the timeout was reached. """ logging.info('Waiting for new checkpoint at %s', checkpoint_dir) stop_time = time.time() + timeout if timeout is not None else None while True: checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir) if checkpoint_path is None or checkpoint_path == last_checkpoint: if stop_time is not None and time.time() + seconds_to_sleep > stop_time: return None time.sleep(seconds_to_sleep) else: logging.info('Found new checkpoint at %s', checkpoint_path) return checkpoint_path
def _new_layer_weight_loading_test_template( self, first_model_fn, second_model_fn, restore_init_fn): with self.cached_session() as session: model = first_model_fn() temp_dir = self.get_temp_dir() prefix = os.path.join(temp_dir, 'ckpt') x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32) executing_eagerly = context.executing_eagerly() ref_y_tensor = model(x) if not executing_eagerly: session.run([v.initializer for v in model.variables]) ref_y = self.evaluate(ref_y_tensor) model.save_weights(prefix) self.assertEqual( prefix, checkpoint_management.latest_checkpoint(temp_dir)) for v in model.variables: self.evaluate( v.assign(random_ops.random_normal(shape=array_ops.shape(v)))) self.addCleanup(shutil.rmtree, temp_dir) second_model = second_model_fn() second_model.load_weights(prefix) second_model(x) self.evaluate(restore_init_fn(second_model)) second_model.save_weights(prefix) # Check that the second model's checkpoint loads into the original model model.load_weights(prefix) y = self.evaluate(model(x)) self.assertAllClose(ref_y, y)
def _restore_or_save_initial_ckpt(self, session): # Ideally this should be run in after_create_session but is not for the # following reason: # Currently there is no way of enforcing an order of running the # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` # is run *after* this hook. That is troublesome because # 1. If a checkpoint exists and this hook restores it, the initializer hook # will override it. # 2. If no checkpoint exists, this hook will try to save an uninitialized # iterator which will result in an exception. # # As a temporary fix we enter the following implicit contract between this # hook and the _DatasetInitializerHook. # 1. The _DatasetInitializerHook initializes the iterator in the call to # after_create_session. # 2. This hook saves the iterator on the first call to `before_run()`, which # is guaranteed to happen after `after_create_session()` of all hooks # have been run. # Check if there is an existing checkpoint. If so, restore from it. # pylint: disable=protected-access latest_checkpoint_path = checkpoint_management.latest_checkpoint( self._checkpoint_saver_hook._checkpoint_dir, latest_filename=self._latest_filename) if latest_checkpoint_path: self._checkpoint_saver_hook._get_saver().restore(session, latest_checkpoint_path) else: # The checkpoint saved here is the state at step "global_step". # Note: We do not save the GraphDef or MetaGraphDef here. global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) self._checkpoint_saver_hook._save(session, global_step) self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
def _read_vars(self, model_dir): """Returns (global_step, latest_feature).""" with ops.Graph().as_default() as g: ckpt_path = checkpoint_management.latest_checkpoint(model_dir) meta_filename = ckpt_path + '.meta' saver_lib.import_meta_graph(meta_filename) saver = saver_lib.Saver() with self.test_session(graph=g) as sess: saver.restore(sess, ckpt_path) return sess.run(ops.get_collection('my_vars'))
def testRestoreInReconstructedIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') dataset = Dataset.range(10) for i in range(5): iterator = datasets.Iterator(dataset) checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) checkpoint.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)) for j in range(2): self.assertEqual(i * 2 + j, iterator.get_next().numpy()) checkpoint.save(file_prefix=checkpoint_prefix)
def testRestoreInReconstructedIteratorInitializable(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") dataset = dataset_ops.Dataset.range(10) iterator = dataset.make_initializable_iterator() get_next = iterator.get_next() checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) for i in range(5): with self.cached_session() as sess: checkpoint.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)).initialize_or_restore(sess) for j in range(2): self.assertEqual(i * 2 + j, self.evaluate(get_next)) checkpoint.save(file_prefix=checkpoint_prefix)
def testRestoreInReconstructedIteratorInitializable(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") dataset = dataset_ops.Dataset.range(10) iterator = iter(dataset) if context.executing_eagerly( ) else dataset_ops.make_initializable_iterator(dataset) get_next = iterator.get_next checkpoint = trackable_utils.Checkpoint(iterator=iterator) for i in range(5): checkpoint.restore( checkpoint_management.latest_checkpoint( checkpoint_directory)).initialize_or_restore() for j in range(2): self.assertEqual(i * 2 + j, self.evaluate(get_next())) checkpoint.save(file_prefix=checkpoint_prefix)
def testNameCollision(self): # Make sure we have a clean directory to work in. with self.tempDir() as tempdir: # Jump to that directory until this test is done. with self.tempWorkingDir(tempdir): # Save training snapshots to a relative path. traindir = "train/" os.mkdir(traindir) # Collides with the default name of the checkpoint state file. filepath = os.path.join(traindir, "checkpoint") with self.test_session() as sess: unused_a = variables.Variable(0.0) # So that Saver saves something. variables.global_variables_initializer().run() # Should fail. saver = saver_module.Saver(sharded=False) with self.assertRaisesRegexp(ValueError, "collides with"): saver.save(sess, filepath) # Succeeds: the file will be named "checkpoint-<step>". saver.save(sess, filepath, global_step=1) self.assertIsNotNone( checkpoint_management.latest_checkpoint(traindir)) # Succeeds: the file will be named "checkpoint-<i>-of-<n>". saver = saver_module.Saver(sharded=True) saver.save(sess, filepath) self.assertIsNotNone( checkpoint_management.latest_checkpoint(traindir)) # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>". saver = saver_module.Saver(sharded=True) saver.save(sess, filepath, global_step=1) self.assertIsNotNone( checkpoint_management.latest_checkpoint(traindir))
def _GetGraphDef(self, use_trt, max_batch_size, model_dir): """Get the frozen mnist GraphDef. Args: use_trt: whether use TF-TRT to convert the graph. max_batch_size: the max batch size to apply during TF-TRT conversion. model_dir: the model directory to load the checkpoints. Returns: The frozen mnist GraphDef. """ graph = ops.Graph() with self.session(graph=graph) as sess: with graph.device('/GPU:0'): x = array_ops.placeholder( shape=(None, 28, 28, 1), dtype=dtypes.float32, name=INPUT_NODE_NAME) self._BuildGraph(x) # Load weights mnist_saver = saver.Saver() checkpoint_file = latest_checkpoint(model_dir) mnist_saver.restore(sess, checkpoint_file) # Freeze graph_def = graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=[OUTPUT_NODE_NAME]) # Convert with TF-TRT if use_trt: logging.info('Number of nodes before TF-TRT conversion: %d', len(graph_def.node)) converter = trt_convert.TrtGraphConverter( input_graph_def=graph_def, nodes_blacklist=[OUTPUT_NODE_NAME], max_batch_size=max_batch_size, precision_mode='INT8', # There is a 2GB GPU memory limit for each test, so we set # max_workspace_size_bytes to 256MB to leave enough room for TF # runtime to allocate GPU memory. max_workspace_size_bytes=1 << 28, minimum_segment_size=2, use_calibration=False, use_function_backup=False) graph_def = converter.convert() logging.info('Number of nodes after TF-TRT conversion: %d', len(graph_def.node)) num_engines = len( [1 for n in graph_def.node if str(n.op) == 'TRTEngineOp']) self.assertEqual(1, num_engines) return graph_def
def __init__(self, estimator, prediction_input_fn, input_alternative_key=None, output_alternative_key=None, graph=None, config=None): """Initialize a `ContribEstimatorPredictor`. Args: estimator: an instance of `tf.contrib.learn.Estimator`. prediction_input_fn: a function that takes no arguments and returns an instance of `InputFnOps`. input_alternative_key: Optional. Specify the input alternative used for prediction. output_alternative_key: Specify the output alternative used for prediction. Not needed for single-headed models but required for multi-headed models. graph: Optional. The Tensorflow `graph` in which prediction should be done. config: `ConfigProto` proto used to configure the session. """ self._graph = graph or ops.Graph() with self._graph.as_default(): input_fn_ops = prediction_input_fn() # pylint: disable=protected-access model_fn_ops = estimator._get_predict_ops(input_fn_ops.features) # pylint: enable=protected-access checkpoint_path = checkpoint_management.latest_checkpoint( estimator.model_dir) self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( config=config, checkpoint_filename_with_path=checkpoint_path)) input_alternative_key = ( input_alternative_key or saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY) input_alternatives, _ = saved_model_export_utils.get_input_alternatives( input_fn_ops) self._feed_tensors = input_alternatives[input_alternative_key] (output_alternatives, output_alternative_key) = saved_model_export_utils.get_output_alternatives( model_fn_ops, output_alternative_key) _, fetch_tensors = output_alternatives[output_alternative_key] self._fetch_tensors = fetch_tensors
def testCheckpointExists(self): for sharded in (False, True): for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): with self.session(graph=ops_lib.Graph()) as sess: unused_v = variables.Variable(1.0, name="v") variables.global_variables_initializer().run() saver = saver_module.Saver(sharded=sharded, write_version=version) path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) self.assertFalse( checkpoint_management.checkpoint_exists(path)) # Not saved yet. ckpt_prefix = saver.save(sess, path) self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir) self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
def testRelativePath(self): # Make sure we have a clean directory to work in. with self.tempDir() as tempdir: # Jump to that directory until this test is done. with self.tempWorkingDir(tempdir): # Save training snapshots to a relative path. traindir = "train/" os.mkdir(traindir) filename = "snapshot" filepath = os.path.join(traindir, filename) with self.test_session() as sess: # Build a simple graph. v0 = variables.Variable(0.0) inc = v0.assign_add(1.0) save = saver_module.Saver({"v0": v0}) # Record a short training history. variables.global_variables_initializer().run() save.save(sess, filepath, global_step=0) inc.eval() save.save(sess, filepath, global_step=1) inc.eval() save.save(sess, filepath, global_step=2) with self.test_session() as sess: # Build a new graph with different initialization. v0 = variables.Variable(-1.0) # Create a new saver. save = saver_module.Saver({"v0": v0}) variables.global_variables_initializer().run() # Get the most recent checkpoint name from the training history file. name = checkpoint_management.latest_checkpoint(traindir) self.assertIsNotNone(name) # Restore "v0" from that checkpoint. save.restore(sess, name) self.assertEqual(v0.eval(), 2.0)
def testTrainSpinn(self): """Test with fake toy SNLI data and GloVe vectors.""" # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") fake_train_file = self._create_test_data(snli_1_0_dir) vocab = data.load_vocabulary(self._temp_data_dir) word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) train_data = data.SnliData(fake_train_file, word2index) dev_data = data.SnliData(fake_train_file, word2index) test_data = data.SnliData(fake_train_file, word2index) # 2. Create a fake config. config = _test_spinn_config( data.WORD_VECTOR_LEN, 4, logdir=os.path.join(self._temp_data_dir, "logdir")) # 3. Test training of a SPINN model. trainer = spinn.train_or_infer_spinn( embed, word2index, train_data, dev_data, test_data, config) # 4. Load train loss values from the summary files and verify that they # decrease with training. summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] events = summary_test_util.events_from_file(summary_file) train_losses = [event.summary.value[0].simple_value for event in events if event.summary.value and event.summary.value[0].tag == "train/loss"] self.assertEqual(config.epochs, len(train_losses)) # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) object_graph = checkpointable_utils.object_metadata( checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: ckpt_variable_names.add(attribute.full_name) self.assertIn("global_step", ckpt_variable_names) for v in trainer.variables: variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name self.assertIn(variable_name, ckpt_variable_names)
def after_save(self, session, global_step_value): """Evaluates and exports the model after a checkpoint is created.""" # Load and cache the path of the most recent checkpoint to avoid duplicate # searches on GCS. logging.info("Checking for checkpoint in %s", self._model_dir) latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: logging.warning("Skipping evaluation and export since model has not been " "saved yet.") elif latest_path == self._latest_path: logging.warning("Skipping evaluation due to same latest checkpoint %s.", latest_path) else: self._latest_path = latest_path self._eval_result = self._eval_fn( name="intermediate_export", checkpoint_path=latest_path) self._export_results = self._export_fn( self._eval_result, checkpoint_path=latest_path)
def testWithDefun(self): num_training_steps = 2 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): with ops.Graph().as_default(), self.test_session( graph=ops.get_default_graph()), test_util.device(use_gpu=True): model = MyModel() # Don't actually train so we can test variable values optimizer = adam.AdamOptimizer(0.) root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) checkpoint_path = checkpoint_management.latest_checkpoint( checkpoint_directory) status = root.restore(save_path=checkpoint_path) def train_fn(): @function.defun def _call_model(x): return model(x) with backprop.GradientTape() as tape: loss = _call_model(constant_op.constant([[3.]])) gradients = tape.gradient(loss, model.variables) return optimizer.apply_gradients(zip(gradients, model.variables), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial( self.evaluate, train_fn()) status.initialize_or_restore() for _ in range(num_training_steps): train_fn() if training_continuation > 0: status.assert_consumed() self.assertAllClose([[42.]], self.evaluate(model.variables[0])) else: self.evaluate(model.variables[0].assign([[42.]])) root.save(file_prefix=checkpoint_prefix) self.assertEqual((training_continuation + 1) * num_training_steps, self.evaluate(root.global_step)) self.assertEqual(training_continuation + 1, self.evaluate(root.save_counter))
def _save_first_checkpoint(keras_model, custom_objects, config): """Save first checkpoint for the keras Estimator. Args: keras_model: an instance of compiled keras model. custom_objects: Dictionary for custom objects. config: Estimator config. Returns: The path where keras model checkpoint is saved. """ # save checkpoint into subdirectory to allow warm start keras_model_dir = os.path.join(config.model_dir, 'keras') # Load weights and save to checkpoint if there is no checkpoint latest_path = checkpoint_management.latest_checkpoint(keras_model_dir) if not latest_path: keras_weights = None if _any_weight_initialized(keras_model): keras_weights = keras_model.get_weights() if not gfile.IsDirectory(keras_model_dir): gfile.MakeDirs(keras_model_dir) with ops.Graph().as_default(): random_seed.set_random_seed(config.tf_random_seed) training_util.create_global_step() model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model, custom_objects) # save to checkpoint with session.Session(config=config.session_config) as sess: if keras_weights: model.set_weights(keras_weights) # Make update ops and initialize all variables. if not model.train_function: # pylint: disable=protected-access model._make_train_function() K._initialize_variables(sess) # pylint: enable=protected-access saver = saver_lib.Saver() latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt') saver.save(sess, latest_path) return latest_path
def testDeferredRestorationUsageEager(self): """An idiomatic eager execution example.""" num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) root.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) optimizer.minimize( lambda: model(input_value), # pylint: disable=cell-var-from-loop global_step=root.optimizer_step) root.save(file_prefix=checkpoint_prefix) self.assertEqual((training_continuation + 1) * num_training_steps, root.optimizer_step.numpy())
def end(self, session=None): super(ExportMonitor, self).end(session=session) latest_path = checkpoint_management.latest_checkpoint( self._estimator.model_dir) if latest_path is None: logging.info("Skipping export at the end since model has not been saved " "yet.") return if isinstance(self._estimator, core_estimator.Estimator): raise ValueError( "ExportMonitor does not support `tf.estimator.Estimator. `. " "Please pass an ExportStrategy to Experiment instead.") try: self._last_export_dir = self._estimator.export( self.export_dir, exports_to_keep=self.exports_to_keep, signature_fn=self.signature_fn, input_fn=self._input_fn, default_batch_size=self._default_batch_size, input_feature_key=self._input_feature_key, use_deprecated_input_fn=self._use_deprecated_input_fn) except RuntimeError: logging.info("Skipping exporting for the same step.")
def testCustomNumbering(self): directory = self.get_temp_dir() step = variables.Variable(0, dtype=dtypes.int64) checkpoint = util.Checkpoint(step=step) manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2) self.evaluate(step.initializer) for i in range(5): path = manager.save(checkpoint_number=step) expected_suffix = "-%d" % (2 * i,) if not path.endswith(expected_suffix): self.fail("%s should have suffix %s" % (path, expected_suffix)) self.evaluate(step.assign_add(2)) self.assertEqual(5, self.evaluate(checkpoint.save_counter)) # Test regular integers last_path = manager.save(checkpoint_number=32) self.assertIn("-32", last_path) self.assertEqual(last_path, manager.latest_checkpoint) self.assertEqual( last_path, checkpoint_management.latest_checkpoint(directory)) state = checkpoint_management.get_checkpoint_state(directory) # Only the most recent two checkpoints are saved self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
def _get_step(self): ckpt = checkpoint_management.latest_checkpoint(self._estimator.model_dir) if ckpt: return int(os.path.basename(ckpt).split('-')[1]) else: return 0
def _latest_ckpt(self): return checkpoint_management.latest_checkpoint(self.get_temp_dir())
def _get_checkpoint_filename(ckpt_dir_or_file): if gfile.IsDirectory(ckpt_dir_or_file): return checkpoint_management.latest_checkpoint(ckpt_dir_or_file) return ckpt_dir_or_file
def every_n_step_end(self, step, outputs): super(ValidationMonitor, self).every_n_step_end(step, outputs) # TODO(mdan): The use of step below is probably misleading. # The code should probably use the step from the checkpoint, because # that's what is being evaluated. if self._estimator is None: raise ValueError("Missing call to set_estimator.") current_time = time.time() if (self._check_interval_secs is not None and self._last_checkpoint_check_time is not None and current_time - self._last_checkpoint_check_time <= self._check_interval_secs): logging.debug( "Skipping evaluation since less than %d seconds have passed since " "last check for a new checkpoint.", self._check_interval_secs) return False self._last_checkpoint_check_time = current_time # Check that we are not running evaluation on the same checkpoint. latest_path = checkpoint_management.latest_checkpoint( self._estimator.model_dir) if latest_path is None: logging.debug("Skipping evaluation since model has not been saved yet " "at step %d.", step) return False if latest_path is not None and latest_path == self._latest_path: logging.debug("Skipping evaluation due to same checkpoint %s for step %d " "as for step %d.", latest_path, step, self._latest_path_step) return False self._latest_path = latest_path self._latest_path_step = step # Run evaluation and log it. validation_outputs = self._evaluate_estimator() stats = [] for name in validation_outputs: stats.append("%s = %s" % (name, str(validation_outputs[name]))) logging.info("Validation (step %d): %s", step, ", ".join(stats)) # Early stopping logic. if self.early_stopping_rounds is not None: if self.early_stopping_metric not in validation_outputs: raise ValueError("Metric %s missing from outputs %s." % (self.early_stopping_metric, set(validation_outputs.keys()))) current_value = validation_outputs[self.early_stopping_metric] if (self._best_value is None or (self.early_stopping_metric_minimize and (current_value < self._best_value)) or (not self.early_stopping_metric_minimize and (current_value > self._best_value))): self._best_value = current_value self._best_metrics = copy.deepcopy(validation_outputs) self._best_value_step = step stop_now = (step - self._best_value_step >= self.early_stopping_rounds) if stop_now: logging.info("Stopping. Best step: {} with {} = {}.".format( self._best_value_step, self.early_stopping_metric, self._best_value)) self._early_stopped = True return True return False
def test_initialize_if_not_restoring(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") optimizer_only_prefix = os.path.join(checkpoint_directory, "opt") with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = trackable_utils.Checkpoint( model=model, # Do not save the optimizer with the checkpoint. global_step=training_util.get_or_create_global_step()) optimizer_checkpoint = trackable_utils.Checkpoint( optimizer=optimizer) checkpoint_path = checkpoint_management.latest_checkpoint( checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial(optimizer.minimize, functools.partial(model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() self.evaluate([v.initializer for v in optimizer.variables()]) train_fn() model_save_path = root.save(file_prefix=checkpoint_prefix) self.evaluate(optimizer.variables()[0].assign(42.)) optimizer_save_path = optimizer_checkpoint.save( optimizer_only_prefix) # Restore into a graph with the optimizer with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = trackable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) status = root.restore(save_path=model_save_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial(optimizer.minimize, functools.partial(model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() train_fn() with self.assertRaises(AssertionError): status.assert_existing_objects_matched() with self.assertRaises(AssertionError): status.assert_consumed() # Make sure initialization doesn't clobber later restores with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001, beta1=1.0) root = trackable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) opt_root = trackable_utils.Checkpoint(optimizer=optimizer) status = root.restore(save_path=model_save_path) init_only_optimizer_status = opt_root.restore(save_path=None) optimizer_status = opt_root.restore(save_path=optimizer_save_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial(optimizer.minimize, functools.partial(model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) optimizer_status.run_restore_ops() status.initialize_or_restore() init_only_optimizer_status.initialize_or_restore() train_fn() self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
def proc_func(model_path, checkpoint_dir): global_batch_size = per_worker_batch_size * num_workers strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() with strategy.scope(): multi_worker_model = build_and_compile_cnn_model() callbacks = [ keras.callbacks.ModelCheckpoint( filepath=os.path.join(self.get_temp_dir(), 'checkpoint')) ] multi_worker_dataset = mnist_dataset(global_batch_size) if shard_policy: options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = shard_policy multi_worker_dataset = multi_worker_dataset.with_options(options) multi_worker_model.fit( multi_worker_dataset, epochs=2, steps_per_epoch=20, callbacks=callbacks) def _is_chief(task_type, task_id): return task_type is None or task_type == 'chief' or ( task_type == 'worker' and task_id == 0) def _get_temp_dir(dirpath, task_id): base_dirpath = 'workertemp_' + str(task_id) temp_dir = os.path.join(dirpath, base_dirpath) file_io.recursive_create_dir_v2(temp_dir) return temp_dir def write_filepath(filepath, task_type, task_id): dirpath = os.path.dirname(filepath) base = os.path.basename(filepath) if not _is_chief(task_type, task_id): dirpath = _get_temp_dir(dirpath, task_id) return os.path.join(dirpath, base) task_type, task_id = (strategy.cluster_resolver.task_type, strategy.cluster_resolver.task_id) write_model_path = write_filepath(model_path, task_type, task_id) multi_worker_model.save(write_model_path) if not _is_chief(task_type, task_id): file_io.delete_recursively_v2(os.path.dirname(write_model_path)) # Make sure chief finishes saving before non-chief's assertions. multi_process_runner.barrier().wait() if not file_io.file_exists(model_path): raise RuntimeError() if file_io.file_exists(write_model_path) != _is_chief(task_type, task_id): raise RuntimeError() loaded_model = keras.saving.save.load_model(model_path) loaded_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20) checkpoint = tracking_util.Checkpoint(model=multi_worker_model) write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id) checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, directory=write_checkpoint_dir, max_to_keep=1) checkpoint_manager.save() if not _is_chief(task_type, task_id): file_io.delete_recursively_v2(write_checkpoint_dir) # Make sure chief finishes saving before non-chief's assertions. multi_process_runner.barrier().wait() if not file_io.file_exists(checkpoint_dir): raise RuntimeError() if file_io.file_exists(write_checkpoint_dir) != _is_chief( task_type, task_id): raise RuntimeError() latest_checkpoint = checkpoint_management.latest_checkpoint( checkpoint_dir) checkpoint.restore(latest_checkpoint) multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20) logging.info('testMultiWorkerTutorial successfully ends')
def _export_estimator(estimator, export_dir, signature_fn, input_fn, default_batch_size, exports_to_keep, input_feature_key=None, use_deprecated_input_fn=True, prediction_key=None, checkpoint_path=None): if use_deprecated_input_fn: input_fn = input_fn or _default_input_fn elif input_fn is None: raise ValueError('input_fn must be defined.') # If checkpoint_path is specified, use the specified checkpoint path. checkpoint_path = (checkpoint_path or checkpoint_management.latest_checkpoint( estimator._model_dir)) with ops.Graph().as_default() as g: training_util.create_global_step(g) if use_deprecated_input_fn: examples = array_ops.placeholder(dtype=dtypes.string, shape=[default_batch_size], name='input_example_tensor') features = input_fn(estimator, examples) else: features, _ = input_fn() examples = None if input_feature_key is not None: examples = features.pop(input_feature_key) if (not features) and (examples is None): raise ValueError('Either features or examples must be defined.') predictions = estimator._get_predict_ops(features).predictions if prediction_key is not None: predictions = predictions[prediction_key] # Explicit signature_fn takes priority if signature_fn: default_signature, named_graph_signatures = signature_fn(examples, features, predictions) else: try: # Some estimators provide a signature function. # TODO(zakaria): check if the estimator has this function, # raise helpful error if not signature_fn = estimator._create_signature_fn() default_signature, named_graph_signatures = ( signature_fn(examples, features, predictions)) except AttributeError: logging.warn( 'Change warning: `signature_fn` will be required after' '2016-08-01.\n' 'Using generic signatures for now. To maintain this behavior, ' 'pass:\n' ' signature_fn=export.generic_signature_fn\n' 'Also consider passing a regression or classification signature; ' 'see cl/126430915 for an example.') default_signature, named_graph_signatures = generic_signature_fn( examples, features, predictions) if exports_to_keep is not None: exports_to_keep = gc.largest_export_versions(exports_to_keep) return _export_graph( g, _get_saver(), checkpoint_path, export_dir, default_graph_signature=default_signature, named_graph_signatures=named_graph_signatures, exports_to_keep=exports_to_keep)
def every_n_step_end(self, step, outputs): super(ValidationMonitor, self).every_n_step_end(step, outputs) # TODO(mdan): The use of step below is probably misleading. # The code should probably use the step from the checkpoint, because # that's what is being evaluated. if self._estimator is None: raise ValueError("Missing call to set_estimator.") current_time = time.time() if (self._check_interval_secs is not None and self._last_checkpoint_check_time is not None and current_time - self._last_checkpoint_check_time <= self._check_interval_secs): logging.debug( "Skipping evaluation since less than %d seconds have passed since " "last check for a new checkpoint.", self._check_interval_secs) return False self._last_checkpoint_check_time = current_time # Check that we are not running evaluation on the same checkpoint. latest_path = checkpoint_management.latest_checkpoint( self._estimator.model_dir) if latest_path is None: logging.debug("Skipping evaluation since model has not been saved yet " "at step %d.", step) return False if latest_path is not None and latest_path == self._latest_path: logging.debug("Skipping evaluation due to same checkpoint %s for step %d " "as for step %d.", latest_path, step, self._latest_path_step) return False self._latest_path = latest_path self._latest_path_step = step # Run evaluation and log it. validation_outputs = self._evaluate_estimator() stats = [] for name in validation_outputs: stats.append("%s = %s" % (name, str(validation_outputs[name]))) logging.info("Validation (step %d): %s", step, ", ".join(stats)) # Early stopping logic. if self.early_stopping_rounds is not None: if self.early_stopping_metric not in validation_outputs: raise ValueError("Metric %s missing from outputs %s." % (self.early_stopping_metric, set(validation_outputs.keys()))) current_value = validation_outputs[self.early_stopping_metric] if (self._best_value is None or (self.early_stopping_metric_minimize and (current_value < self._best_value)) or (not self.early_stopping_metric_minimize and (current_value > self._best_value))): self._best_value = current_value self._best_metrics = copy.deepcopy(validation_outputs) self._best_value_step = step stop_now = (step - self._best_value_step >= self.early_stopping_rounds) if stop_now: logging.info("Stopping. Best step: {} with {} = {}.".format( self._best_value_step, self.early_stopping_metric, self._best_value)) self._early_stopped = True return True return False
def _get_checkpoint_filename(filepattern): """Returns checkpoint filename given directory or specific filepattern.""" if gfile.IsDirectory(filepattern): return checkpoint_management.latest_checkpoint(filepattern) return filepattern
def _latest_ckpt(self): return checkpoint_management.latest_checkpoint(self.get_temp_dir())
def _get_checkpoint_filename(ckpt_dir_or_file): """Returns checkpoint filename given directory or specific checkpoint file.""" if gfile.IsDirectory(ckpt_dir_or_file): return checkpoint_management.latest_checkpoint(ckpt_dir_or_file) return ckpt_dir_or_file
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if len(argv) > 1: FLAGS.predict_file = argv[1] validate_flags_or_throw(bert_config) tf.gfile.MakeDirs(FLAGS.output_dir) moran_tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case, use_moran=True) basic_tokenizer = tokenization.BasicTokenizer(use_moran=False) tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host)) train_examples = None num_train_steps = None num_warmup_steps = None if FLAGS.do_train : train_examples = read_squad_examples(input_file=FLAGS.train_file, is_training=True) num_train_steps = int( len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) # Pre-shuffle the input to avoid having to make a very large shuffle # buffer in in the `input_fn`. rng = random.Random(42) rng.shuffle(train_examples) model_fn = model_fn_builder( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=FLAGS.use_tpu, use_one_hot_embeddings=FLAGS.use_tpu) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. estimator = tf.contrib.tpu.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=model_fn, config=run_config, train_batch_size=FLAGS.train_batch_size, predict_batch_size=FLAGS.predict_batch_size) if FLAGS.do_train: train_record_exists = False train_writer = FeatureWriter( filename=os.path.join(FLAGS.output_dir, "train.tf_record"), is_training=True, record_file_exists=train_record_exists) convert_examples_to_features( examples=train_examples, tokenizer=moran_tokenizer, max_seq_length=FLAGS.max_seq_length, doc_stride=FLAGS.doc_stride, max_query_length=FLAGS.max_query_length, is_training=True, output_fn=train_writer.process_feature) train_writer.close() tf.logging.info("***** Running training *****") tf.logging.info(" Num orig examples = %d", len(train_examples)) tf.logging.info(" Num split examples = %d", train_writer.num_features) tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) tf.logging.info(" Num steps = %d", num_train_steps) del train_examples train_input_fn = input_fn_builder( input_file=train_writer.filename, seq_length=FLAGS.max_seq_length, is_training=True, drop_remainder=True) estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) if FLAGS.do_predict : output_prediction_file_name = "predictions.json" output_nbest_file_name = "nbest_predictions.json" output_null_log_odds_file_name = "null_odds.json" if FLAGS.korquad_refine_answer_by_pos: output_prediction_file_name = "predictions_pos.json" if FLAGS.do_predict: eval_examples = read_squad_examples(input_file=FLAGS.predict_file, is_training=False) eval_record_exists = os.path.exists(os.path.join(FLAGS.output_dir, "eval.tf_record")) eval_record_exists = False if eval_record_exists: tf.logging.info("eval.tf_record exists. Do not write tf example file.") eval_writer = FeatureWriter( filename=os.path.join(FLAGS.output_dir, "eval.tf_record"), is_training=False, record_file_exists=eval_record_exists) eval_features = [] def append_feature(feature): eval_features.append(feature) eval_writer.process_feature(feature) convert_examples_to_features( examples=eval_examples, tokenizer=moran_tokenizer, max_seq_length=FLAGS.max_seq_length, doc_stride=FLAGS.doc_stride, max_query_length=FLAGS.max_query_length, is_training=False, output_fn=append_feature) eval_writer.close() tf.logging.info("***** Running predictions *****") tf.logging.info(" Num orig examples = %d", len(eval_examples)) tf.logging.info(" Num split examples = %d", len(eval_features)) tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) all_results = [] predict_input_fn = input_fn_builder( input_file=eval_writer.filename, seq_length=FLAGS.max_seq_length, is_training=False, drop_remainder=False) if FLAGS.do_train: init_checkpoint = None else: if FLAGS.init_checkpoint is not None and tf.gfile.IsDirectory(FLAGS.init_checkpoint): from tensorflow.python.training import checkpoint_management init_checkpoint = checkpoint_management.latest_checkpoint(FLAGS.init_checkpoint) else: init_checkpoint = FLAGS.init_checkpoint all_results = [] for result in estimator.predict( predict_input_fn, yield_single_examples=True, checkpoint_path=init_checkpoint): unique_id = int(result["unique_ids"]) start_logits = [float(x) for x in result["start_logits"].flat] end_logits = [float(x) for x in result["end_logits"].flat] all_results.append( RawResult( unique_id=unique_id, start_logits=start_logits, end_logits=end_logits)) if len(argv) > 1: output_prediction_file = os.path.join(FLAGS.output_dir, argv[2]) else: output_prediction_file = os.path.join(FLAGS.output_dir, output_prediction_file_name) output_nbest_file = os.path.join(FLAGS.output_dir, output_nbest_file_name) output_null_log_odds_file = os.path.join(FLAGS.output_dir, output_null_log_odds_file_name) write_predictions(eval_examples, eval_features, all_results, FLAGS.n_best_size, FLAGS.max_answer_length, FLAGS.do_lower_case, output_prediction_file, output_nbest_file, output_null_log_odds_file, basic_tokenizer)
def test_initialize_if_not_restoring(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") optimizer_only_prefix = os.path.join(checkpoint_directory, "opt") with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = checkpointable_utils.Checkpoint( model=model, # Do not save the optimizer with the checkpoint. global_step=training_util.get_or_create_global_step()) optimizer_checkpoint = checkpointable_utils.Checkpoint( optimizer=optimizer) checkpoint_path = checkpoint_management.latest_checkpoint( checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( optimizer.minimize, functools.partial(model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() self.evaluate([v.initializer for v in optimizer.variables()]) train_fn() model_save_path = root.save(file_prefix=checkpoint_prefix) self.evaluate(optimizer.variables()[0].assign(42.)) optimizer_save_path = optimizer_checkpoint.save(optimizer_only_prefix) # Restore into a graph with the optimizer with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) status = root.restore(save_path=model_save_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( optimizer.minimize, functools.partial(model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() train_fn() with self.assertRaises(AssertionError): status.assert_existing_objects_matched() with self.assertRaises(AssertionError): status.assert_consumed() # Make sure initialization doesn't clobber later restores with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001, beta1=1.0) root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) opt_root = checkpointable_utils.Checkpoint( optimizer=optimizer) status = root.restore(save_path=model_save_path) init_only_optimizer_status = opt_root.restore(save_path=None) optimizer_status = opt_root.restore(save_path=optimizer_save_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( optimizer.minimize, functools.partial(model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) optimizer_status.run_restore_ops() status.initialize_or_restore() init_only_optimizer_status.initialize_or_restore() train_fn() self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
def _get_checkpoint_filename(ckpt_dir_or_file): """Returns checkpoint filename given directory or specific checkpoint file.""" if gfile.IsDirectory(ckpt_dir_or_file): return checkpoint_management.latest_checkpoint(ckpt_dir_or_file) return ckpt_dir_or_file
def continuous_train_and_eval(self, continuous_eval_predicate_fn=None): """Interleaves training and evaluation. The frequency of evaluation is controlled by the `train_steps_per_iteration` (via constructor). The model will be first trained for `train_steps_per_iteration`, and then be evaluated in turns. This method is intended for single machine usage. This differs from `train_and_evaluate` as follows: 1. The procedure will have train and evaluation in turns. The model will be trained for a number of steps (usually smaller than `train_steps` if provided) and then be evaluated. `train_and_evaluate` will train the model for `train_steps` (no small training iterations). 2. Due to the different approach this schedule takes, it leads to two differences in resource control. First, the resources (e.g., memory) used by training will be released before evaluation (`train_and_evaluate` takes double resources). Second, more checkpoints will be saved as a checkpoint is generated at the end of each training iteration. 3. As the estimator.train starts from scratch (new graph, new states for input, etc) at each iteration, it is recommended to have the `train_steps_per_iteration` larger. It is also recommended to shuffle your input. Args: continuous_eval_predicate_fn: A predicate function determining whether to continue eval after each iteration. A `predicate_fn` has one of the following signatures: * (eval_results) -> boolean * (eval_results, checkpoint_path) -> boolean Where `eval_results` is the dictionary of metric evaluations and checkpoint_path is the path to the checkpoint containing the parameters on which that evaluation was based. At the beginning of evaluation, the passed `eval_results` and `checkpoint_path` will be None so it's expected that the predicate function handles that gracefully. When `predicate_fn` is not specified, continuous eval will run in an infinite loop (if `train_steps` is None). or exit once global step reaches `train_steps`. Returns: A tuple of the result of the `evaluate` call to the `Estimator` and the export results using the specified `ExportStrategy`. Raises: ValueError: if `continuous_eval_predicate_fn` is neither None nor callable. """ if continuous_eval_predicate_fn is not None: if not callable(continuous_eval_predicate_fn): raise ValueError( "`continuous_eval_predicate_fn` must be a callable, or None.") predicate_fn = _get_standardized_predicate_fn( continuous_eval_predicate_fn) else: predicate_fn = None export_results = None latest_checkpoint = None eval_result = None # Set the default value for train_steps_per_iteration, which will be # overridden by other settings. train_steps_per_iteration = 1000 if self._train_steps_per_iteration is not None: train_steps_per_iteration = self._train_steps_per_iteration elif self._train_steps is not None: train_steps_per_iteration = int(self._train_steps / 10) while (not predicate_fn or predicate_fn( eval_result, checkpoint_path=latest_checkpoint if eval_result else None)): if self._has_training_stopped(eval_result): # Exits once max steps of training is satisfied. logging.info("Stop training model as max steps reached") break logging.info("Training model for %s steps", train_steps_per_iteration) self._call_train( input_fn=self._train_input_fn, steps=train_steps_per_iteration, hooks=self._train_monitors, saving_listeners=self._saving_listeners) logging.info("Evaluating model now.") latest_checkpoint = checkpoint_management.latest_checkpoint( self._estimator.model_dir) eval_result = self._call_evaluate( input_fn=self._eval_input_fn, steps=self._eval_steps, metrics=self._eval_metrics, name="one_pass", checkpoint_path=latest_checkpoint, hooks=self._eval_hooks) export_results = self._maybe_export(eval_result) return eval_result, export_results
def _get_most_recently_modified_file_matching_pattern(self, pattern): """Returns the most recently modified filepath matching pattern. Pattern may contain python formatting placeholder. If `tf.train.latest_checkpoint()` does not return None, use that; otherwise, check for most recently modified one that matches the pattern. In the rare case where there are more than one pattern-matching file having the same modified time that is most recent among all, return the filepath that is largest (by `>` operator, lexicographically using the numeric equivalents). This provides a tie-breaker when multiple files are most recent. Note that a larger `filepath` can sometimes indicate a later time of modification (for instance, when epoch/batch is used as formatting option), but not necessarily (when accuracy or loss is used). The tie-breaker is put in the logic as best effort to return the most recent, and to avoid undeterministic result. Modified time of a file is obtained with `os.path.getmtime()`. This utility function is best demonstrated via an example: ```python file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5' test_dir = self.get_temp_dir() path_pattern = os.path.join(test_dir, file_pattern) file_paths = [ os.path.join(test_dir, file_name) for file_name in ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.batch01epoch01.h5'] ] for file_path in file_paths: # Write something to each of the files self.assertEqual( _get_most_recently_modified_file_matching_pattern(path_pattern), file_paths[-1]) ``` Arguments: pattern: The file pattern that may optionally contain python placeholder such as `{epoch:02d}`. Returns: The most recently modified file's full filepath matching `pattern`. If `pattern` does not contain any placeholder, this returns the filepath that exactly matches `pattern`. Returns `None` if no match is found. """ dir_name = os.path.dirname(pattern) base_name = os.path.basename(pattern) base_name_regex = '^' + re.sub(r'{.*}', r'.*', base_name) + '$' # If tf.train.latest_checkpoint tells us there exists a latest checkpoint, # use that as it is more robust than `os.path.getmtime()`. latest_tf_checkpoint = checkpoint_management.latest_checkpoint( dir_name) if latest_tf_checkpoint is not None and re.match( base_name_regex, os.path.basename(latest_tf_checkpoint)): return latest_tf_checkpoint latest_mod_time = 0 file_path_with_latest_mod_time = None n_file_with_latest_mod_time = 0 file_path_with_largest_file_name = None if os.path.exists(dir_name): for file_name in os.listdir(dir_name): # Only consider if `file_name` matches the pattern. if re.match(base_name_regex, file_name): file_path = os.path.join(dir_name, file_name) mod_time = os.path.getmtime(file_path) if (file_path_with_largest_file_name is None or file_path > file_path_with_largest_file_name): file_path_with_largest_file_name = file_path if mod_time > latest_mod_time: latest_mod_time = mod_time file_path_with_latest_mod_time = file_path # In the case a file with later modified time is found, reset # the counter for the number of files with latest modified time. n_file_with_latest_mod_time = 1 elif mod_time == latest_mod_time: # In the case a file has modified time tied with the most recent, # increment the counter for the number of files with latest modified # time by 1. n_file_with_latest_mod_time += 1 if n_file_with_latest_mod_time == 1: # Return the sole file that has most recent modified time. return file_path_with_latest_mod_time else: # If there are more than one file having latest modified time, return # the file path with the largest file name. return file_path_with_largest_file_name
def _continuous_eval(self, input_fn, name, delay_secs, throttle_delay_secs, evaluate_checkpoint_only_once=True, continuous_eval_predicate_fn=None, export=True): """Run continuous eval. Runs infinite eval on the evaluation data set. This function starts evaluating after `delay_secs` seconds and then runs no more than one evaluation (with `self._eval_steps` steps each time) per `throttle_delay_secs`. If `train_steps` is not None, will return after global_step reaches `train_steps`. Args: input_fn: The input to use for this eval. name: A string appended to the folder name of evaluation results. delay_secs: Start evaluating after this many seconds. If None, defaults to self._eval_delay_secs. throttle_delay_secs: Do not re-evaluate unless the last evaluation was started at least this many seconds ago. If None, defaults to self._continuous_eval_throttle_secs. evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints that have already been evaluated. Default is `True`. continuous_eval_predicate_fn: A predicate function determining whether to continue eval after each iteration. A `predicate_fn` has one of the following signatures: * (eval_results) -> boolean * (eval_results, checkpoint_path) -> boolean Where `eval_results` is the dictionary of metric evaluations and checkpoint_path is the path to the checkpoint containing the parameters on which that evaluation was based. At the beginning of evaluation, the passed `eval_results` will be None so it's expected that the predicate function handles that gracefully. Continuous eval behavior under different conditions: * When `predicate_fn` is specified: + if `train_steps` is None, run until `predicate_fn` returns False. + if `train_steps` is specified, run until either global step reaches `train_steps` or `predicate_fn` returns False. * When `predicate_fn` is not specified: + if `train_steps` is None, run in an infinite loop. + if `train_steps` is specified, run until global step reaches `train_steps`. export: Whether to export from this step. Default is 'True'. Raises: ValueError: if `continuous_eval_predicate_fn` is neither None nor callable. """ if continuous_eval_predicate_fn is not None: if not callable(continuous_eval_predicate_fn): raise ValueError( "`continuous_eval_predicate_fn` must be a callable, or None.") predicate_fn = _get_standardized_predicate_fn( continuous_eval_predicate_fn) else: predicate_fn = None if delay_secs is None: delay_secs = self._eval_delay_secs if throttle_delay_secs is None: throttle_delay_secs = self._continuous_eval_throttle_secs if delay_secs: logging.info("Waiting %f secs before starting eval.", delay_secs) time.sleep(delay_secs) previous_path = None eval_result = None last_warning_time = 0 while (not predicate_fn or predicate_fn( eval_result, checkpoint_path=previous_path)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " "train_step=%s", eval_result[ops.GraphKeys.GLOBAL_STEP], self._train_steps) return start = time.time() error_msg = None latest_path = checkpoint_management.latest_checkpoint( self._estimator.model_dir) if not latest_path: error_msg = ("Estimator is not fitted yet. " "Will start an evaluation when a checkpoint is ready.") elif evaluate_checkpoint_only_once and latest_path == previous_path: error_msg = "No new checkpoint ready for evaluation." if error_msg: # Print warning message every 10 mins. eval_result = {} if time.time() - last_warning_time > 600: logging.warning(error_msg) last_warning_time = time.time() else: eval_result = self._call_evaluate( input_fn=input_fn, steps=self._eval_steps, metrics=self._eval_metrics, name=name, checkpoint_path=latest_path, hooks=self._eval_hooks) # Ensure eval result is not None for next round of evaluation. if not eval_result: eval_result = {} if export: self._maybe_export(eval_result, checkpoint_path=latest_path) # Clear warning timer and update last evaluated checkpoint last_warning_time = 0 previous_path = latest_path duration = time.time() - start if duration < throttle_delay_secs: difference = throttle_delay_secs - duration logging.info("Waiting %f secs before starting next eval run.", difference) time.sleep(difference)
def testTrainWithInitFromCheckpoint(self): logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/') logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/') if gfile.Exists(logdir1): # For running on jenkins. gfile.DeleteRecursively(logdir1) if gfile.Exists(logdir2): # For running on jenkins. gfile.DeleteRecursively(logdir2) # First, train the model one step (make sure the error is high). with ops.Graph().as_default(): random_seed.set_random_seed(0) train_op = self.create_train_op() saver = saver_lib.Saver() loss = training.train( train_op, logdir1, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir1, save_steps=1, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=1), ], save_checkpoint_secs=None, save_summaries_steps=None) self.assertGreater(loss, .5) # Next, train the model to convergence. with ops.Graph().as_default(): random_seed.set_random_seed(1) train_op = self.create_train_op() saver = saver_lib.Saver() loss = training.train( train_op, logdir1, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir1, save_steps=300, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=300), ], save_checkpoint_secs=None, save_summaries_steps=None) self.assertIsNotNone(loss) self.assertLess(loss, .02) # Finally, advance the model a single step and validate that the loss is # still low. with ops.Graph().as_default(): random_seed.set_random_seed(2) train_op = self.create_train_op() model_variables = variables_lib2.global_variables() model_path = checkpoint_management.latest_checkpoint(logdir1) assign_fn = variables_lib.assign_from_checkpoint_fn( model_path, model_variables) def init_fn(_, session): assign_fn(session) loss = training.train( train_op, None, scaffold=monitored_session.Scaffold(init_fn=init_fn), hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)], save_checkpoint_secs=None, save_summaries_steps=None) self.assertIsNotNone(loss) self.assertLess(loss, .02)
import re import sys import os from tensorflow.python.training import checkpoint_management from shutil import copyfile if len(sys.argv) < 2: print("Usage: fix_embed_mask_name.py [chkpt-dir]") print( "Fix the latest checkpoint in the directory by renaming the embedding mask variable." ) print("Note: overwrites original checkpoint!") exit() else: CHECKPOINT_DIR = sys.argv[1].rstrip('/') CHECKPOINT_FILE = checkpoint_management.latest_checkpoint(CHECKPOINT_DIR) with tf.Session() as sess: # Load all the variables from the checkpoint, renaming as we go for var_name, _ in tf.train.list_variables(CHECKPOINT_FILE): var_tensor = tf.contrib.framework.load_variable( CHECKPOINT_FILE, var_name) new_name = re.sub(r'bert/embeddings//mask', r'bert/embeddings/embed_mask/mask', var_name) new_name = re.sub(r'bert/embeddings//threshold', r'bert/embeddings/embed_mask/threshold', new_name) if new_name != var_name: print(f"Renaming {var_name} to {new_name}")
def _get_checkpoint_filename(filepattern): """Returns checkpoint filename given directory or specific filepattern.""" if gfile.IsDirectory(filepattern): return checkpoint_management.latest_checkpoint(filepattern) return filepattern