def _get_train_op_and_ensemble(self, head, config, is_classification, train_in_memory): """Calls bt_model_fn() and returns the train_op and ensemble_serialzed.""" features, labels = _make_train_input_fn(is_classification)() estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access features=features, labels=labels, mode=model_fn.ModeKeys.TRAIN, head=head, feature_columns=self._feature_columns, tree_hparams=self._tree_hparams, example_id_column_name=EXAMPLE_ID_COLUMN, n_batches_per_layer=1, config=config, train_in_memory=train_in_memory) resources.initialize_resources(resources.shared_resources()).run() variables.global_variables_initializer().run() variables.local_variables_initializer().run() # Gets the train_op and serialized proto of the ensemble. shared_resources = resources.shared_resources() self.assertEqual(1, len(shared_resources)) train_op = estimator_spec.train_op with ops.control_dependencies([train_op]): _, ensemble_serialized = ( gen_boosted_trees_ops.boosted_trees_serialize_ensemble( shared_resources[0].handle)) return train_op, ensemble_serialized
def test_sync_replicas(self, create_gan_model_fn, create_global_step): model = create_gan_model_fn() loss = train.gan_loss(model) num_trainable_vars = len(variables_lib.get_trainable_variables()) if create_global_step: gstep = variable_scope.get_variable( 'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False) ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops( model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # No new trainable variables should have been added. self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars) # Sync hooks should be populated in the GANTrainOps. self.assertLen(train_ops.train_hooks, 2) for hook in train_ops.train_hooks: self.assertIsInstance( hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks] self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) # Check that update op is run properly. global_step = training_util.get_or_create_global_step() with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() variables.local_variables_initializer().run() g_opt.chief_init_op.run() d_opt.chief_init_op.run() gstep_before = global_step.eval() # Start required queue runner for SyncReplicasOptimizer. coord = coordinator.Coordinator() g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord) d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord) g_sync_init_op.run() d_sync_init_op.run() train_ops.generator_train_op.eval() # Check that global step wasn't incremented. self.assertEqual(gstep_before, global_step.eval()) train_ops.discriminator_train_op.eval() # Check that global step wasn't incremented. self.assertEqual(gstep_before, global_step.eval()) coord.request_stop() coord.join(g_threads + d_threads)
def test_empty_labels_and_scores_gives_nan_auc(self): with self.cached_session(): labels = constant_op.constant([], shape=[0], dtype=dtypes.bool) scores = constant_op.constant([], shape=[0], dtype=dtypes.float32) score_range = [0, 1.] auc, update_op = histogram_ops.auc_using_histogram(labels, scores, score_range) variables.local_variables_initializer().run() update_op.run() self.assertTrue(np.isnan(auc.eval()))
def testAccuracy(self): predictions = constant_op.constant([0, 1, 3, 6, 5, 2, 7, 6, 4, 9]) targets = constant_op.constant([0, 1, 4, 6, 5, 1, 7, 5, 4, 8]) accuracy_op, update_op = eval_metrics._accuracy(predictions, targets) with self.test_session(): variables.local_variables_initializer().run() # need to call in order to run the accuracy_op internal operations because # it is a streaming function update_op.eval() self.assertNear(0.6, accuracy_op.eval(), 0.0001)
def testMetricsCollection(self): def _enqueue_vector(sess, queue, values, shape=None): if not shape: shape = (1, len(values)) dtype = queue.dtypes[0] sess.run( queue.enqueue(constant_op.constant( values, dtype=dtype, shape=shape))) meta_graph_filename = os.path.join( _TestDir("metrics_export"), "meta_graph.pb") graph = ops.Graph() with self.session(graph=graph) as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) _enqueue_vector(sess, values_queue, [-4.2, 9.1]) _enqueue_vector(sess, values_queue, [6.5, 0]) _enqueue_vector(sess, values_queue, [-3.2, 4.0]) values = values_queue.dequeue() _, update_op = metrics.mean(values) initializer = variables.local_variables_initializer() self.evaluate(initializer) self.evaluate(update_op) meta_graph.export_scoped_meta_graph( filename=meta_graph_filename, graph=graph) # Verifies that importing a meta_graph with LOCAL_VARIABLES collection # works correctly. graph = ops.Graph() with self.session(graph=graph) as sess: meta_graph.import_scoped_meta_graph(meta_graph_filename) initializer = variables.local_variables_initializer() self.evaluate(initializer) # Verifies that importing an old meta_graph where "local_variables" # collection is of node_list type works, but cannot build initializer # with the collection. graph = ops.Graph() with self.session(graph=graph) as sess: meta_graph.import_scoped_meta_graph( test.test_src_dir_path( "python/framework/testdata/metrics_export_meta_graph.pb")) self.assertEqual(len(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)), 2) with self.assertRaisesRegexp( AttributeError, "'Tensor' object has no attribute 'initializer'"): initializer = variables.local_variables_initializer()
def _check_auc(self, nbins=100, desired_auc=0.75, score_range=None, num_records=50, frac_true=0.5, atol=0.05, num_updates=10): """Check auc accuracy against synthetic data. Args: nbins: nbins arg from contrib.metrics.auc_using_histogram. desired_auc: Number in [0, 1]. The desired auc for synthetic data. score_range: 2-tuple, (low, high), giving the range of the resultant scores. Defaults to [0, 1.]. num_records: Positive integer. The number of records to return. frac_true: Number in (0, 1). Expected fraction of resultant labels that will be True. This is just in expectation...more or less may actually be True. atol: Absolute tolerance for final AUC estimate. num_updates: Update internal histograms this many times, each with a new batch of synthetic data, before computing final AUC. Raises: AssertionError: If resultant AUC is not within atol of theoretical AUC from synthetic data. """ score_range = [0, 1.] or score_range with self.cached_session(): labels = array_ops.placeholder(dtypes.bool, shape=[num_records]) scores = array_ops.placeholder(dtypes.float32, shape=[num_records]) auc, update_op = histogram_ops.auc_using_histogram( labels, scores, score_range, nbins=nbins) variables.local_variables_initializer().run() # Updates, then extract auc. for _ in range(num_updates): labels_a, scores_a = synthetic_data(desired_auc, score_range, num_records, self.rng, frac_true) update_op.run(feed_dict={labels: labels_a, scores: scores_a}) labels_a, scores_a = synthetic_data(desired_auc, score_range, num_records, self.rng, frac_true) # Fetch current auc, and verify that fetching again doesn't change it. auc_eval = auc.eval() self.assertAlmostEqual(auc_eval, auc.eval(), places=5) msg = ('nbins: %s, desired_auc: %s, score_range: %s, ' 'num_records: %s, frac_true: %s, num_updates: %s') % (nbins, desired_auc, score_range, num_records, frac_true, num_updates) np.testing.assert_allclose(desired_auc, auc_eval, atol=atol, err_msg=msg)
def testTop2(self): top_2_fn = eval_metrics._top_k_generator(2) probabilities = constant_op.constant([[0.1, 0.2, 0.3], [0.4, 0.7, 0.5], [0.9, 0.8, 0.2], [0.6, 0.4, 0.8]]) targets = constant_op.constant([[0], [2], [1], [1]]) in_top_2_op, update_op = top_2_fn(probabilities, targets) with self.test_session(): # initializes internal accuracy vars variables.local_variables_initializer().run() # need to call in order to run the in_top_2_op internal operations because # it is a streaming function update_op.eval() self.assertNear(0.5, in_top_2_op.eval(), 0.0001)
def testR2(self): scores = constant_op.constant( [1.2, 3.9, 2.1, 0.9, 2.2, 0.1, 6.0, 4.0, 0.9]) targets = constant_op.constant( [1.0, 4.3, 2.6, 0.5, 1.1, 0.7, 5.1, 3.4, 1.8]) r2_op, update_op = eval_metrics._r2(scores, targets) with self.test_session(): # initializes internal accuracy vars variables.local_variables_initializer().run() # need to call in order to run the r2_op internal operations because # it is a streaming function update_op.eval() self.assertNear(0.813583, r2_op.eval(), 0.0001)
def testLargeCase(self): shape = [32, 512, 256, 1] predictions = random_ops.random_uniform( shape, 0.0, 1.0, dtype=dtypes_lib.float32) labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) result, update_op = metric_ops.precision_recall_at_equal_thresholds( labels=labels, predictions=predictions, num_thresholds=201) # Run many updates, enough to cause highly inaccurate values if the # code used float32 for accumulation. num_updates = 71 with self.test_session() as sess: sess.run(variables.local_variables_initializer()) for _ in xrange(num_updates): sess.run(update_op) prdata = sess.run(result) # Since we use random values, we won't know the tp/fp/tn/fn values, but # tp and fp at threshold 0 should be the total number of positive and # negative labels, hence their sum should be total number of pixels. expected_value = 1.0 * np.product(shape) * num_updates got_value = prdata.tp[0] + prdata.fp[0] # They should be at least within 1. self.assertNear(got_value, expected_value, 1.0)
def _random_window_input_fn_test_template( self, time_series_reader, window_size, batch_size, num_features, discard_out_of_order=False): input_fn = input_pipeline.RandomWindowInputFn( time_series_reader=time_series_reader, window_size=window_size, batch_size=batch_size) result, _ = input_fn() init_op = variables.local_variables_initializer() with self.cached_session() as session: coordinator = coordinator_lib.Coordinator() queue_runner_impl.start_queue_runners(session, coord=coordinator) session.run(init_op) features = session.run(result) coordinator.request_stop() coordinator.join() self.assertAllEqual([batch_size, window_size], features[TrainEvalFeatures.TIMES].shape) for window_position in range(window_size - 1): for batch_position in range(batch_size): # Checks that all times are contiguous self.assertEqual( features[TrainEvalFeatures.TIMES][batch_position, window_position + 1], features[TrainEvalFeatures.TIMES][batch_position, window_position] + 1) self.assertAllEqual([batch_size, window_size, num_features], features[TrainEvalFeatures.VALUES].shape) self.assertEqual("int64", features[TrainEvalFeatures.TIMES].dtype) for feature_number in range(num_features): self.assertAllEqual( features[TrainEvalFeatures.TIMES] * 2. + feature_number, features[TrainEvalFeatures.VALUES][:, :, feature_number]) return features
def test_batch_text_lines(self): gfile.Glob = self._orig_glob filename = self._create_temp_file("A\nB\nC\nD\nE\n") batch_size = 3 queue_capacity = 10 name = "my_batch" with ops.Graph().as_default() as g, self.test_session(graph=g) as session: inputs = graph_io.read_batch_examples( [filename], batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, queue_capacity=queue_capacity, read_batch_size=10, name=name) self.assertAllEqual((None,), inputs.get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"]) self.assertAllEqual(session.run(inputs), [b"D", b"E"]) with self.assertRaises(errors.OutOfRangeError): session.run(inputs) coord.request_stop() coord.join(threads)
def begin(self): self._local_init_op = variables.local_variables_initializer() self._global_init_op = None if self._is_chief: self._global_init_op = variables.global_variables_initializer() self._chief_init_op = self._ma_optimizer._chief_init_op # pylint: disable=protected-access self._variable_init_op = self._ma_optimizer.get_init_op()
def testFinalOpsOnEvaluationLoop(self): value_op, update_op = metric_ops.streaming_accuracy(self._predictions, self._labels) init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) # Create Checkpoint and log directories chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/') gfile.MakeDirs(chkpt_dir) logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/') gfile.MakeDirs(logdir) # Save initialized variables to checkpoint directory saver = saver_lib.Saver() with self.test_session() as sess: init_op.run() saver.save(sess, os.path.join(chkpt_dir, 'chkpt')) # Now, run the evaluation loop: accuracy_value = evaluation.evaluation_loop( '', chkpt_dir, logdir, eval_op=update_op, final_op=value_op, max_number_of_evaluations=1) self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): with ops.Graph().as_default(), distribution.scope(): iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() value, update = distribution.call_for_each_tower( metric_fn, iterator.get_next()) update = distribution.group(update) self.evaluate(variables.local_variables_initializer()) # TODO(josh11b): Once we switch to using a global batch size for input, # replace "distribution.num_towers" with "1". batches_per_update = distribution.num_towers # Update variables using the first `num_towers` batches. self.evaluate(update) self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), 0.001, msg="After first update") # Update variables using the second `num_towers` batches. self.evaluate(update) self.assertAllClose(expected_fn(2 * batches_per_update), self.evaluate(value), 0.001, msg="After second update") if batches_per_update == 1: # Consume 4 input batches self.evaluate(update) self.assertAllClose(expected_fn(3 * batches_per_update), self.evaluate(value), 0.001, msg="After third update") self.evaluate(update) self.assertAllClose(expected_fn(4 * batches_per_update), self.evaluate(value), 0.001, msg="After fourth update")
def _all_window_input_fn_test_template( self, time_series_reader, num_samples, window_size, original_numpy_features=None): input_fn = test_utils.AllWindowInputFn( time_series_reader=time_series_reader, window_size=window_size) features, _ = input_fn() init_op = variables.local_variables_initializer() with self.cached_session() as session: coordinator = coordinator_lib.Coordinator() queue_runner_impl.start_queue_runners(session, coord=coordinator) session.run(init_op) chunked_times, chunked_values = session.run( [features[TrainEvalFeatures.TIMES], features[TrainEvalFeatures.VALUES]]) coordinator.request_stop() coordinator.join() self.assertAllEqual([num_samples - window_size + 1, window_size], chunked_times.shape) if original_numpy_features is not None: original_times = original_numpy_features[TrainEvalFeatures.TIMES] original_values = original_numpy_features[TrainEvalFeatures.VALUES] self.assertAllEqual(original_times, numpy.unique(chunked_times)) self.assertAllEqual(original_values[chunked_times], chunked_values)
def testWithEpochLimit(self): predictions_limited = input.limit_epochs(self._predictions, num_epochs=1) labels_limited = input.limit_epochs(self._labels, num_epochs=1) value_op, update_op = metric_ops.streaming_accuracy( predictions_limited, labels_limited) init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) # Create checkpoint and log directories: chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/') gfile.MakeDirs(chkpt_dir) logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/') gfile.MakeDirs(logdir) # Save initialized variables to a checkpoint directory: saver = saver_lib.Saver() with self.test_session() as sess: init_op.run() saver.save(sess, os.path.join(chkpt_dir, 'chkpt')) # Now, run the evaluation loop: accuracy_value = evaluation.evaluation_loop( '', chkpt_dir, logdir, eval_op=update_op, final_op=value_op, max_number_of_evaluations=1, num_evals=10000) self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
def testWhileLoop(self): with self.cached_session(): r_ = rate.Rate() def body(value, denom, i, ret_rate): i += 1 ret_rate = r_(value, denom) with ops.control_dependencies([ret_rate]): value = math_ops.add(value, 2) denom = math_ops.add(denom, 1) return [value, denom, i, ret_rate] def condition(v, d, i, r): del v, d, r # unused vars by condition return math_ops.less(i, 100) i = constant_op.constant(0) value = constant_op.constant([1], dtype=dtypes.float64) denom = constant_op.constant([1], dtype=dtypes.float64) ret_rate = r_(value, denom) self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.local_variables_initializer()) loop = control_flow_ops.while_loop(condition, body, [value, denom, i, ret_rate]) self.assertEqual([[2]], self.evaluate(loop[3]))
def test_eval_single_tower(self): features = np.array([[0.01], [0.002]]) labels = np.array([[0.01], [0.02]]) with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.EVAL, self.params) session.run(variables.local_variables_initializer()) session.run(variables.global_variables_initializer()) accuracy, a = estimator_spec.eval_metric_ops['accuracy'] auc, b = estimator_spec.eval_metric_ops['auc'] session.run([a, b]) accuracy = session.run(accuracy) auc = session.run(auc) # Accuracy is 0.0 (no match) in the first tower. # Accuracy is 1.0 (match) in the second tower, since the feature # times weight "c" happened to be equal to the label. total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) self.assertEqual(0, auc) self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
def _prepareCheckpoint(self, checkpoint_path): init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) with self.test_session() as sess: sess.run(init_op) saver.save(sess, checkpoint_path)
def test_example(self): with self.test_session() as session: tower_losses = map(self.create_constant_loss, [2, 4, 6]) tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3]) tower_specs = [ self.create_estimator_spec(l, m) for l, m in zip(tower_losses, tower_metrics) ] session.run(variables.local_variables_initializer()) estimator_spec = replicate_model_fn._eval_spec( tower_specs, aggregation_device='/device:GPU:0') accuracy, a = estimator_spec.eval_metric_ops['accuracy'] auc, b = estimator_spec.eval_metric_ops['auc'] self.assertEqual('/device:CPU:0', accuracy.device) self.assertEqual('/device:CPU:0', auc.device) session.run([a, b]) accuracy, auc = session.run([accuracy, auc]) self.assertNear((12 - 2) / 12, accuracy, 0.01) self.assertEqual(0, auc) self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
def test_handles_single_tower(self): with self.test_session() as session: tower_losses = map(self.create_constant_loss, [5]) tower_metrics = map(self.create_eval_metrics, [0.2]) tower_specs = [ self.create_estimator_spec(l, m) for l, m in zip(tower_losses, tower_metrics) ] session.run(variables.local_variables_initializer()) estimator_spec = replicate_model_fn._eval_spec( tower_specs, aggregation_device='/device:GPU:0') accuracy, a = estimator_spec.eval_metric_ops['accuracy'] auc, b = estimator_spec.eval_metric_ops['auc'] self.assertEqual('/device:CPU:0', accuracy.device) self.assertEqual('/device:CPU:0', auc.device) session.run([a, b]) accuracy = session.run(accuracy) auc = session.run(auc) self.assertNear((4 - 1) / 4, accuracy, 0.01) self.assertEqual(0, auc) self.assertEqual(5, session.run(estimator_spec.loss))
def test_keyed_features_filter(self): gfile.Glob = self._orig_glob lines = [ '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [3]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [5]}}}}}' ] filename = self._create_temp_file("\n".join(lines)) batch_size = 2 queue_capacity = 4 name = "my_batch" features = {"age": parsing_ops.FixedLenFeature([], dtypes_lib.int64)} def filter_fn(keys, examples_json): del keys serialized = parsing_ops.decode_json_example(examples_json) examples = parsing_ops.parse_example(serialized, features) return math_ops.less(examples["age"], 2) with ops.Graph().as_default() as g, self.session(graph=g) as session: keys, inputs = graph_io._read_keyed_batch_examples_helper( filename, batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, read_batch_size=batch_size, queue_capacity=queue_capacity, filter_fn=filter_fn, name=name) self.assertAllEqual((None,), keys.get_shape().as_list()) self.assertAllEqual((None,), inputs.get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) # First batch of two filtered examples. out_keys, out_vals = session.run((keys, inputs)) self.assertAllEqual( [filename.encode("utf-8") + b":2", filename.encode("utf-8") + b":3"], out_keys) self.assertAllEqual([lines[1].encode("utf-8"), lines[2].encode("utf-8")], out_vals) # Second batch will only have one filtered example as that's the only # remaining example that satisfies the filtering criterion. out_keys, out_vals = session.run((keys, inputs)) self.assertAllEqual([filename.encode("utf-8") + b":4"], out_keys) self.assertAllEqual([lines[3].encode("utf-8")], out_vals) # Exhausted input. with self.assertRaises(errors.OutOfRangeError): session.run((keys, inputs)) coord.request_stop() coord.join(threads)
def export(self, last_checkpoint, output_dir): """Builds a prediction graph and xports the model. Args: last_checkpoint: Path to the latest checkpoint file from training. output_dir: Path to the folder to be used to output the model. """ logging.info('Exporting prediction graph to %s', output_dir) with tf.Session(graph=tf.Graph()) as sess: # Build and save prediction meta graph and trained variable values. inputs, outputs = self.build_prediction_graph() signature_def_map = { 'serving_default': signature_def_utils.predict_signature_def(inputs, outputs) } init_op = tf.global_variables_initializer() sess.run(init_op) self.restore_from_checkpoint(sess, self.inception_checkpoint_file, last_checkpoint) init_op_serving = control_flow_ops.group( variables.local_variables_initializer(), tf.tables_initializer()) builder = saved_model_builder.SavedModelBuilder(output_dir) builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map=signature_def_map, legacy_init_op=init_op_serving) builder.save(False)
def test_metrics_consistent(self): # Tests that the identity metrics used to report in-sample predictions match # the behavior of standard metrics. g = ops.Graph() with g.as_default(): features = { feature_keys.TrainEvalFeatures.TIMES: array_ops.zeros((1, 1)), feature_keys.TrainEvalFeatures.VALUES: array_ops.zeros((1, 1, 1)), "ticker": array_ops.reshape( math_ops.cast( variables.VariableV1( name="ticker", initial_value=0, dtype=dtypes.int64, collections=[ops.GraphKeys.LOCAL_VARIABLES]) .count_up_to(10), dtype=dtypes.float32), (1, 1, 1)) } model_fn = ts_head_lib.TimeSeriesRegressionHead( model=_TickerModel(), state_manager=state_management.PassthroughStateManager(), optimizer=train.GradientDescentOptimizer(0.001)).create_estimator_spec outputs = model_fn( features=features, labels=None, mode=estimator_lib.ModeKeys.EVAL) metric_update_ops = [ metric[1] for metric in outputs.eval_metric_ops.values()] loss_mean, loss_update = metrics.mean(outputs.loss) metric_update_ops.append(loss_update) with self.cached_session() as sess: coordinator = coordinator_lib.Coordinator() queue_runner_impl.start_queue_runners(sess, coord=coordinator) variables.local_variables_initializer().run() sess.run(metric_update_ops) loss_evaled, metric_evaled, nested_metric_evaled = sess.run( (loss_mean, outputs.eval_metric_ops["ticker"][0], outputs.eval_metric_ops[feature_keys.FilteringResults.STATE_TUPLE][ 0][0])) # The custom model_utils metrics for in-sample predictions should be in # sync with the Estimator's mean metric for model loss. self.assertAllClose(0., loss_evaled) self.assertAllClose((((0.,),),), metric_evaled) self.assertAllClose((((0.,),),), nested_metric_evaled) coordinator.request_stop() coordinator.join()
def test_read_text_lines_large(self): gfile.Glob = self._orig_glob sequence_prefix = "abcdefghijklmnopqrstuvwxyz123456789" num_records = 49999 lines = [ "".join([sequence_prefix, str(l)]).encode("ascii") for l in xrange(num_records) ] json_lines = [ "".join([ '{"features": { "feature": { "sequence": {', '"bytes_list": { "value": ["', base64.b64encode(l).decode("ascii"), '"]}}}}}\n' ]) for l in lines ] filename = self._create_temp_file("".join(json_lines)) batch_size = 10000 queue_capacity = 10000 name = "my_large_batch" features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)} with ops.Graph().as_default() as g, self.test_session(graph=g) as session: keys, result = graph_io.read_keyed_batch_features( filename, batch_size, features, io_ops.TextLineReader, randomize_input=False, num_epochs=1, queue_capacity=queue_capacity, num_enqueue_threads=2, parse_fn=parsing_ops.decode_json_example, name=name) self.assertAllEqual((None,), keys.get_shape().as_list()) self.assertEqual(1, len(result)) self.assertAllEqual((None,), result["sequence"].get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) data = [] try: while not coord.should_stop(): data.append(session.run(result)) except errors.OutOfRangeError: pass finally: coord.request_stop() coord.join(threads) parsed_records = [ item for sublist in [d["sequence"] for d in data] for item in sublist ] # Check that the number of records matches expected and all records # are present. self.assertEqual(len(parsed_records), num_records) self.assertEqual(set(parsed_records), set(lines))
def _get_local_init_op(): local_init_op = _get_first_op_from_collection(ops.GraphKeys.LOCAL_INIT_OP) if local_init_op is None: op_list = [variables.local_variables_initializer(), data_flow_ops.initialize_all_tables()] if op_list: local_init_op = control_flow_ops.group(*op_list) ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) return local_init_op
def testWithMultipleUpdates(self): num_samples = 1000 batch_size = 10 num_batches = int(num_samples / batch_size) # Create the labels and data. labels = np.random.randint(0, 2, size=(num_samples, 1)) noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) predictions = 0.4 + 0.2 * labels + noise predictions[predictions > 1] = 1 predictions[predictions < 0] = 0 thresholds = [-0.01, 0.5, 1.01] expected_max_f1 = -1.0 for threshold in thresholds: tp = 0 fp = 0 fn = 0 tn = 0 for i in range(num_samples): if predictions[i] >= threshold: if labels[i] == 1: tp += 1 else: fp += 1 else: if labels[i] == 1: fn += 1 else: tn += 1 epsilon = 1e-7 expected_prec = tp / (epsilon + tp + fp) expected_rec = tp / (epsilon + tp + fn) expected_f1 = (2 * expected_prec * expected_rec / (epsilon + expected_prec + expected_rec)) if expected_f1 > expected_max_f1: expected_max_f1 = expected_f1 labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) tf_predictions, tf_labels = (dataset_ops.Dataset .from_tensor_slices((predictions, labels)) .repeat() .batch(batch_size) .make_one_shot_iterator() .get_next()) f1, f1_op = classification.f1_score(tf_labels, tf_predictions, num_thresholds=3) with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for _ in range(num_batches): sess.run([f1_op]) # Since this is only approximate, we can't expect a 6 digits match. # Although with higher number of samples/thresholds we should see the # accuracy improving self.assertAlmostEqual(expected_max_f1, f1.eval(), 2)
def testSeekNextLimitEpochs(self): string_list = ["a", "b", "c"] with self.test_session() as session: elem = input_pipeline_ops.seek_next(string_list, num_epochs=1) session.run([ variables.local_variables_initializer(), variables.global_variables_initializer() ]) self._assert_output([b"a", b"b", b"c"], session, elem)
def testZeroLabelsPredictions(self): with self.cached_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes.float32) labels = array_ops.zeros([4]) f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) sess.run(variables.local_variables_initializer()) sess.run([f1_op]) self.assertAlmostEqual(0.0, f1.eval(), places=5)
def testSeekNextLimitEpochsThree(self): string_list = ["a", "b", "c"] with self.cached_session() as session: elem = input_pipeline_ops.seek_next(string_list, num_epochs=3) session.run([ variables.local_variables_initializer(), variables.global_variables_initializer() ]) # Expect to see [a, b, c] three times. self._assert_output([b"a", b"b", b"c"] * 3, session, elem)
def test_keyed_features_filter(self): gfile.Glob = self._orig_glob lines = [ '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [3]}}}}}', '{"features": {"feature": {"age": {"int64_list": {"value": [5]}}}}}' ] filename = self._create_temp_file("\n".join(lines)) batch_size = 2 queue_capacity = 4 name = "my_batch" features = {"age": parsing_ops.FixedLenFeature([], dtypes_lib.int64)} def filter_fn(keys, examples_json): del keys serialized = parsing_ops.decode_json_example(examples_json) examples = parsing_ops.parse_example(serialized, features) return math_ops.less(examples["age"], 2) with ops.Graph().as_default() as g, self.test_session( graph=g) as session: keys, inputs = graph_io._read_keyed_batch_examples_helper( filename, batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, read_batch_size=batch_size, queue_capacity=queue_capacity, filter_fn=filter_fn, name=name) self.assertAllEqual((None, ), keys.get_shape().as_list()) self.assertAllEqual((None, ), inputs.get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) # First batch of two filtered examples. out_keys, out_vals = session.run((keys, inputs)) self.assertAllEqual([ filename.encode("utf-8") + b":2", filename.encode("utf-8") + b":3" ], out_keys) self.assertAllEqual( [lines[1].encode("utf-8"), lines[2].encode("utf-8")], out_vals) # Second batch will only have one filtered example as that's the only # remaining example that satisfies the filtering criterion. out_keys, out_vals = session.run((keys, inputs)) self.assertAllEqual([filename.encode("utf-8") + b":4"], out_keys) self.assertAllEqual([lines[3].encode("utf-8")], out_vals) # Exhausted input. with self.assertRaises(errors.OutOfRangeError): session.run((keys, inputs)) coord.request_stop() coord.join(threads)
def _default_local_init_op(): return control_flow_ops.group( variables.local_variables_initializer(), lookup_ops.tables_initializer(), resources.initialize_resources(resources.local_resources()))
def do_training(train_op, init_fn=None, summary_op=None, lr=None): global savers graph = ops.get_default_graph() with graph.as_default(): global_step = variables.get_or_create_global_step() saver = tf_saver.Saver(max_to_keep=0) with ops.name_scope('init_ops'): init_op = tf_variables.global_variables_initializer() ready_op = tf_variables.report_uninitialized_variables() local_init_op = control_flow_ops.group( tf_variables.local_variables_initializer(), data_flow_ops.tables_initializer()) summary_writer = supervisor.Supervisor.USE_DEFAULT with ops.name_scope('train_step'): train_step_kwargs = {} if not FLAGS.max_number_of_steps is None: should_stop_op = math_ops.greater_equal( global_step, FLAGS.max_number_of_steps) else: should_stop_op = constant_op.constant(False) train_step_kwargs['should_stop'] = should_stop_op if FLAGS.log_every_n_steps > 0: train_step_kwargs['should_log'] = math_ops.equal( math_ops.mod(global_step, FLAGS.log_every_n_steps), 0) prefix = "loc/net" lp = len(prefix) vdic = { "InceptionV2" + v.op.name[lp:]: v for v in tf.trainable_variables() if v.name.startswith(prefix) and v.name.find("Logits/") < 0 } _saver = tf_saver.Saver(vdic) savers.append(_saver) for i in xrange(NUM_STN): prefix = "stn%d/net" % i lp = len(prefix) vdic = { "InceptionV2" + v.op.name[lp:]: v for v in tf.trainable_variables() if v.name.startswith(prefix) and v.name.find("Logits/") < 0 } #saver = tf.train.Saver(vdic) _saver = tf_saver.Saver(vdic) savers.append(_saver) prt("savers %d" % len(savers)) is_chief = True logdir = FLAGS.train_dir sv = supervisor.Supervisor(graph=graph, is_chief=is_chief, logdir=logdir, init_op=init_op, init_feed_dict=None, local_init_op=local_init_op, ready_for_local_init_op=None, ready_op=ready_op, summary_op=summary_op, summary_writer=summary_writer, global_step=global_step, saver=saver, save_summaries_secs=FLAGS.save_summaries_secs, save_model_secs=FLAGS.save_interval_secs, init_fn=init_fn) if summary_writer is not None: train_step_kwargs['summary_writer'] = sv.summary_writer with sv.managed_session('', start_standard_services=False, config=None) as sess: logging.info('Starting Session.') if is_chief: if logdir: sv.start_standard_services(sess) elif startup_delay_steps > 0: _wait_for_step( sess, global_step, min(startup_delay_steps, number_of_steps or sys.maxint)) sv.start_queue_runners(sess) logging.info('Starting Queues.') try: while not sv.should_stop(): total_loss, global_step_value, should_stop = train_step( sess, train_op, global_step, lr, train_step_kwargs) current_epoch = int( math.ceil(float(global_step_value) / FLAGS.steps_in_epoch)) if global_step_value > 0 and global_step_value % FLAGS.save_every_n_steps == 0: sv.saver.save(sess, sv.save_path, global_step=sv.global_step) if should_stop: logging.info('Stopping Training.') break except errors.OutOfRangeError: # OutOfRangeError is thrown when epoch limit per # tf.train.limit_epochs is reached. logging.info('Caught OutOfRangeError. Stopping Training.') if logdir and sv.is_chief: logging.info('Finished training! Saving model to disk.') sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
def _export_mode( mode, has_saved_vars, builder, model, custom_objects, checkpoint_path): """Export a model, and optionally save new vars from the clone model. Args: mode: A `tf.estimator.ModeKeys` string. has_saved_vars: A `boolean` indicating whether the SavedModel has already exported variables. builder: A `SavedModelBuilder` object. model: A `tf.keras.Model` object. custom_objects: A dictionary mapping string names to custom classes or functions. checkpoint_path: String path to checkpoint. Raises: ValueError: If the train/eval mode is being exported, but the model does not have an optimizer. """ compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT) if compile_clone and not model.optimizer: raise ValueError( 'Model does not have an optimizer. Cannot export mode %s' % mode) model_graph = ops.get_default_graph() with ops.Graph().as_default() as g: K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN) # Clone the model into blank graph. This will create placeholders for inputs # and targets. clone = models_lib.clone_and_build_model( model, custom_objects=custom_objects, compile_clone=compile_clone) # Make sure that iterations variable is added to the global step collection, # to ensure that, when the SavedModel graph is loaded, the iterations # variable is returned by `tf.train.get_global_step()`. This is required for # compatibility with the SavedModelEstimator. if compile_clone: g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) # Extract update and train ops from train/test/predict functions. if mode == model_fn_lib.ModeKeys.TRAIN: clone._make_train_function() builder._add_train_op(clone.train_function.updates_op) elif mode == model_fn_lib.ModeKeys.EVAL: clone._make_test_function() else: clone._make_predict_function() g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates) clone_var_list = checkpointable_utils.named_saveables(clone) with session.Session().as_default(): if has_saved_vars: # Confirm all variables in the clone have an entry in the checkpoint. status = clone.load_weights(checkpoint_path) status.assert_existing_objects_matched() else: # Confirm that variables between the clone and model match up exactly, # not counting optimizer objects. Optimizer objects are ignored because # if the model has not trained, the slot variables will not have been # created yet. # TODO(b/113179535): Replace with checkpointable equivalence. _assert_same_non_optimizer_objects(model, model_graph, clone, g) # TODO(b/113178242): Use value transfer for checkpointable objects. clone.load_weights(checkpoint_path) # Add graph and variables to SavedModel. # TODO(b/113134168): Switch to add_meta_graph_and_variables. clone.save_weights(checkpoint_path, save_format='tf', overwrite=True) builder._has_saved_variables = True # Add graph to the SavedModel builder. builder.add_meta_graph( model_fn_lib.EXPORT_TAG_MAP[mode], signature_def_map=_create_signature_def_map(clone, mode), saver=saver_lib.Saver(clone_var_list), main_op=variables.local_variables_initializer()) return None
def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None): with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) input_ops = serving_from_csv_input(train_config, args, keep_target) model_fn_ops = estimator._call_model_fn( input_ops.features, None, model_fn_lib.ModeKeys.INFER) output_fetch_tensors = make_output_tensors( train_config=train_config, args=args, input_ops=input_ops, model_fn_ops=model_fn_ops, keep_target=keep_target) signature_def_map = { 'serving_default': signature_def_utils.predict_signature_def( input_ops.default_inputs, output_fetch_tensors) } if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(estimator._model_dir) if not checkpoint_path: raise NotFittedError("Couldn't find trained model at %s." % estimator._model_dir) export_dir = saved_model_export_utils.get_timestamped_export_dir( export_dir_base) with tf_session.Session('') as session: #variables.initialize_local_variables() variables.local_variables_initializer() data_flow_ops.tables_initializer() saver_for_restore = saver.Saver(variables.global_variables(), sharded=True) saver_for_restore.restore(session, checkpoint_path) init_op = control_flow_ops.group( variables.local_variables_initializer(), data_flow_ops.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=init_op) builder.save(False) # Add the extra assets if assets_extra: assets_extra_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join( compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) gfile.MakeDirs(dest_path) gfile.Copy(source, dest_absolute) # only keep the last 3 models saved_model_export_utils.garbage_collect_exports(export_dir_base, exports_to_keep=3) # save the last model to the model folder. # export_dir_base = A/B/intermediate_models/ if keep_target: final_dir = os.path.join(args.job_dir, 'evaluation_model') else: final_dir = os.path.join(args.job_dir, 'model') if file_io.is_directory(final_dir): file_io.delete_recursively(final_dir) file_io.recursive_create_dir(final_dir) _recursive_copy(export_dir, final_dir) return export_dir
def test1Workers2Period(self): num_workers = 1 communication_period = 2 num_ps = 1 cluster, workers, _ = create_local_cluster(num_workers=num_workers, num_ps=num_ps) sessions, graphs, train_ops, savers = _get_workers( num_workers, communication_period, workers, 1.0) var_0 = graphs[0].get_tensor_by_name("v0:0") var_1 = graphs[0].get_tensor_by_name("v1:0") global_step = training_util.get_global_step(graphs[0]) var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0") var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0") # Verify the initialized value. self.assertAllEqual(0.0, sessions[0].run(var_0)) self.assertAllEqual(1.0, sessions[0].run(var_1)) self.assertAllEqual(0.0, sessions[0].run(var_0_g)) self.assertAllEqual(1.0, sessions[0].run(var_1_g)) self.assertAllEqual(0, sessions[0].run(global_step)) sessions[0].run(train_ops[0]) self.assertAllEqual(1.0, sessions[0].run(var_0)) self.assertAllEqual(2.0, sessions[0].run(var_1)) self.assertAllEqual(0.0, sessions[0].run(var_0_g)) self.assertAllEqual(1.0, sessions[0].run(var_1_g)) self.assertAllEqual(0, sessions[0].run(global_step)) # iteration 2, global variable update sessions[0].run(train_ops[0]) self.assertAllEqual(0.0, sessions[0].run(var_0)) self.assertAllEqual(1.0, sessions[0].run(var_1)) self.assertAllEqual(2.0, sessions[0].run(var_0_g)) self.assertAllEqual(3.0, sessions[0].run(var_1_g)) self.assertAllEqual(1, sessions[0].run(global_step)) # iteration 3 sessions[0].run(train_ops[0]) self.assertAllEqual(1.0, sessions[0].run(var_0)) self.assertAllEqual(2.0, sessions[0].run(var_1)) self.assertAllEqual(2.0, sessions[0].run(var_0_g)) self.assertAllEqual(3.0, sessions[0].run(var_1_g)) self.assertAllEqual(1, sessions[0].run(global_step)) sessions[0].run(train_ops[0]) # save, data will be global value outfile = os.path.join(test.get_temp_dir(), "model") savers[0].save(sessions[0]._sess._sess._sess._sess, save_path=outfile) ops.reset_default_graph() # restore on a new graph with session.Session() as sess: v0 = variable_scope.get_variable(initializer=0.0, name="v0") v1 = variable_scope.get_variable(initializer=1.0, name="v1") sess.run(variables.local_variables_initializer()) saver_opt = saver.Saver(var_list=[v1, v0]) saver_opt.restore(sess, outfile) self.assertAllEqual(2.0, sess.run(v0)) self.assertAllEqual(3.0, sess.run(v1))
def main_op(): init_local = variables.local_variables_initializer() init_tables = lookup_ops.tables_initializer() return control_flow_ops.group(init_local, init_tables)
def train(train_op, logdir, train_step_fn=train_step, train_step_kwargs=_USE_DEFAULT, log_every_n_steps=1, graph=None, master='', is_chief=True, global_step=None, number_of_steps=None, init_op=_USE_DEFAULT, init_feed_dict=None, local_init_op=_USE_DEFAULT, init_fn=None, ready_op=_USE_DEFAULT, summary_op=_USE_DEFAULT, save_summaries_secs=600, summary_writer=_USE_DEFAULT, startup_delay_steps=0, saver=None, save_interval_secs=600, sync_optimizer=None, session_config=None, session_wrapper=None, trace_every_n_steps=None, ignore_live_threads=False): """Runs a training loop using a TensorFlow supervisor. When the sync_optimizer is supplied, gradient updates are applied synchronously. Otherwise, gradient updates are applied asynchronous. Args: train_op: A `Tensor` that, when executed, will apply the gradients and return the loss value. logdir: The directory where training logs are written to. If None, model checkpoints and summaries will not be written. train_step_fn: The function to call in order to execute a single gradient step. The function must have take exactly four arguments: the current session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary. train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By default, two `Boolean`, scalar ops called "should_stop" and "should_log" are provided. log_every_n_steps: The frequency, in terms of global steps, that the loss and global step and logged. graph: The graph to pass to the supervisor. If no graph is supplied the default graph is used. master: The address of the tensorflow master. is_chief: Specifies whether or not the training is being run by the primary replica during replica training. global_step: The `Tensor` representing the global step. If left as `None`, then slim.variables.get_or_create_global_step() is used. number_of_steps: The max number of gradient steps to take during training, as measured by 'global_step': training will stop if global_step is greater than 'number_of_steps'. If the value is left as None, training proceeds indefinitely. init_op: The initialization operation. If left to its default value, then the session is initialized by calling `tf.global_variables_initializer()`. init_feed_dict: A feed dictionary to use when executing the `init_op`. local_init_op: The local initialization operation. If left to its default value, then the session is initialized by calling `tf.local_variables_initializer()` and `tf.tables_initializer()`. init_fn: An optional callable to be executed after `init_op` is called. The callable must accept one argument, the session being initialized. ready_op: Operation to check if the model is ready to use. If left to its default value, then the session checks for readiness by calling `tf.report_uninitialized_variables()`. summary_op: The summary operation. save_summaries_secs: How often, in seconds, to save summaries. summary_writer: `SummaryWriter` to use. Can be `None` to indicate that no summaries should be written. If unset, we create a SummaryWriter. startup_delay_steps: The number of steps to wait for before beginning. Note that this must be 0 if a sync_optimizer is supplied. saver: Saver to save checkpoints. If None, a default one will be created and used. save_interval_secs: How often, in seconds, to save the model to `logdir`. sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of them. If the argument is supplied, gradient updates will be synchronous. If left as `None`, gradient updates will be asynchronous. session_config: An instance of `tf.ConfigProto` that will be used to configure the `Session`. If left as `None`, the default will be used. session_wrapper: A function that takes a `tf.Session` object as the only argument and returns a wrapped session object that has the same methods that the original object has, or `None`. Iff not `None`, the wrapped object will be used for training. trace_every_n_steps: produce and save a `Timeline` in Chrome trace format and add it to the summaries every `trace_every_n_steps`. If None, no trace information will be produced or saved. ignore_live_threads: If `True` ignores threads that remain running after a grace period when stopping the supervisor, instead of raising a RuntimeError. Returns: the value of the loss function after training. Raises: ValueError: if `train_op` is empty or if `startup_delay_steps` is non-zero when `sync_optimizer` is supplied, if `number_of_steps` is negative, or if `trace_every_n_steps` is not `None` and no `logdir` is provided. """ if train_op is None: raise ValueError('train_op cannot be None.') if logdir is None: if summary_op != _USE_DEFAULT: raise ValueError('Cannot provide summary_op because logdir=None') if saver is not None: raise ValueError('Cannot provide saver because logdir=None') if trace_every_n_steps is not None: raise ValueError('Cannot provide trace_every_n_steps because ' 'logdir=None') if isinstance(sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): sync_optimizer = [sync_optimizer] if sync_optimizer is not None and startup_delay_steps > 0: raise ValueError( 'startup_delay_steps must be zero when sync_optimizer is supplied.' ) if number_of_steps is not None and number_of_steps <= 0: raise ValueError( '`number_of_steps` must be either None or a positive number.') graph = graph or ops.get_default_graph() with graph.as_default(): if global_step is None: global_step = training_util.get_or_create_global_step() saver = saver or tf_saver.Saver() if sync_optimizer is not None: for opt in sync_optimizer: if not isinstance( opt, sync_replicas_optimizer.SyncReplicasOptimizer): raise ValueError( '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.' ) with ops.name_scope('init_ops'): if init_op == _USE_DEFAULT: init_op = variables.global_variables_initializer() if ready_op == _USE_DEFAULT: ready_op = variables.report_uninitialized_variables() if local_init_op == _USE_DEFAULT: local_init_op = control_flow_ops.group( variables.local_variables_initializer(), lookup_ops.tables_initializer()) if sync_optimizer is not None and isinstance(sync_optimizer, list): with ops.control_dependencies( [local_init_op] if local_init_op is not None else []): if is_chief: local_init_op = control_flow_ops.group( *[opt.chief_init_op for opt in sync_optimizer]) else: local_init_op = control_flow_ops.group( * [opt.local_step_init_op for opt in sync_optimizer]) ready_for_local_init_op = control_flow_ops.group( *[opt.ready_for_local_init_op for opt in sync_optimizer]) else: ready_for_local_init_op = None if summary_op == _USE_DEFAULT: summary_op = summary.merge_all() if summary_writer == _USE_DEFAULT: summary_writer = supervisor.Supervisor.USE_DEFAULT if is_chief and sync_optimizer is not None: # Need to create these BEFORE the supervisor finalizes the graph: init_tokens_op = [ opt.get_init_tokens_op() for opt in sync_optimizer ] chief_queue_runner = [ opt.get_chief_queue_runner() for opt in sync_optimizer ] if train_step_kwargs == _USE_DEFAULT: with ops.name_scope('train_step'): train_step_kwargs = {} if number_of_steps: should_stop_op = math_ops.greater_equal( global_step, number_of_steps) else: should_stop_op = constant_op.constant(False) train_step_kwargs['should_stop'] = should_stop_op if log_every_n_steps > 0: train_step_kwargs['should_log'] = math_ops.equal( math_ops.mod(global_step, log_every_n_steps), 0) if is_chief and trace_every_n_steps is not None: train_step_kwargs['should_trace'] = math_ops.equal( math_ops.mod(global_step, trace_every_n_steps), 0) train_step_kwargs['logdir'] = logdir sv = supervisor.Supervisor(graph=graph, is_chief=is_chief, logdir=logdir, init_op=init_op, init_feed_dict=init_feed_dict, local_init_op=local_init_op, ready_for_local_init_op=ready_for_local_init_op, ready_op=ready_op, summary_op=summary_op, summary_writer=summary_writer, global_step=global_step, saver=saver, save_summaries_secs=save_summaries_secs, save_model_secs=save_interval_secs, init_fn=init_fn) if summary_writer is not None: train_step_kwargs['summary_writer'] = sv.summary_writer total_loss = 0 should_retry = True while should_retry: try: should_retry = False with sv.managed_session(master, start_standard_services=False, config=session_config) as sess: logging.info('Starting Session.') if session_wrapper is not None: logging.info('Wrapping session with wrapper function: %s', session_wrapper) sess = session_wrapper(sess) if is_chief: if logdir: sv.start_standard_services(sess) elif startup_delay_steps > 0: # (use sys.maxsize because sys.maxint doesn't exist in Python 3) _wait_for_step( sess, global_step, min(startup_delay_steps, number_of_steps or sys.maxsize)) threads = sv.start_queue_runners(sess) logging.info('Starting Queues.') if is_chief and sync_optimizer is not None: sv.start_queue_runners(sess, chief_queue_runner) sess.run(init_tokens_op) try: while not sv.should_stop(): total_loss, should_stop = train_step_fn( sess, train_op, global_step, train_step_kwargs) if should_stop: logging.info('Stopping Training.') sv.request_stop() break except errors.OutOfRangeError: # OutOfRangeError is thrown when epoch limit per # tf.train.limit_epochs is reached. logging.info('Caught OutOfRangeError. Stopping Training.') if logdir and sv.is_chief: logging.info('Finished training! Saving model to disk.') sv.saver.save(sess, sv.save_path, global_step=sv.global_step) sv.stop(threads, close_summary_writer=True, ignore_live_threads=ignore_live_threads) except errors.AbortedError: # Always re-run on AbortedError as it indicates a restart of one of the # distributed tensorflow servers. logging.info('Retrying training!') should_retry = True return total_loss
def _default_local_init_op(): return control_flow_ops.group(variables.local_variables_initializer(), data_flow_ops.tables_initializer())
def test_read_text_lines_large(self): gfile.Glob = self._orig_glob sequence_prefix = "abcdefghijklmnopqrstuvwxyz123456789" num_records = 49999 lines = [ "".join([sequence_prefix, str(l)]).encode("ascii") for l in xrange(num_records) ] json_lines = [ "".join([ '{"features": { "feature": { "sequence": {', '"bytes_list": { "value": ["', base64.b64encode(l).decode("ascii"), '"]}}}}}\n' ]) for l in lines ] filename = self._create_temp_file("".join(json_lines)) batch_size = 10000 queue_capacity = 10000 name = "my_large_batch" features = { "sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string) } with ops.Graph().as_default() as g, self.test_session( graph=g) as session: keys, result = graph_io.read_keyed_batch_features( filename, batch_size, features, io_ops.TextLineReader, randomize_input=False, num_epochs=1, queue_capacity=queue_capacity, num_enqueue_threads=2, parse_fn=parsing_ops.decode_json_example, name=name) self.assertAllEqual((None, ), keys.get_shape().as_list()) self.assertEqual(1, len(result)) self.assertAllEqual((None, ), result["sequence"].get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) data = [] try: while not coord.should_stop(): data.append(session.run(result)) except errors.OutOfRangeError: pass finally: coord.request_stop() coord.join(threads) parsed_records = [ item for sublist in [d["sequence"] for d in data] for item in sublist ] # Check that the number of records matches expected and all records # are present. self.assertEqual(len(parsed_records), num_records) self.assertEqual(set(parsed_records), set(lines))
def train(datasets_dicts, epochs, val_every, iters_cnt, validate_with_eval_model, pipeline_config, num_clones=1, save_cback=None): logger.info('Start train') configs = configs_from_pipeline(pipeline_config) model_config = configs['model'] train_config = configs['train_config'] create_model_fn = functools.partial( model_builder.build, model_config=model_config, is_training=True) detection_model = create_model_fn() def get_next(dataset): return dataset_util.make_initializable_iterator( build_dataset(dataset)).get_next() create_tensor_dict_fn = functools.partial(get_next, datasets_dicts['train']) create_tensor_dict_fn_val = functools.partial(get_next, datasets_dicts['val']) data_augmentation_options = [ preprocessor_builder.build(step) for step in train_config.data_augmentation_options] with tf.Graph().as_default(): # Build a configuration specifying multi-GPU and multi-replicas. deploy_config = model_deploy.DeploymentConfig( num_clones=4, clone_on_cpu=False, replica_id=0, num_replicas=1, num_ps_tasks=0, worker_job_name='lonely_worker') # Place the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() with tf.device(deploy_config.inputs_device()): coord = coordinator.Coordinator() input_queue = create_input_queue( train_config.batch_size, create_tensor_dict_fn, train_config.batch_queue_capacity, train_config.num_batch_queue_threads, train_config.prefetch_queue_capacity, data_augmentation_options) input_queue_val = create_input_queue( train_config.batch_size, create_tensor_dict_fn_val, train_config.batch_queue_capacity, train_config.num_batch_queue_threads, train_config.prefetch_queue_capacity, data_augmentation_options) # create validation graph create_model_fn_val = functools.partial( model_builder.build, model_config=model_config, is_training=not validate_with_eval_model) with tf.device(deploy_config.optimizer_device()): training_optimizer, optimizer_summary_vars = optimizer_builder.build( train_config.optimizer) for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var, family='LearningRate') train_losses = [] grads_and_vars = [] with slim.arg_scope([slim.model_variable, slim.variable], device='/device:CPU:0'): for curr_dev_id in range(num_clones): with tf.device('/gpu:{}'.format(curr_dev_id)): with tf.name_scope('clone_{}'.format(curr_dev_id)) as scope: with tf.variable_scope(tf.get_variable_scope(), reuse=True if curr_dev_id > 0 else None): losses = _create_losses_val(input_queue, create_model_fn, train_config) clones_loss = tf.add_n(losses) clones_loss = tf.divide(clones_loss, 1.0 * num_clones) grads = training_optimizer.compute_gradients(clones_loss) train_losses.append(clones_loss) grads_and_vars.append(grads) if curr_dev_id == 0: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) val_total_loss = get_val_loss(num_clones, input_queue_val, create_model_fn_val, train_config) with tf.device(deploy_config.optimizer_device()): total_loss = tf.add_n(train_losses) grads_and_vars = model_deploy._sum_clones_gradients(grads_and_vars) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') # Optionally multiply bias gradients by train_config.bias_grad_multiplier. if train_config.bias_grad_multiplier: biases_regex_list = ['.*/biases'] grads_and_vars = variables_helper.multiply_gradients_matching_regex( grads_and_vars, biases_regex_list, multiplier=train_config.bias_grad_multiplier) # Optionally freeze some layers by setting their gradients to be zero. if train_config.freeze_variables: grads_and_vars = variables_helper.freeze_gradients_matching_regex( grads_and_vars, train_config.freeze_variables) # Optionally clip gradients if train_config.gradient_clipping_by_norm > 0: with tf.name_scope('clip_grads'): grads_and_vars = slim.learning.clip_gradient_norms( grads_and_vars, train_config.gradient_clipping_by_norm) # Create gradient updates. grad_updates = training_optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) coord.clear_stop() sess = tf.Session(config=config) saver = tf.train.Saver() graph = ops.get_default_graph() with graph.as_default(): with ops.name_scope('init_ops'): init_op = variables.global_variables_initializer() ready_op = variables.report_uninitialized_variables() local_init_op = control_flow_ops.group( variables.local_variables_initializer(), lookup_ops.tables_initializer()) # graph.finalize() sess.run([init_op, ready_op, local_init_op]) queue_runners = graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS) threads = [] for qr in queue_runners: threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True)) logger.info('Start restore') if train_config.fine_tune_checkpoint: var_map = detection_model.restore_map( fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type, load_all_detection_checkpoint_vars=( train_config.load_all_detection_checkpoint_vars)) available_var_map = (variables_helper. get_variables_available_in_checkpoint( var_map, train_config.fine_tune_checkpoint)) if 'global_step' in available_var_map: del available_var_map['global_step'] init_saver = tf.train.Saver(available_var_map) logger.info('Restoring model weights from previous checkpoint.') init_saver.restore(sess, train_config.fine_tune_checkpoint) logger.info('Model restored.') eval_planner = EvalPlanner(epochs, val_every) progress = sly.progress_counter_train(epochs, iters_cnt['train']) best_val_loss = float('inf') epoch_flt = 0 for epoch in range(epochs): logger.info("Before new epoch", extra={'epoch': epoch_flt}) for train_it in range(iters_cnt['train']): total_loss, np_global_step = sess.run([train_tensor, global_step]) metrics_values_train = { 'loss': total_loss, } progress.iter_done_report() epoch_flt = epoch_float(epoch, train_it + 1, iters_cnt['train']) sly.report_metrics_training(epoch_flt, metrics_values_train) if eval_planner.need_validation(epoch_flt): logger.info("Before validation", extra={'epoch': epoch_flt}) overall_val_loss = 0 for val_it in range(iters_cnt['val']): overall_val_loss += sess.run(val_total_loss) logger.info("Validation in progress", extra={'epoch': epoch_flt, 'val_iter': val_it, 'val_iters': iters_cnt['val']}) metrics_values_val = { 'loss': overall_val_loss / iters_cnt['val'], } sly.report_metrics_validation(epoch_flt, metrics_values_val) logger.info("Validation has been finished", extra={'epoch': epoch_flt}) eval_planner.validation_performed() val_loss = metrics_values_val['loss'] model_is_best = val_loss < best_val_loss if model_is_best: best_val_loss = val_loss logger.info('It\'s been determined that current model is the best one for a while.') save_cback(saver, sess, model_is_best, opt_data={ 'epoch': epoch_flt, 'val_metrics': metrics_values_val, }) logger.info("Epoch was finished", extra={'epoch': epoch_flt}) coord.request_stop() coord.join(threads)
def begin(self): self._local_init_op = variables.local_variables_initializer() self._global_init_op = None if self._is_chief: self._global_init_op = variables.global_variables_initializer() self._variable_init_op = self._ea_optimizer.get_init_op(self._task_index)
def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None): with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) input_ops = feature_transforms.build_csv_serving_tensors_for_training_step( args.analysis, features, schema, stats, keep_target) model_fn_ops = estimator._call_model_fn( input_ops.features, None, model_fn_lib.ModeKeys.INFER) output_fetch_tensors = make_prediction_output_tensors( args=args, features=features, input_ops=input_ops, model_fn_ops=model_fn_ops, keep_target=keep_target) # Don't use signature_def_utils.predict_signature_def as that renames # tensor names if there is only 1 input/output tensor! signature_inputs = { key: tf.saved_model.utils.build_tensor_info(tensor) for key, tensor in six.iteritems(input_ops.default_inputs) } signature_outputs = { key: tf.saved_model.utils.build_tensor_info(tensor) for key, tensor in six.iteritems(output_fetch_tensors) } signature_def_map = { 'serving_default': signature_def_utils.build_signature_def( signature_inputs, signature_outputs, tf.saved_model.signature_constants.PREDICT_METHOD_NAME) } if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(estimator._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % estimator._model_dir) export_dir = saved_model_export_utils.get_timestamped_export_dir( export_dir_base) if (model_fn_ops.scaffold is not None and model_fn_ops.scaffold.saver is not None): saver_for_restore = model_fn_ops.scaffold.saver else: saver_for_restore = saver.Saver(sharded=True) with tf_session.Session('') as session: saver_for_restore.restore(session, checkpoint_path) init_op = control_flow_ops.group( variables.local_variables_initializer(), resources.initialize_resources( resources.shared_resources()), tf.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=init_op) builder.save(False) # Add the extra assets if assets_extra: assets_extra_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join( compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) file_io.recursive_create_dir(dest_path) file_io.copy(source, dest_absolute) # only keep the last 3 models saved_model_export_utils.garbage_collect_exports(export_dir_base, exports_to_keep=3) # save the last model to the model folder. # export_dir_base = A/B/intermediate_models/ if keep_target: final_dir = os.path.join(args.job_dir, 'evaluation_model') else: final_dir = os.path.join(args.job_dir, 'model') if file_io.is_directory(final_dir): file_io.delete_recursively(final_dir) file_io.recursive_create_dir(final_dir) recursive_copy(export_dir, final_dir) return export_dir
def train(): vocab_size = len(open(FLAGS.vocab_file).readlines()) id_to_label = load_id_to_label() num_label = len(id_to_label) print('#vocab={} #label={}'.format(vocab_size, num_label)) parse_spec = get_parse_spec(FLAGS.use_ngrams, num_label) features = tf.contrib.learn.read_batch_features( FLAGS.train_tfrecord, FLAGS.batch_size, parse_spec, tf.TFRecordReader, num_epochs=FLAGS.num_epochs, reader_num_threads=FLAGS.num_threads) text_ts = tf.sparse_tensor_to_dense(features[TEXT_KEY], default_value=DEFAULT_WORD) label_ts = features.pop(LABELS_KEY) # text_ph = tf.placeholder(tf.string, shape=(None, None)) text_ph = tf.placeholder(tf.int64, shape=(None, None)) label_ph = tf.placeholder(tf.float32, shape=(None, num_label)) # text_lookup_table = tf.contrib.lookup.index_table_from_file( # FLAGS.vocab_file, FLAGS.num_oov_vocab_buckets, vocab_size) # text_ids = text_lookup_table.lookup(text_ph) text_ids = text_ph # text_embedding_w = tf.Variable(tf.random_uniform([vocab_size + FLAGS.num_oov_vocab_buckets, FLAGS.embedding_dimension], -0.1, 0.1)) text_embedding_w = tf.Variable( tf.random_uniform([vocab_size + 1, FLAGS.embedding_dimension], -0.1, 0.1)) text_embedding = tf.reduce_mean(tf.nn.embedding_lookup( text_embedding_w, text_ids), axis=-2) input_layer = text_embedding logits_ts = tf.contrib.layers.fully_connected(inputs=input_layer, num_outputs=num_label, activation_fn=None) loss_ts = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=label_ph, logits=logits_ts)) optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) train_op = optimizer.minimize(loss_ts, global_step=tf.train.get_global_step()) var_init = tf.global_variables_initializer() tab_init = tf.tables_initializer() tf.summary.scalar('loss', loss_ts) summary_op = tf.summary.merge_all() features_v = tf.contrib.learn.read_batch_features( FLAGS.valid_tfrecord, FLAGS.batch_size, parse_spec, tf.TFRecordReader, num_epochs=1, reader_num_threads=FLAGS.num_threads) text_ts_v = tf.sparse_tensor_to_dense(features_v[TEXT_KEY], default_value=DEFAULT_WORD) label_ts_v = features_v.pop(LABELS_KEY) from tensorflow.python.framework import errors from tensorflow.python.ops import variables from tensorflow.python.training import coordinator from tensorflow.python.training import queue_runner_impl with tf.Session() as sess: writer = tf.summary.FileWriter(FLAGS.logs_dir, graph=tf.get_default_graph()) sess.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess, coord=coord) sess.run(var_init) sess.run(tab_init) total_size = 0 try: while not coord.should_stop(): # feature_np, label_np = sess.run([features, label_ts]) # text_np = feature_np[TEXT_KEY] # print(type(text_np), text_np.shape, type(label_np), label_np.shape) # for i in range(FLAGS.batch_size): # label_ids = [j for j in range(num_label) if label_np[i,j] != 0] # labels = [id_to_label[label_id] for label_id in label_ids] # text = [text_np[i,j].decode('utf-8') for j in range(text_np.shape[1]) if text_np[i,j] != b' '] # text = ' '.join(text) # print(str(text), labels) # input() # input() for train_step in range(1000000): text_np, label_np = sess.run([text_ts, label_ts]) total_size += FLAGS.batch_size # print(type(text_np), text_np.shape, type(label_np), label_np.shape) # for i in range(FLAGS.batch_size): # label_ids = [j for j in range(num_label) if label_np[i,j] != 0] # labels = [id_to_label[label_id] for label_id in label_ids] # text = [text_np[i,j].decode('utf-8') for j in range(text_np.shape[1]) if text_np[i,j] != b' '] # text = ' '.join(text) # print(str(text), labels) # input() feed_dict = {text_ph: text_np, label_ph: label_np} _, loss, summary = sess.run( [train_op, loss_ts, summary_op], feed_dict=feed_dict) if (train_step + 1) % 100 == 0: writer.add_summary(summary, train_step) print('#{0} loss={1:.4f}'.format(train_step, loss)) except errors.OutOfRangeError: print('total={}'.format(total_size)) cutoff = 3 prec_v, rec_v = [], [] for valid_step in range(int(2000 / FLAGS.batch_size)): text_np, label_np = sess.run([text_ts_v, label_ts_v]) feed_dict = {text_ph: text_np, label_ph: label_np} logits, = sess.run([logits_ts], feed_dict=feed_dict) prec_bt = precision(logits, label_np, cutoff) prec_v.append(prec_bt) rec_bt = recall(logits, label_np, cutoff) rec_v.append(rec_bt) prec_v, rec_v = np.mean(prec_v), np.mean(rec_v) print('prec={0:.4f} rec={1:.4f}'.format(prec_v, rec_v)) finally: coord.request_stop() coord.join(threads)
def test_multiple_workers_with_shared_queue(self): gfile.Glob = self._orig_glob filenames = self._create_sorted_temp_files([ "ABC\n", "DEF\n", "GHI\n", "JKL\n", "MNO\n", "PQR\n", "STU\n", "VWX\n", "YZ\n" ]) batch_size = 1 queue_capacity = 5 name = "my_batch" example_queue_name = "%s/fifo_queue" % name worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name server = server_lib.Server.create_local_server() with ops.Graph().as_default() as g1, session_lib.Session( server.target, graph=g1) as session: keys, inputs = graph_io.read_keyed_batch_examples_shared_queue( filenames, batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, queue_capacity=queue_capacity, name=name) self.assertAllEqual((None, ), keys.get_shape().as_list()) self.assertAllEqual((None, ), inputs.get_shape().as_list()) session.run([ variables.local_variables_initializer(), variables.global_variables_initializer() ]) # Run the two queues once manually. self._run_queue(worker_file_name_queue_name, session) self._run_queue(example_queue_name, session) self.assertAllEqual(session.run(inputs), [b"ABC"]) # Run the worker and the example queue. self._run_queue(worker_file_name_queue_name, session) self._run_queue(example_queue_name, session) self.assertAllEqual(session.run(inputs), [b"DEF"]) with ops.Graph().as_default() as g2, session_lib.Session( server.target, graph=g2) as session: keys, inputs = graph_io.read_keyed_batch_examples_shared_queue( filenames, batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, queue_capacity=queue_capacity, name=name) self.assertAllEqual((None, ), keys.get_shape().as_list()) self.assertAllEqual((None, ), inputs.get_shape().as_list()) # Run the worker and the example queue. self._run_queue(worker_file_name_queue_name, session) self._run_queue(example_queue_name, session) self.assertAllEqual(session.run(inputs), [b"GHI"]) self.assertTrue(g1 is not g2)
def _export_mode(mode, has_saved_vars, builder, model, custom_objects, checkpoint_path, input_signature): """Exports a model, and optionally saves new vars from the clone model. Args: mode: A `tf.estimator.ModeKeys` string. has_saved_vars: A `boolean` indicating whether the SavedModel has already exported variables. builder: A `SavedModelBuilder` object. model: A `tf.keras.Model` object. custom_objects: A dictionary mapping string names to custom classes or functions. checkpoint_path: String path to checkpoint. input_signature: Nested TensorSpec containing the expected inputs. Can be `None`, in which case the signature will be inferred from the model. Raises: ValueError: If the train/eval mode is being exported, but the model does not have an optimizer. """ compile_clone = (mode != mode_keys.ModeKeys.PREDICT) if compile_clone and not model.optimizer: raise ValueError( 'Model does not have an optimizer. Cannot export mode %s' % mode) model_graph = ops.get_default_graph() with ops.Graph().as_default() as g, K.learning_phase_scope( mode == mode_keys.ModeKeys.TRAIN): if input_signature is None: input_tensors = None else: input_tensors = nest.map_structure(create_placeholder, input_signature) # Clone the model into blank graph. This will create placeholders for inputs # and targets. clone = models_lib.clone_and_build_model(model, input_tensors=input_tensors, custom_objects=custom_objects, compile_clone=compile_clone) # Make sure that iterations variable is added to the global step collection, # to ensure that, when the SavedModel graph is loaded, the iterations # variable is returned by `tf.compat.v1.train.get_global_step()`. This is # required for compatibility with the SavedModelEstimator. if compile_clone: g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) # Extract update and train ops from train/test/predict functions. train_op = None if mode == mode_keys.ModeKeys.TRAIN: clone._make_train_function() # pylint: disable=protected-access train_op = clone.train_function.updates_op elif mode == mode_keys.ModeKeys.TEST: clone._make_test_function() # pylint: disable=protected-access else: clone._make_predict_function() # pylint: disable=protected-access g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend( clone.state_updates) with session.Session().as_default(): clone_var_list = _get_var_list(clone) if has_saved_vars: # Confirm all variables in the clone have an entry in the checkpoint. status = clone.load_weights(checkpoint_path) status.assert_existing_objects_matched() else: # Confirm that variables between the clone and model match up exactly, # not counting optimizer objects. Optimizer objects are ignored because # if the model has not trained, the slot variables will not have been # created yet. # TODO(b/113179535): Replace with trackable equivalence. _assert_same_non_optimizer_objects(model, model_graph, clone, g) # TODO(b/113178242): Use value transfer for trackable objects. clone.load_weights(checkpoint_path) # Add graph and variables to SavedModel. # TODO(b/113134168): Switch to add_meta_graph_and_variables. clone.save_weights(checkpoint_path, save_format='tf', overwrite=True) builder._has_saved_variables = True # pylint: disable=protected-access # Add graph to the SavedModel builder. builder.add_meta_graph( model_utils.EXPORT_TAG_MAP[mode], signature_def_map=_create_signature_def_map(clone, mode), saver=saver_lib.Saver( clone_var_list, # Allow saving Models with no variables. This is somewhat odd, but # it's not necessarily a bug. allow_empty=True), init_op=variables.local_variables_initializer(), train_op=train_op) return None