def test_build_all_summaries(self): training.create_global_step() x = {'x': tf.placeholder(tf.float32, [2, 89])} y = tf.constant([[1], [1]]) model = BaseModel(plx.Modes.TRAIN, graph_fn=self.get_dummy_graph_fn(), loss_config=LossConfig(module='log_loss'), optimizer_config=OptimizerConfig(module='adadelta', decay_type='exponential_decay'), model_type=BaseModel.Types.CLASSIFIER, eval_metrics_config=[], summaries='all', name='test') model(x, y, None, None) # Only var are created learning_rate_summaries = 0 activations_summaries = 0 gradients_summaries = 0 loss_summaries = 0 for s_name in get_tracked(collection=tf.GraphKeys.SUMMARIES_BY_NAMES).keys(): if 'learning_rate' in s_name: learning_rate_summaries += 1 elif 'Activation' in s_name: activations_summaries += 1 elif 'Loss' in s_name: loss_summaries += 1 elif 'Gradient' in s_name: gradients_summaries += 1 assert learning_rate_summaries > 0 assert activations_summaries > 0 assert gradients_summaries > 0 assert loss_summaries > 0
def predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None): """Returns predictions for given features. Args: input_fn: Input function returning features which is a dictionary of string feature name to `Tensor` or `SparseTensor`. If it returns a tuple, first item is extracted as features. Prediction continues until `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`). predict_keys: list of `str`, name of the keys to predict. It is used if the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used then rest of the predictions will be filtered from the dictionary. If `None`, returns all. hooks: List of `SessionRunHook` subclass instances. Used for callbacks inside the prediction call. checkpoint_path: Path of a specific checkpoint to predict. If `None`, the latest checkpoint in `model_dir` is used. Yields: Evaluated values of `predictions` tensors. Raises: ValueError: Could not find a trained model in model_dir. ValueError: if batch length of predictions are not same. ValueError: If there is a conflict between `predict_keys` and `predictions`. For example if `predict_keys` is not `None` but `EstimatorSpec.predictions` is not a `dict`. """ hooks = _check_hooks_type(hooks) # Check that model has been trained. if not checkpoint_path: checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise ValueError('Could not find trained model in model_dir: {}.'.format( self._model_dir)) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) training.create_global_step(g) features = self._get_features_from_input_fn(input_fn) estimator_spec = self._call_model_fn(features, None, model_fn_lib.ModeKeys.PREDICT) predictions = self._extract_keys(estimator_spec.predictions, predict_keys) with training.MonitoredSession( session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, scaffold=estimator_spec.scaffold, config=config_pb2.ConfigProto(allow_soft_placement=True)), hooks=hooks) as mon_sess: while not mon_sess.should_stop(): preds_evaluated = mon_sess.run(predictions) if not isinstance(predictions, dict): for pred in preds_evaluated: yield pred else: for i in range(self._extract_batch_length(preds_evaluated)): yield { key: value[i] for key, value in six.iteritems(preds_evaluated) }
def test_build_learning_rate_summaries(self): training.create_global_step() config = ModelConfig(loss_config=LossConfig(name='log_loss'), optimizer_config=OptimizerConfig( name='Adadelta', decay_type='exponential_decay')) x = {'source_ids': tf.placeholder(tf.float32, [2, 89])} y = tf.constant([[1], [1]]) model = BaseModel(plx.ModeKeys.TRAIN, graph_fn=self.get_dummy_graph_fn(), config=config, model_type=BaseModel.Types.CLASSIFIER, summaries=['learning_rate'], name='test', params=None) model(x, y, None, None) # Only var are created summaries_names = list( get_tracked(collection=tf.GraphKeys.SUMMARIES_BY_NAMES).keys()) assert len(summaries_names) == 1 assert summaries_names[0] == 'learning_rate'
def predict(self, input_fn, predict_keys=None, hooks=None): """Returns predictions for given features. Args: input_fn: Input function returning features which is a dictionary of string feature name to `Tensor` or `SparseTensor`. If it returns a tuple, first item is extracted as features. Prediction continues until `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`). predict_keys: list of `str`, name of the keys to predict. It is used if the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used then rest of the predictions will be filtered from the dictionary. If `None`, returns all. hooks: List of `SessionRunHook` subclass instances. Used for callbacks inside the prediction call. Yields: Evaluated values of `predictions` tensors. Raises: ValueError: Could not find a trained model in model_dir. ValueError: if batch length of predictions are not same. ValueError: If there is a conflict between `predict_keys` and `predictions`. For example if `predict_keys` is not `None` but `EstimatorSpec.predictions` is not a `dict`. """ hooks = list(hooks or []) # Check that model has been trained. checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise ValueError('Could not find trained model in model_dir: {}.'.format( self._model_dir)) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) training.create_global_step(g) features = self._get_features_from_input_fn(input_fn) estimator_spec = self._call_model_fn(features, None, model_fn_lib.ModeKeys.FIT) predictions = self._extract_keys(estimator_spec.predictions, predict_keys) with training.MonitoredSession( session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, scaffold=estimator_spec.scaffold, config=config_pb2.ConfigProto(allow_soft_placement=True)), hooks=hooks) as mon_sess: while not mon_sess.should_stop(): preds_evaluated = mon_sess.run(predictions) if not isinstance(predictions, dict): for pred in preds_evaluated: yield pred else: for i in range(self._extract_batch_length(preds_evaluated)): yield { key: value[i] for key, value in six.iteritems(preds_evaluated) }
def _train_model(self, input_fn, hooks): all_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = training.create_global_step(g) with ops.device('/cpu:0'): features, labels = input_fn() estimator_spec = self._call_model_fn(features, labels, model_fn_lib.ModeKeys.FIT) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend([ training.NanTensorHook(estimator_spec.loss), training.LoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step_tensor }, every_n_iter=100) ]) all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) scaffold = estimator_spec.scaffold or training.Scaffold() if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection(ops.GraphKeys.SAVERS, training.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, defer_build=True)) chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): saver_hook_exists = any([ isinstance(h, training.CheckpointSaverHook) for h in (all_hooks + chief_hooks + estimator_spec.training_chief_hooks) ]) if not saver_hook_exists: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=scaffold) ] with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=scaffold, hooks=all_hooks, chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) return loss
def _train_model(self, input_fn, hooks): all_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = training.create_global_step(g) with ops.device('/cpu:0'): features, labels = input_fn() estimator_spec = self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend([ training.NanTensorHook(estimator_spec.loss), training.LoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step_tensor }, every_n_iter=100) ]) all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) if not (estimator_spec.scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection(ops.GraphKeys.SAVERS, training.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, defer_build=True)) chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): saver_hook_exists = any([ isinstance(h, training.CheckpointSaverHook) for h in (all_hooks + chief_hooks + estimator_spec.training_chief_hooks) ]) if not saver_hook_exists: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=estimator_spec.scaffold) ] with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=all_hooks, chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) return loss
def _evaluate_model(self, input_fn, hooks=None, checkpoint_path=None, name=''): """Evaluates the model using the training.evaluation library.""" # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: latest_path = saver.latest_checkpoint(self._model_dir) if not latest_path: raise ValueError( 'Could not find trained model in model_dir: {}.'.format( self._model_dir)) checkpoint_path = latest_path # Setup output directory. eval_dir = os.path.join(self._model_dir, 'eval' if not name else 'eval_' + name) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = training.create_global_step(g) features, labels = input_fn() estimator_spec = self._call_model_fn(features, labels, model_fn_lib.ModeKeys.EVAL) if model_fn_lib.MetricKeys.LOSS in estimator_spec.eval_metric_ops: raise ValueError( 'Metric with name "%s" is not allowed, because Estimator ' % (model_fn_lib.MetricKeys.LOSS) + 'already defines a default metric with the same name.') estimator_spec.eval_metric_ops[ model_fn_lib.MetricKeys.LOSS] = metrics_lib.mean( estimator_spec.loss) update_op, eval_dict = _extract_metric_update_ops( estimator_spec.eval_metric_ops) if ops.GraphKeys.GLOBAL_STEP in eval_dict: raise ValueError( 'Metric with name `global_step` is not allowed, because Estimator ' 'already defines a default metric with the same name.') eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor eval_results = evaluation._evaluate_once( # pylint: disable=protected-access checkpoint_path=checkpoint_path, master=self._config.evaluation_master, scaffold=estimator_spec.scaffold, eval_ops=update_op, final_ops=eval_dict, hooks=hooks, config=self._session_config) _write_dict_to_summary( output_dir=eval_dir, dictionary=eval_results, current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP]) return eval_results
def _evaluate_model(self, input_fn, hooks=None, checkpoint_path=None, name=''): # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: latest_path = saver.latest_checkpoint(self._model_dir) if not latest_path: error_message = "Could not find trained model at {}.".format( self._model_dir) raise EstimatorNotTrainedError(error_message) checkpoint_path = latest_path # Setup output directory. eval_dir = os.path.join(self._model_dir, 'eval' if not name else 'eval_' + name) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) global_step = training.create_global_step(g) features, labels = input_fn() estimator_spec = self._call_model_fn(features, labels, ModeKeys.EVAL) if model_fn_lib.MetricKeys.LOSS in estimator_spec.eval_metric_ops: raise ValueError( "Metric with name `{}` is not allowed, because Estimator " "already defines a default metric " "with the same name.".format(model_fn_lib.MetricKeys.LOSS)) estimator_spec.eval_metric_ops[ model_fn_lib.MetricKeys.LOSS] = metrics_lib.streaming_mean( estimator_spec.loss) update_op, eval_dict = self._extract_metric_update_ops( estimator_spec.eval_metric_ops) if ops.GraphKeys.GLOBAL_STEP in eval_dict: raise ValueError( "Metric with name `global_step` is not allowed, because " "Estimator already defines a default metric with the same name." ) eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step eval_results = evaluation._evaluate_once( checkpoint_path=checkpoint_path, master=self._config.evaluation_master, scaffold=estimator_spec.scaffold, eval_ops=update_op, final_ops=eval_dict, hooks=hooks, config=self._session_config) self._write_dict_to_summary( output_dir=eval_dir, dictionary=eval_results, current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP]) return eval_results
def test_build_learning_rate_summaries(self): training.create_global_step() x = {'x': tf.placeholder(tf.float32, [2, 89])} y = tf.constant([[1], [1]]) model = BaseModel(plx.Modes.TRAIN, graph_fn=self.get_dummy_graph_fn(), loss_config=LossConfig(module='log_loss'), optimizer_config=OptimizerConfig(module='adadelta', decay_type='exponential_decay'), model_type=BaseModel.Types.CLASSIFIER, eval_metrics_config=[], summaries=['learning_rate'], name='test') model(x, y, None, None) # Only var are created summaries_names = list(get_tracked(collection=tf.GraphKeys.SUMMARIES_BY_NAMES).keys()) assert len(summaries_names) == 1 assert summaries_names[0] == 'learning_rate'
def _evaluate_model(self, input_fn, hooks=None, checkpoint_path=None, name=''): """Evaluates the model using the training.evaluation library.""" # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: latest_path = saver.latest_checkpoint(self._model_dir) if not latest_path: raise ValueError('Could not find trained model in model_dir: {}.'. format(self._model_dir)) checkpoint_path = latest_path # Setup output directory. eval_dir = os.path.join(self._model_dir, 'eval' if not name else 'eval_' + name) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = training.create_global_step(g) features, labels = input_fn() estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.EVAL) if model_fn_lib.MetricKeys.LOSS in estimator_spec.eval_metric_ops: raise ValueError( 'Metric with name "%s" is not allowed, because Estimator ' % ( model_fn_lib.MetricKeys.LOSS) + 'already defines a default metric with the same name.') estimator_spec.eval_metric_ops[ model_fn_lib.MetricKeys.LOSS] = metrics_lib.mean(estimator_spec.loss) update_op, eval_dict = _extract_metric_update_ops( estimator_spec.eval_metric_ops) if ops.GraphKeys.GLOBAL_STEP in eval_dict: raise ValueError( 'Metric with name `global_step` is not allowed, because Estimator ' 'already defines a default metric with the same name.') eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor eval_results = evaluation._evaluate_once( # pylint: disable=protected-access checkpoint_path=checkpoint_path, master=self._config.evaluation_master, scaffold=estimator_spec.scaffold, eval_ops=update_op, final_ops=eval_dict, hooks=hooks, config=self._session_config) _write_dict_to_summary( output_dir=eval_dir, dictionary=eval_results, current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP]) return eval_results
def _create_global_step(self, graph): """Creates the global step tensor in graph. The global step tensor must be an integer type with name 'global_step' and be added to the collection ${tf.GraphKeys.GLOBAL_STEP}. Args: graph: The graph in which to create the global step tensor. Returns: The global step `Tensor`. """ return training.create_global_step(graph)
def test_build_all_summaries(self): training.create_global_step() config = ModelConfig(loss_config=LossConfig(name='log_loss'), optimizer_config=OptimizerConfig( name='Adadelta', decay_type='exponential_decay')) x = {'source_ids': tf.placeholder(tf.float32, [2, 89])} y = tf.constant([[1], [1]]) model = BaseModel(plx.ModeKeys.TRAIN, graph_fn=self.get_dummy_graph_fn(), config=config, model_type=BaseModel.Types.CLASSIFIER, summaries='all', name='test', params=None) model(x, y, None, None) # Only var are created learning_rate_summaries = 0 activations_summaries = 0 gradients_summaries = 0 loss_summaries = 0 for s_name in get_tracked( collection=tf.GraphKeys.SUMMARIES_BY_NAMES).keys(): if 'learning_rate' in s_name: learning_rate_summaries += 1 elif 'Activation' in s_name: activations_summaries += 1 elif 'Loss' in s_name: loss_summaries += 1 elif 'Gradient' in s_name: gradients_summaries += 1 assert learning_rate_summaries > 0 assert activations_summaries > 0 assert gradients_summaries > 0 assert loss_summaries > 0
def _create_global_step(self, graph): """Creates the global step tensor in graph. The global step tensor must be an integer type with name 'global_step' and be added to the collection ${tf.GraphKeys.GLOBAL_STEP}. Args: graph: The graph in which to create the global step tensor. Returns: The global step `Tensor`. """ return training.create_global_step(graph)
def test_do_not_stop_if_checkpoint_is_not_there(self): with ops.Graph().as_default(): step = training.create_global_step() assign_ten = step.assign(10) no_op = control_flow_ops.no_op() hook = hooks_lib._StopAtCheckpointStepHook( model_dir=tempfile.mkdtemp(), last_step=10) with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: mon_sess.raw_session().run(assign_ten) with test.mock.patch.object(time, 'sleep') as mock_sleep: mon_sess.run(no_op) self.assertTrue(mock_sleep.called) self.assertFalse(mon_sess.should_stop())
def test_build_learning_rate_summaries(self): training.create_global_step() x = {'x': tf.placeholder(tf.float32, [2, 89])} y = tf.constant([[1], [1]]) model = BaseModel( plx.Modes.TRAIN, graph_fn=self.get_dummy_graph_fn(), loss=LogLossConfig(), optimizer=AdadeltaConfig(decay_type='exponential_decay'), model_type=BaseModel.Types.CLASSIFIER, metrics=[], summaries=['learning_rate'], name='test') model(x, y, None, None) # Only var are created summaries_names = list( get_tracked(collection=tf.GraphKeys.SUMMARIES_BY_NAMES).keys()) assert len(summaries_names) == 1 assert summaries_names[0] == 'learning_rate'
def test_do_not_stop_if_checkpoint_is_not_there(self): with ops.Graph().as_default(): step = training.create_global_step() assign_ten = step.assign(10) no_op = control_flow_ops.no_op() hook = hooks_lib._StopAtCheckpointStepHook( model_dir=tempfile.mkdtemp(), last_step=10) with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: mon_sess.raw_session().run(assign_ten) with test.mock.patch.object(time, 'sleep') as mock_sleep: mon_sess.run(no_op) self.assertTrue(mock_sleep.called) self.assertFalse(mon_sess.should_stop())
def _create_and_assert_global_step(self, graph): """Creates and asserts properties of the global step. Args: graph: The graph in which to create the global step tensor. Returns: The global step `tf.Tensor`. """ with variable_scope.variable_scope(self.params["args"].tag): step = training.create_global_step(graph) assert step.dtype.is_integer return step
def _evaluate_model(self, input_fn, hooks=None, checkpoint_path=None, name=''): # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: latest_path = saver.latest_checkpoint(self._model_dir) if not latest_path: error_message = "Could not find trained model at {}.".format(self._model_dir) raise EstimatorNotTrainedError(error_message) checkpoint_path = latest_path # Setup output directory. eval_dir = os.path.join(self._model_dir, 'eval' if not name else 'eval_' + name) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) global_step = training.create_global_step(g) features, labels = input_fn() estimator_spec = self._call_model_fn(features, labels, Modes.EVAL) if MetricKeys.LOSS in estimator_spec.eval_metric_ops: raise ValueError("Metric with name `{}` is not allowed, because Estimator " "already defines a default metric " "with the same name.".format(MetricKeys.LOSS)) estimator_spec.eval_metric_ops[ MetricKeys.LOSS] = metrics_lib.streaming_mean(estimator_spec.loss) update_op, eval_dict = self._extract_metric_update_ops(estimator_spec.eval_metric_ops) if ops.GraphKeys.GLOBAL_STEP in eval_dict: raise ValueError("Metric with name `global_step` is not allowed, because " "Estimator already defines a default metric with the same name.") eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step eval_results = evaluation._evaluate_once( checkpoint_path=checkpoint_path, master=self._config.evaluation_master, scaffold=estimator_spec.scaffold, eval_ops=update_op, final_ops=eval_dict, hooks=hooks, config=self._session_config) self._write_dict_to_summary( output_dir=eval_dir, dictionary=eval_results, current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP]) return eval_results
def export_savedmodel( self, export_dir_base, serving_input_receiver_fn, assets_extra=None, as_text=False, checkpoint_path=None): """Exports inference graph as a SavedModel into given dir. This method builds a new graph by first calling the serving_input_receiver_fn to obtain feature `Tensor`s, and then calling this `Estimator`'s model_fn to generate the model graph based on those features. It restores the given checkpoint (or, lacking that, the most recent checkpoint) into this graph in a fresh session. Finally it creates a timestamped export directory below the given export_dir_base, and writes a `SavedModel` into it containing a single `MetaGraphDef` saved from this session. The exported `MetaGraphDef` will provide one `SignatureDef` for each element of the export_outputs dict returned from the model_fn, named using the same keys. One of these keys is always signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which signature will be served when a serving request does not specify one. For each signature, the outputs are provided by the corresponding `ExportOutput`s, and the inputs are always the input receivers provided by the serving_input_receiver_fn. Extra assets may be written into the SavedModel via the extra_assets argument. This should be a dict, where each key gives a destination path (including the filename) relative to the assets.extra directory. The corresponding value gives the full path of the source file to be copied. For example, the simple case of copying a single file without renaming it is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. Args: export_dir_base: A string containing a directory in which to create timestamped subdirectories containing exported SavedModels. serving_input_receiver_fn: A function that takes no argument and returns a `ServingInputReceiver`. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel, or `None` if no extra assets are needed. as_text: whether to write the SavedModel proto in text format. checkpoint_path: The checkpoint path to export. If `None` (the default), the most recent checkpoint found within the model directory is chosen. Returns: The string path to the exported directory. Raises: ValueError: if no serving_input_receiver_fn is provided, no export_outputs are provided, or no checkpoint can be found. """ if serving_input_receiver_fn is None: raise ValueError('serving_input_receiver_fn must be defined.') with ops.Graph().as_default() as g: training.create_global_step(g) random_seed.set_random_seed(self._config.tf_random_seed) serving_input_receiver = serving_input_receiver_fn() # Call the model_fn and collect the export_outputs. estimator_spec = self._call_model_fn( features=serving_input_receiver.features, labels=None, mode=model_fn_lib.ModeKeys.PREDICT) # Build the SignatureDefs from receivers and all outputs signature_def_map = build_all_signature_defs( serving_input_receiver.receiver_tensors, estimator_spec.export_outputs) if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % self._model_dir) export_dir = get_timestamped_export_dir(export_dir_base) # TODO(soergel): Consider whether MonitoredSession makes sense here with tf_session.Session() as session: saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( sharded=True) saver_for_restore.restore(session, checkpoint_path) # TODO(b/36111876): replace legacy_init_op with main_op mechanism # pylint: disable=protected-access local_init_op = ( estimator_spec.scaffold.local_init_op or monitored_session.Scaffold._default_local_init_op()) # pylint: enable=protected-access # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=local_init_op) builder.save(as_text) # Add the extra assets if assets_extra: assets_extra_path = os.path.join(compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) gfile.MakeDirs(dest_path) gfile.Copy(source, dest_absolute) return export_dir
def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name, with_bypass, delay, use_ema): """Tests quantization: inputs -> FC with batch norm -> Activation. Args: activation: Callable that returns an Operation, a factory method for the Activation. activation_op_name: String, name of the Activation operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() with graph.as_default(): training.create_global_step(graph) batch_size, depth = 5, 256 inputs = array_ops.zeros((batch_size, depth)) out_depth = 256 if with_bypass else 128 scope = 'test/test2' if with_bypass else 'test' node = fully_connected(inputs, out_depth, weights_initializer=self._WeightInit(0.03), activation_fn=None, normalizer_fn=batch_norm, normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, scope=scope) # Manually fold the batch norm. weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0] bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') .outputs[0]) mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold') fc_fold = math_ops.matmul(inputs, mul_fold, name=scope + '/MatMul_Fold') bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') .outputs[0]) add_fold = math_ops.add(fc_fold, bn_bias, name=scope + '/add_fold') # Manually add a bypass (optionaly) and an activation. if with_bypass: node = math_ops.add(inputs, add_fold, name='test/Add') else: node = add_fold node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' if delay and use_ema else '/MatMul_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', scope + '/add_fold' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add') self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) act_quant = graph.get_operation_by_name('test/act_quant/' + quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ 'test/act_quant/min/read', 'test/act_quant/max/read', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) output_op_name = ('test/act_quant/delayed_quant/Switch_1' if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, with_bypass, delay): """Tests quantization: inputs -> Conv2d no batch norm -> Activation. Args: activation: Callable that returns an Operation, a factory method for the Activation. activation_op_name: String, name of the Activation operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. """ graph = ops.Graph() with graph.as_default(): training.create_global_step(graph) batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 out_depth = 3 if with_bypass else 32 activation_fn = None if with_bypass else activation scope = 'test/test2' if with_bypass else 'test' node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', weights_initializer=self._WeightInit(0.09), activation_fn=activation_fn, scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') quantize.Quantize(graph, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum', scope + '/weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + '/convolution' self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', scope + '/BiasAdd' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add') self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) act_quant = graph.get_operation_by_name('test/act_quant/' + quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ 'test/act_quant/min/read', 'test/act_quant/max/read', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) output_op_name = ('test/act_quant/delayed_quant/Switch_1' if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def _testQuantize_DepthwiseConv2dWithBatchNorm( self, activation, activation_op_name, with_bypass, delay, fused_batch_norm, use_ema): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: activation: Callable that returns an Operation, a factory method for the Activation. activation_op_name: String, name of the Activation operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() with graph.as_default(): training.create_global_step(graph) batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 scope = 'test/test2' if with_bypass else 'test' node = separable_conv2d( inputs, None, [5, 5], stride=stride, depth_multiplier=1.0, padding='SAME', weights_initializer=self._WeightInit(0.09), activation_fn=None, normalizer_fn=batch_norm, normalizer_params=self._BatchNormParams(fused_batch_norm), scope=scope) # Manually add a bypass (optionaly) and an activation. if with_bypass: node = math_ops.add(inputs, node, name='test/Add') node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') fold_batch_norms.FoldBatchNorms(graph) quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' if delay and use_ema else '/depthwise_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', scope + '/add_fold' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add') self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) act_quant = graph.get_operation_by_name('test/act_quant/' + quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ 'test/act_quant/min/read', 'test/act_quant/max/read', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) output_op_name = ('test/act_quant/delayed_quant/Switch_1' if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, with_bypass, delay): """Tests quantization: inputs -> Conv2d no batch norm -> Activation. Args: activation: Callable that returns an Operation, a factory method for the Activation. activation_op_name: String, name of the Activation operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. """ graph = ops.Graph() with graph.as_default(): training.create_global_step(graph) batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 out_depth = 3 if with_bypass else 32 activation_fn = None if with_bypass else activation scope = 'test/test2' if with_bypass else 'test' node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', weights_initializer=self._WeightInit(0.09), activation_fn=activation_fn, scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') quantize.Quantize(graph, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ scope + '/weights_quant/AssignMinLast', scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + '/Conv2D' self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ scope + '/conv_quant/AssignMinEma', scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add') self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) act_quant = graph.get_operation_by_name('test/act_quant/' + quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) output_op_name = ('test/act_quant/delayed_quant/Switch_1' if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def export_savedmodel(self, export_dir_base, serving_input_receiver_fn, assets_extra=None, as_text=False, checkpoint_path=None): """Exports inference graph as a SavedModel into given dir. This method builds a new graph by first calling the serving_input_receiver_fn to obtain feature `Tensor`s, and then calling this `Estimator`'s model_fn to generate the model graph based on those features. It restores the given checkpoint (or, lacking that, the most recent checkpoint) into this graph in a fresh session. Finally it creates a timestamped export directory below the given export_dir_base, and writes a `SavedModel` into it containing a single `MetaGraphDef` saved from this session. The exported `MetaGraphDef` will provide one `SignatureDef` for each element of the export_outputs dict returned from the model_fn, named using the same keys. One of these keys is always signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which signature will be served when a serving request does not specify one. For each signature, the outputs are provided by the corresponding `ExportOutput`s, and the inputs are always the input receivers provided by the serving_input_receiver_fn. Extra assets may be written into the SavedModel via the extra_assets argument. This should be a dict, where each key gives a destination path (including the filename) relative to the assets.extra directory. The corresponding value gives the full path of the source file to be copied. For example, the simple case of copying a single file without renaming it is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. Args: export_dir_base: A string containing a directory in which to create timestamped subdirectories containing exported SavedModels. serving_input_receiver_fn: A function that takes no argument and returns a `ServingInputReceiver`. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel, or `None` if no extra assets are needed. as_text: whether to write the SavedModel proto in text format. checkpoint_path: The checkpoint path to export. If `None` (the default), the most recent checkpoint found within the model directory is chosen. Returns: The string path to the exported directory. Raises: ValueError: if no serving_input_receiver_fn is provided, no export_outputs are provided, or no checkpoint can be found. """ if serving_input_receiver_fn is None: raise ValueError('serving_input_receiver_fn must be defined.') with ops.Graph().as_default() as g: training.create_global_step(g) random_seed.set_random_seed(self._config.tf_random_seed) serving_input_receiver = serving_input_receiver_fn() # Call the model_fn and collect the export_outputs. estimator_spec = self._call_model_fn( features=serving_input_receiver.features, labels=None, mode=model_fn_lib.ModeKeys.PREDICT) # Build the SignatureDefs from receivers and all outputs signature_def_map = build_all_signature_defs( serving_input_receiver.receiver_tensors, estimator_spec.export_outputs) if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % self._model_dir) export_dir = get_timestamped_export_dir(export_dir_base) # TODO(soergel): Consider whether MonitoredSession makes sense here with tf_session.Session() as session: saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( sharded=True) saver_for_restore.restore(session, checkpoint_path) # TODO(b/36111876): replace legacy_init_op with main_op mechanism # pylint: disable=protected-access local_init_op = ( estimator_spec.scaffold.local_init_op or monitored_session.Scaffold._default_local_init_op()) # pylint: enable=protected-access # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=local_init_op) builder.save(as_text) # Add the extra assets if assets_extra: assets_extra_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join( compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) gfile.MakeDirs(dest_path) gfile.Copy(source, dest_absolute) return export_dir
def _testQuantize_DepthwiseConv2dWithBatchNorm( self, activation, activation_op_name, with_bypass, delay, fused_batch_norm, use_ema): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: activation: Callable that returns an Operation, a factory method for the Activation. activation_op_name: String, name of the Activation operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() with graph.as_default(): training.create_global_step(graph) batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 scope = 'test/test2' if with_bypass else 'test' node = separable_conv2d( inputs, None, [5, 5], stride=stride, depth_multiplier=1.0, padding='SAME', weights_initializer=self._WeightInit(0.09), activation_fn=None, normalizer_fn=batch_norm, normalizer_params=self._BatchNormParams(fused_batch_norm), scope=scope) # Manually add a bypass (optionaly) and an activation. if with_bypass: node = math_ops.add(inputs, node, name='test/Add') node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') fold_batch_norms.FoldBatchNorms(graph) quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ scope + '/weights_quant/' + ('AssignMinEma' if use_ema else 'AssignMinLast'), scope + '/weights_quant/' + ('AssignMaxEma' if use_ema else 'AssignMaxLast'), scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' if delay and use_ema else '/depthwise_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ scope + '/conv_quant/AssignMinEma', scope + '/conv_quant/AssignMaxEma', scope + '/add_fold' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add') self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) act_quant = graph.get_operation_by_name('test/act_quant/' + quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) output_op_name = ('test/act_quant/delayed_quant/Switch_1' if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def _testQuantize_DepthwiseConv2dWithBatchNorm( self, activation, activation_op_name, with_bypass, delay, use_ema): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: activation: Callable that returns an Operation, a factory method for the Activation. activation_op_name: String, name of the Activation operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() with graph.as_default(): training.create_global_step(graph) batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 scope = 'test/test2' if with_bypass else 'test' node = separable_conv2d(inputs, None, [5, 5], stride=stride, depth_multiplier=1.0, padding='SAME', weights_initializer=self._WeightInit(0.09), activation_fn=None, normalizer_fn=batch_norm, normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, scope=scope) # Manually fold the batch norm. weights = (graph.get_operation_by_name(scope + '/depthwise_weights/read') .outputs[0]) bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') .outputs[0]) new_shape = [ weights.get_shape().as_list()[2], weights.get_shape().as_list()[3] ] bn_mult_reshaped = array_ops.reshape( bn_mult, new_shape, name=scope + '/gamma_reshape') mul_fold = math_ops.multiply( weights, bn_mult_reshaped, name=scope + '/mul_fold') stride = [1, stride, stride, 1] conv_fold = nn_ops.depthwise_conv2d( input=inputs, filter=mul_fold, padding='SAME', strides=stride, name=scope + '/depthwise_Fold') bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') .outputs[0]) add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold') # Manually add a bypass (optionaly) and an activation. if with_bypass: node = math_ops.add(inputs, add_fold, name='test/Add') else: node = add_fold node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' if delay and use_ema else '/depthwise_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', scope + '/add_fold' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add') self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) act_quant = graph.get_operation_by_name('test/act_quant/' + quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ 'test/act_quant/min/read', 'test/act_quant/max/read', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) output_op_name = ('test/act_quant/delayed_quant/Switch_1' if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])