def _test_multi_record_input_generator(self, input_generator, is_dataset=False): feature_spec = tensorspec_utils.TensorSpecStruct() feature_spec.state = tensorspec_utils.ExtendedTensorSpec( shape=(64, 64, 3), dtype=tf.uint8, name='state/image', data_format='jpeg', dataset_key='d1') feature_spec.action = tensorspec_utils.ExtendedTensorSpec( shape=(2), dtype=tf.float32, name='pose', dataset_key='d1') label_spec = tensorspec_utils.TensorSpecStruct() label_spec.reward = tensorspec_utils.ExtendedTensorSpec( shape=(), dtype=tf.float32, name='reward', dataset_key='d1') label_spec.reward_2 = tensorspec_utils.ExtendedTensorSpec( shape=(), dtype=tf.float32, name='reward', dataset_key='d2') input_generator.set_feature_specifications(feature_spec, feature_spec) input_generator.set_label_specifications(label_spec, label_spec) np_features, np_labels = input_generator.create_dataset_input_fn( mode=tf.estimator.ModeKeys.TRAIN)().make_one_shot_iterator( ).get_next() np_features = tensorspec_utils.validate_and_pack(feature_spec, np_features, ignore_batch=True) np_labels = tensorspec_utils.validate_and_pack(label_spec, np_labels, ignore_batch=True) self.assertAllEqual([2, 64, 64, 3], np_features.state.shape) self.assertAllEqual([2, 2], np_features.action.shape) self.assertAllEqual((2, ), np_labels.reward.shape) self.assertAllEqual((2, ), np_labels.reward_2.shape)
def preprocess( self, features, labels, mode ): """The function which preprocesses the features and labels per example. Note, this function performs the boilerplate packing and flattening and verification of the features and labels according to our spec. The actual preprocessing is performed by _preprocess_fn. Args: features: The features of a single example. labels: (Optional None) The labels of a single example. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. Returns: features_preprocessed: The preprocessed and flattened features verified to fulfill our output specs. labels_preprocessed: (Optional None) The preprocessed and flattened labels verified to fulfill our output specs. """ # First, we verify that the input features and labels fulfill our spec. # We further pack the flattened features and labels to our (hierarchical) # specification.: features = tensorspec_utils.validate_and_pack( expected_spec=self.get_in_feature_specification(mode), actual_tensors_or_spec=features, ignore_batch=True) if labels is not None: labels = tensorspec_utils.validate_and_pack( expected_spec=self.get_in_label_specification(mode), actual_tensors_or_spec=labels, ignore_batch=True) features_preprocessed, labels_preprocessed = self._preprocess_fn( features=features, labels=labels, mode=mode) features_preprocessed = tensorspec_utils.validate_and_flatten( expected_spec=self.get_out_feature_specification(mode), actual_tensors_or_spec=features_preprocessed, ignore_batch=True) if labels_preprocessed: labels_preprocessed = tensorspec_utils.validate_and_flatten( expected_spec=self.get_out_label_specification(mode), actual_tensors_or_spec=labels_preprocessed, ignore_batch=True) return features_preprocessed, labels_preprocessed
def test_validate_flatten_and_pack(self): # An example data pipeline. # Some input generator creates input features according to some spec. input_features = utils.make_placeholders(mock_nested_spec) # Assume a preprocessor has altered these input_features and we want # to pass the data on to the next stage, then we simply assure that # our output is according to our spec and flatten. flat_input_features = utils.validate_and_flatten( mock_nested_optional_spec, input_features, ignore_batch=True) utils.assert_required( mock_nested_optional_spec, input_features, ignore_batch=True) # Then e.g. the model_fn receives the flat_input_spec and validates # that it is according to it's requirements and packs it back into the # spec structure. output_features = utils.validate_and_pack( mock_nested_subset_spec, flat_input_features, ignore_batch=True) utils.assert_required( mock_nested_subset_spec, output_features, ignore_batch=True)
def model_fn(self, features, labels, mode, config = None, params = None): """Estimator model_fn. Args: features: This is the first item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. labels: This is the second item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig) Will receive what is passed to Estimator in config parameter, or the default config (tf.estimator.RunConfig). Allows updating things in your model_fn based on configuration such as num_ps_replicas, or model_dir. params: An optional dict of hyper parameters that will be passed into input_fn and model_fn. Keys are names of parameters, values are basic python types. There are reserved keys for TPUEstimator, including 'batch_size'. Raises: ValueError: If the mode key is not supported, not in [PREDICT, TRAIN, EVAL]. Returns: An EstimatorSpec. """ features = tensorspec_utils.validate_and_pack( expected_spec=self.get_feature_specification(mode), actual_tensors_or_spec=features, ignore_batch=True) if labels: labels = tensorspec_utils.validate_and_pack( expected_spec=self.get_label_specification(mode), actual_tensors_or_spec=labels, ignore_batch=True) inference_outputs = self.inference_network_fn(features, labels, mode, config, params) update_ops = None if isinstance(inference_outputs, tuple): if len(inference_outputs) != 2: raise ValueError('Unknown output of inference_network_fn: ' 'tuple of length %d' % len(inference_outputs)) outputs = inference_outputs[0] update_ops = inference_outputs[1] inference_outputs = outputs if mode == tf.estimator.ModeKeys.PREDICT: model_fn_results = self.create_export_outputs_fn(features, inference_outputs, mode, config, params) export_outputs = None if isinstance(model_fn_results, tuple): predictions = model_fn_results[0] export_outputs = model_fn_results[1] elif isinstance(model_fn_results, dict): export_outputs = {} if len(model_fn_results) == 1: name, output = list(model_fn_results.items())[0] export_outputs[name] = tf.estimator.export.RegressionOutput(output) export_outputs[tf.saved_model.signature_constants .DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( tf.estimator.export.PredictOutput(model_fn_results)) predictions = model_fn_results else: raise ValueError('The create_export_outputs_fn should return a ' 'tuple(predictions, export_outputs) or predictions.') return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs=export_outputs) train_fn_result = self.model_train_fn(features, labels, inference_outputs, mode, config, params) if isinstance(train_fn_result, tf.Tensor): train_loss = train_fn_result train_outputs = {} elif isinstance(train_fn_result, tuple): train_loss = train_fn_result[0] train_outputs = train_fn_result[1] else: raise ValueError('The model_train_fn should return a ' 'tuple(loss, train_outputs) or loss.') if mode == tf.estimator.ModeKeys.TRAIN: # Create the tf.train.Optimizer. optimizer = self.create_optimizer() train_op = self.create_train_op(train_loss, optimizer, update_ops, train_outputs) self.add_summaries(features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) # Now the optimizer has been created, therefore, the checkpoint could be # initialized. # No new variables are allowed to be added, otherwise # we would not initialize these variables. # Note, this feature is only available for train to bootstrap a model # (partially) from a different model. As soon as this checkpoint is # written all other modes will use the local checkpoint within model_dir. self.maybe_init_from_checkpoint() training_hooks = [] # EstimatorSpec has training_chief_hooks, but TPUEstimatorSpec does not, # so we have to use training_hooks here and check is_chief. if config and config.is_chief: # pytype: disable=attribute-error training_hooks.append( gin_utils.GinConfigSaverHook( config.model_dir, summarize_config=True)) if hasattr(self, 'writer_init_ops'): training_hooks.append(V2SummaryInitHook(self.writer_init_ops[mode])) # `SyncReplicasOptimizer` needs to attach a training hook. if self._sync_replicas_optimizer: training_hooks.append( self._sync_replicas_optimizer.make_session_run_hook( config.is_chief)) # pytype: disable=attribute-error # Return the value of the property first since it might be changed. scaffold_fn = self.scaffold_fn scaffold = scaffold_fn() # In order to export asynchronously the saver has to be registered # in the graph collection. The scaffold function might register a # saver already which is why it is checked here and a saver only # added it has none has been added. if not tf.get_collection(tf.GraphKeys.SAVERS): # TODO(T2R_CONTRIBUTORS): Switch to using gin config for all saver params. keep_checkpoint_every_n_hours = None max_to_keep = None if config is not None: keep_checkpoint_every_n_hours = config.keep_checkpoint_every_n_hours max_to_keep = config.keep_checkpoint_max saver = gin_configurable_saver( keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, max_to_keep=max_to_keep, ) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) return tf.estimator.EstimatorSpec( mode=mode, loss=train_loss, train_op=train_op, training_hooks=training_hooks, scaffold=scaffold) if mode == tf.estimator.ModeKeys.EVAL: self.add_summaries(features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) eval_metrics = self.model_eval_fn(features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) evaluation_hooks = self.get_eval_hooks(config, params) if config and config.is_chief: # pytype: disable=attribute-error eval_name = params.get('eval_name', 'eval') # pytype: disable=attribute-error evaluation_hooks.append( gin_utils.GinConfigSaverHook( os.path.join(config.model_dir, eval_name), summarize_config=True)) if hasattr(self, 'writer_init_ops'): evaluation_hooks.append(V2SummaryInitHook(self.writer_init_ops[mode])) return tf.estimator.EstimatorSpec( mode=mode, loss=train_loss, eval_metric_ops=eval_metrics, evaluation_hooks=evaluation_hooks) raise ValueError('The mode {} is not supported yet.'.format(mode))
def model_fn(self, features, labels, mode, config=None, params=None): """Estimator model_fn. Note, this function overwrites the model_fn of the wrapped t2r_model since is replaces specifications with their TPU corresponding calls and introduces additional casting conversion after the specification has been verified. Args: features: This is the first item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. labels: This is the second item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. config: (Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig) Will receive what is passed to Estimator in config parameter, or the default config (tf.estimator.RunConfig). Allows updating things in your model_fn based on configuration such as num_ps_replicas, or model_dir. params: An optional dict of hyper parameters that will be passed into input_fn and model_fn. Keys are names of parameters, values are basic python types. There are reserved keys for TPUEstimator, including 'batch_size'. Raises: ValueError: If the mode key is not supported, not in [PREDICT, TRAIN, EVAL]. Returns: A TPUEstimatorSpec. """ features = tensorspec_utils.validate_and_pack( expected_spec=self.get_feature_specification(mode), actual_tensors_or_spec=features, ignore_batch=True) if labels: labels = tensorspec_utils.validate_and_pack( expected_spec=self.get_label_specification(mode), actual_tensors_or_spec=labels, ignore_batch=True) # In order to support both TPU and CPU for inference, tensors # with dtype=bfloat16 will be casted to float32. # Note, despite casting the benefit of bfloat16 are still maintained # for TPUs since this operation is a noop on this platform. # See http://shortn/_TTg3ZyATRo for rationale. features = tensorspec_utils.cast_bfloat16_to_float32(features) if labels is not None: labels = tensorspec_utils.cast_bfloat16_to_float32(labels) inference_outputs = self._t2r_model.inference_network_fn( features, labels, mode, config, params) if mode == tf.estimator.ModeKeys.PREDICT: model_fn_results = self._t2r_model.create_export_outputs_fn( features, inference_outputs, mode, config, params) export_outputs = None if isinstance(model_fn_results, tuple): predictions = model_fn_results[0] export_outputs = model_fn_results[1] elif isinstance(model_fn_results, dict): export_outputs = {} if len(model_fn_results) == 1: name, output = model_fn_results.items()[0] export_outputs[ name] = tf.estimator.export.RegressionOutput(output) export_outputs[tf.saved_model.signature_constants. DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( tf.estimator.export.PredictOutput( model_fn_results)) predictions = model_fn_results else: raise ValueError( 'The create_export_outputs_fn should return a ' 'tuple(predictions, export_outputs) or predictions.') return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, export_outputs=export_outputs) train_fn_result = self._t2r_model.model_train_fn( features, labels, inference_outputs, mode, config, params) if isinstance(train_fn_result, tf.Tensor): train_loss = train_fn_result train_outputs = {} elif isinstance(train_fn_result, tuple): train_loss = train_fn_result[0] train_outputs = train_fn_result[1] else: raise ValueError('The model_train_fn should return a ' 'tuple(loss, train_outputs) or loss.') if mode == tf.estimator.ModeKeys.TRAIN: # Create the tf.train.Optimizer. optimizer = get_cross_shard_optimizer( self._t2r_model.create_optimizer()) # Required for batch norm usage. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = self._t2r_model.create_train_op( train_loss, optimizer) self._t2r_model.add_summaries(features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) # For TPUs the init has to happen in a scaffold function. Since the model # already contains one implementation which is internal to the model # this call is simply wrapped. # No new variables are allowed to be added, otherwise # we would not initialize these variables. # Note, this feature is only available for train to bootstrap a model # (partially) from a different model. As soon as this checkpoint is # written all other modes will use the local checkpoint within # model_dir. def create_scaffold_fn(): """Creates a scaffold instance.""" self._t2r_model.maybe_init_from_checkpoint() # Return the value of the property first since it might be changed. scaffold_fn = self._t2r_model.scaffold_fn scaffold = scaffold_fn() # In order to export asynchronously the saver has to be registered # in the graph collection. The scaffold function might register a # saver already which is why it is checked here and a saver only # added it has none has been added. if not tf.get_collection(tf.GraphKeys.SAVERS): # TODO(T2R_CONTRIBUTORS): Switch to using gin config for all saver params. keep_checkpoint_every_n_hours = None max_to_keep = None if config is not None: keep_checkpoint_every_n_hours = config.keep_checkpoint_every_n_hours max_to_keep = config.keep_checkpoint_max saver = abstract_model.gin_configurable_saver( keep_checkpoint_every_n_hours= keep_checkpoint_every_n_hours, max_to_keep=max_to_keep, ) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) return scaffold training_hooks = [] # EstimatorSpec has training_chief_hooks, but TPUEstimatorSpec does not, # so we have to use training_hooks here and check is_chief. if config and config.is_chief: # pytype: disable=attribute-error training_hooks.append( gin_utils.GinConfigSaverHook(config.model_dir, summarize_config=True)) return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=train_loss, train_op=train_op, training_hooks=training_hooks, scaffold_fn=create_scaffold_fn) if mode == tf.estimator.ModeKeys.EVAL: self._t2r_model.add_summaries(features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) eval_metrics = self._t2r_model.model_eval_fn( features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) evaluation_hooks = self._t2r_model.get_eval_hooks(config, params) if config and config.is_chief: # pytype: disable=attribute-error eval_name = params.get('eval_name', 'eval') # pytype: disable=attribute-error evaluation_hooks.append( gin_utils.GinConfigSaverHook(os.path.join( config.model_dir, eval_name), summarize_config=True)) return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=train_loss, eval_metrics=eval_metrics, evaluation_hooks=evaluation_hooks) raise ValueError('The mode {} is not supported yet.'.format(mode))
def parse_tf_example_fn(*input_values): """Maps string tensors (serialized TFExamples) to parsed tensors. Args: *input_values: A (string tensor,) tuple if mapping from a RecordIODataset or TFRecordDataset, or a (key, string tensor) tuple if mapping from a SSTableDataset, or (Dict[dataset_key, values],) if mapping from multiple datasets. Returns: features: Collection of tensors conforming to feature_tspec. labels: Collection of tensors conforming to label_tspec. Raises: ValueError: If dtype other than uint8 or uint16 is supplied for image specs. """ dict_extracted = _get_sstable_proto_dict(*input_values) def parse_wrapper(example, spec_dict): """Wrap tf.parse_example to support bfloat16 dtypes. This allows models which declare bfloat16 as inputs to not require an additional preprocessing step to cast all inputs from float32 to bfloat16. Consider this to be analogous to JPEG decoding in the data step. Args: example: TFExample spec_dict: Dictionary of feature name -> tf.FixedLenFeature Returns: Parsed feature map """ def is_bfloat_feature(value): return value.dtype == tf.bfloat16 def maybe_map_bfloat(value): """Maps bfloat16 to float32.""" if is_bfloat_feature(value): if isinstance(value, tf.FixedLenFeature): return tf.FixedLenFeature( value.shape, tf.float32, default_value=value.default_value) elif isinstance(value, tf.VarLenFeature): return tf.VarLenFeature( value.shape, tf.float32, default_value=value.default_value) else: return tf.FixedLenSequenceFeature( value.shape, tf.float32, default_value=value.default_value) return value # Change bfloat features to float32 for parsing. new_spec_dict = { k: maybe_map_bfloat(v) for k, v in six.iteritems(spec_dict) } for k, v in six.iteritems(new_spec_dict): if v.dtype not in [tf.float32, tf.string, tf.int64]: raise ValueError('Feature specification with invalid data type for ' 'tf.Example parsing: "%s": %s' % (k, v.dtype)) # Separate new_spec_dict into Context and Sequence features. In the event # that there are no SequenceFeatures, the context_features dictionary # (containing FixedLenFeatures) is passed to tf.parse_examples. context_features, sequence_features = {}, {} for k, v in six.iteritems(new_spec_dict): v = maybe_map_bfloat(v) if isinstance(v, tf.FixedLenSequenceFeature): sequence_features[k] = v elif isinstance(v, tf.FixedLenFeature): context_features[k] = v elif isinstance(v, tf.VarLenFeature): context_features[k] = v else: raise ValueError( 'Only FixedLenFeature and FixedLenSequenceFeature are currently ' 'supported.') # If there are any sequence features, we use parse_sequence_example. if sequence_features: # Filter out '_length' context features; don't parse them from records. for parse_name in sequence_features: # Sometimes, the '_length' context feature doesn't exist. if parse_name + '_length' in context_features: del context_features[parse_name + '_length'] result, sequence_result, feature_lengths = tf.io.parse_sequence_example( example, context_features=context_features, sequence_features=sequence_features) result.update(sequence_result) # Augment the parsed tensors with feature length tensors. for parse_name, length_tensor in feature_lengths.items(): result[parse_name + '_length'] = length_tensor else: result = tf.parse_example(example, context_features) to_convert = [ k for k, v in six.iteritems(spec_dict) if is_bfloat_feature(v) ] for c in to_convert: result[c] = tf.cast(result[c], tf.bfloat16) return result prepend_keys = lambda d, pre: {pre + k: v for k, v in list(d.items())} # Parse each dataset's tensors. Parsed results from parse_wrapper get # dataset_key prepended to ensure uniqueness of keys among datasets. parsed_tensors = {} # {Prepended parsed key : TensorSpecs} for all datasets. Will contain # '_length' TensorSpecs that won't actually get parsed. We filter those out # before passing to the parse_sequence_example call. tensor_spec_dict = {} for dataset_key, example_proto in dict_extracted.items(): # Parsed key to Feature Specs (retained only for this dataset). tensor_dict = {} sub_feature_tspec = tensorspec_utils.filter_spec_structure_by_dataset( feature_tspec, dataset_key) feature_dict, feature_tspec_dict = ( tensorspec_utils.tensorspec_to_feature_dict( sub_feature_tspec, decode_images=decode_images)) tensor_dict.update(feature_dict) tensor_spec_dict.update(prepend_keys(feature_tspec_dict, dataset_key)) if label_tspec is not None: sub_label_tspec = tensorspec_utils.filter_spec_structure_by_dataset( label_tspec, dataset_key) label_dict, label_tspec_dict = ( tensorspec_utils.tensorspec_to_feature_dict( sub_label_tspec, decode_images=decode_images)) tensor_dict.update(label_dict) tensor_spec_dict.update(prepend_keys(label_tspec_dict, dataset_key)) for key, parsed in parse_wrapper(example_proto, tensor_dict).items(): parsed_tensors[dataset_key + key] = parsed # At this point, all tensors have been parsed into a single flat map. # Interpret encoded images. def decode_image(key, raw_bytes): """Decodes single or batches of JPEG- or PNG-encoded string tensors. Args: key: String key specified in feature map. raw_bytes: String tensor to decode as JPEG or PNG. Returns: Decoded image tensor with shape specified by tensor spec. Raises: ValueError: If dtype other than uint8 or uint16 is supplied for image specs. """ img_batch_dims = tf.shape(raw_bytes) # The spatial + channel dimensions of a single image, assumed to be the # last 3 entries of the image feature's tensor spec. if len(tensor_spec_dict[key].shape) < 3: raise ValueError( 'Shape of tensor spec for image feature "%s" must ' 'be 3 dimensional (h, w, c), but is %s' % (tensor_spec_dict[key].name, tensor_spec_dict[key].shape)) single_img_dims = tensor_spec_dict[key].shape[-3:] num_channels = single_img_dims[2] if num_channels not in [1, 3]: raise ValueError( 'Last dimension of shape of tensor spec for image ' 'feature "%s" must 1 or 3, but the shape is %s' % (tensor_spec_dict[key].name, tensor_spec_dict[key].shape)) # Collapse (possibly multiple) batch dims to a single batch dim for # decoding purposes. raw_bytes = tf.reshape(raw_bytes, [-1]) data_type = tensor_spec_dict[key].dtype if data_type not in SUPPORTED_PIXEL_ENCODINGS: raise ValueError('Decoding an image requires tensorspec.data_type ' 'to be uint8 or uint16.') def _decode_images(image_bytes): """Decode single image.""" def _zero_image(): return tf.zeros(single_img_dims, dtype=data_type) def _tf_decode_image(): return tf.image.decode_image( image_bytes, channels=num_channels, dtype=data_type) image = tf.cond( tf.equal(image_bytes, ''), _zero_image, _tf_decode_image) image.set_shape(single_img_dims) return image img = tf.map_fn( _decode_images, raw_bytes, dtype=data_type, back_prop=False) img.set_shape(raw_bytes.shape.concatenate(single_img_dims)) # Expand the collapsed batch dim back to the original img_batch_dims. img = tf.reshape(img, tf.concat([img_batch_dims, single_img_dims], 0)) return img # Convert all sparse tensors to dense tensors. for key, val in parsed_tensors.items(): tensor_spec = tensor_spec_dict[key] if tensor_spec.varlen_default_value is not None: if tensorspec_utils.is_encoded_image_spec(tensor_spec): default_value = '' else: default_value = tf.cast( tf.constant(tensor_spec.varlen_default_value), dtype=tensor_spec.dtype) parsed_tensors[key] = tf.sparse.to_dense( val, default_value=default_value) # Ensure that all images are properly decoded. for key, val in parsed_tensors.items(): tensor_spec = tensor_spec_dict[key] if tensorspec_utils.is_encoded_image_spec(tensor_spec) and decode_images: parsed_tensors[key] = decode_image(key, val) if tensor_spec.dtype not in SUPPORTED_PIXEL_ENCODINGS: raise ValueError('Encoded images with key {} must be ' 'specified with uint8 or uint16 dtype.'.format(key)) # Pad all varlen features to the corrensponding spec. for key, val in parsed_tensors.items(): tensor_spec = tensor_spec_dict[key] if tensor_spec.varlen_default_value is not None: parsed_tensors[key] = tensorspec_utils.pad_or_clip_tensor_to_spec_shape( val, tensor_spec) # Ensure that we have a consistent ordered mapping despite the underlying # spec structure. flat_feature_tspec = tensorspec_utils.TensorSpecStruct( sorted(tensorspec_utils.flatten_spec_structure(feature_tspec).items())) # Using the flat spec structure we allow to map the same parsed_tensor # to multiple features or labels. Note, the spec structure ensures that # the corresponding tensorspecs are iddentical in such cases. features = tensorspec_utils.TensorSpecStruct([ (key, parsed_tensors[value.dataset_key + value.name]) for key, value in flat_feature_tspec.items() ]) features = tensorspec_utils.validate_and_pack( flat_feature_tspec, features, ignore_batch=True) if label_tspec is not None: # Ensure that we have a consistent ordered mapping despite the underlying # spec structure. flat_label_tspec = tensorspec_utils.TensorSpecStruct( sorted(tensorspec_utils.flatten_spec_structure(label_tspec).items())) labels = tensorspec_utils.TensorSpecStruct([ (key, parsed_tensors[value.dataset_key + value.name]) for key, value in flat_label_tspec.items() ]) labels = tensorspec_utils.validate_and_pack( flat_label_tspec, labels, ignore_batch=True) return features, labels return features
def model_fn(self, features, labels, mode, config=None, params=None): """Estimator model_fn. Args: features: This is the first item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. labels: This is the second item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. config: (Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig) Will receive what is passed to Estimator in config parameter, or the default config (tf.estimator.RunConfig). Allows updating things in your model_fn based on configuration such as num_ps_replicas, or model_dir. params: An optional dict of hyper parameters that will be passed into input_fn and model_fn. Keys are names of parameters, values are basic python types. There are reserved keys for TPUEstimator, including 'batch_size'. Raises: ValueError: If the mode key is not supported, not in [PREDICT, TRAIN, EVAL]. Returns: An EstimatorSpec. """ features = tensorspec_utils.validate_and_pack( expected_spec=self.get_feature_specification(mode), actual_tensors_or_spec=features, ignore_batch=True) if labels: labels = tensorspec_utils.validate_and_pack( expected_spec=self.get_label_specification(mode), actual_tensors_or_spec=labels, ignore_batch=True) inference_outputs = self.inference_network_fn(features, labels, mode, config, params) # After inference_fn no new variables are allowed to be added, otherwise # we would not initialize these variables. self.maybe_init_from_checkpoint() if mode == tf.estimator.ModeKeys.PREDICT: model_fn_results = self.create_export_outputs_fn( features, inference_outputs, mode, config, params) export_outputs = None if isinstance(model_fn_results, tuple): predictions = model_fn_results[0] export_outputs = model_fn_results[1] elif isinstance(model_fn_results, dict): export_outputs = {} if len(model_fn_results) == 1: name, output = list(model_fn_results.items())[0] export_outputs[ name] = tf.estimator.export.RegressionOutput(output) export_outputs[tf.saved_model.signature_constants. DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( tf.estimator.export.PredictOutput( model_fn_results)) predictions = model_fn_results else: raise ValueError( 'The create_export_outputs_fn should return a ' 'tuple(predictions, export_outputs) or predictions.') return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs) train_fn_result = self.model_train_fn(features, labels, inference_outputs, mode, config, params) if isinstance(train_fn_result, tf.Tensor): train_loss = train_fn_result train_outputs = {} elif isinstance(train_fn_result, tuple): train_loss = train_fn_result[0] train_outputs = train_fn_result[1] else: raise ValueError('The model_train_fn should return a ' 'tuple(loss, train_outputs) or loss.') if mode == tf.estimator.ModeKeys.TRAIN: # Create the tf.train.Optimizer. optimizer = self.create_optimizer() # Required for batch norm usage. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = self.create_train_op(train_loss, optimizer) self.add_summaries(features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) training_hooks = [] # EstimatorSpec has training_chief_hooks, but TPUEstimatorSpec does not, # so we have to use training_hooks here and check is_chief. if config and config.is_chief: # pytype: disable=attribute-error training_hooks.append( gin_utils.GinConfigSaverHook(config.model_dir, summarize_config=True)) # `SyncReplicasOptimizer` needs to attach a training hook. if self._sync_replicas_optimizer: training_hooks.append( self._sync_replicas_optimizer.make_session_run_hook( config.is_chief)) # pytype: disable=attribute-error return tf.estimator.EstimatorSpec(mode=mode, loss=train_loss, train_op=train_op, training_hooks=training_hooks, scaffold=self._scaffold_fn()) if mode == tf.estimator.ModeKeys.EVAL: self.add_summaries(features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) eval_metrics = self.model_eval_fn(features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) evaluation_hooks = self.get_eval_hooks(config, params) if config and config.is_chief: # pytype: disable=attribute-error eval_name = params.get('eval_name', 'eval') # pytype: disable=attribute-error evaluation_hooks.append( gin_utils.GinConfigSaverHook(os.path.join( config.model_dir, eval_name), summarize_config=True)) return tf.estimator.EstimatorSpec( mode=mode, loss=train_loss, eval_metric_ops=eval_metrics, evaluation_hooks=evaluation_hooks) raise ValueError('The mode {} is not supported yet.'.format(mode))
def parse_tf_example_fn(*input_values): """Maps string tensor to a parsed TFExample. Args: *input_values: A string tensor if mapping from a RecordIODataset or TFRecordDataset, or a (key, string tensor) tuple if mapping from a SSTableDataset. Returns: features: Collection of tensors conforming to feature_tspec. labels: Collection of tensors conforming to label_tspec. Raises: ValueError: If dtype other than uint8 is supplied for image specs. """ if len(input_values) == 2: # Assume an SSTable key, value pair. _, example_proto = input_values else: example_proto, = input_values def parse_wrapper(example, spec_dict): """Wrap tf.parse_example to support bfloat16 dtypes. This allows models which declare bfloat16 as inputs to not require an additional preprocessing step to cast all inputs from float32 to bfloat16. Consider this to be analogous to JPEG decoding in the data step. Args: example: TFExample spec_dict: Dictionary of feature name -> tf.FixedLenFeature Returns: Parsed feature map """ def is_bfloat_feature(value): return value.dtype == tf.bfloat16 def maybe_map_bfloat(value): if is_bfloat_feature(value): if isinstance(value, tf.FixedLenFeature): return tf.FixedLenFeature( value.shape, tf.float32, default_value=value.default_value) else: return tf.FixedLenSequenceFeature( value.shape, tf.float32, default_value=value.default_value) return value # Change bfloat features to float32 for parsing. new_spec_dict = { k: maybe_map_bfloat(v) for k, v in six.iteritems(spec_dict) } for k, v in six.iteritems(new_spec_dict): if v.dtype not in [tf.float32, tf.string, tf.int64]: raise ValueError( 'Feature specification with invalid data type for ' 'tf.Example parsing: "%s": %s' % (k, v.dtype)) # Separate new_spec_dict into Context and Sequence features. In the event # that there are no SequenceFeatures, the context_features dictionary # (containing FixedLenFeatures) is passed to tf.parse_examples. context_features, sequence_features = {}, {} for k, v in six.iteritems(new_spec_dict): v = maybe_map_bfloat(v) if isinstance(v, tf.FixedLenSequenceFeature): sequence_features[k] = v elif isinstance(v, tf.FixedLenFeature): context_features[k] = v else: raise ValueError( 'Only FixedLenFeature and FixedLenSequenceFeature are currently ' 'supported.') # If there are any sequence features, we use parse_sequence_example. if sequence_features: result, sequence_result, feature_lengths = tf.io.parse_sequence_example( example, context_features=context_features, sequence_features=sequence_features) del feature_lengths result.update(sequence_result) else: result = tf.parse_example(example, context_features) to_convert = [ k for k, v in six.iteritems(spec_dict) if is_bfloat_feature(v) ] for c in to_convert: result[c] = tf.cast(result[c], tf.bfloat16) return result parsed_tensors = parse_wrapper(example_proto, tensor_dict) # Interpret encoded images. def decode_image(key, raw_bytes): """Decodes single or batches of JPEG- or PNG-encoded string tensors. Args: key: String key specified in feature map. raw_bytes: String tensor to decode as JPEG or PNG. Returns: Decoded image tensor with shape specified by tensor spec. Raises: ValueError: If dtype other than uint8 is supplied for image specs. """ img_batch_dims = tf.shape(raw_bytes) # The spatial + channel dimensions of a single image, assumed to be the # last 3 entries of the image feature's tensor spec. single_img_dims = tensor_spec_dict[key].shape[-3:] # Collapse (possibly multiple) batch dims to a single batch dim for # decoding purposes. raw_bytes = tf.reshape(raw_bytes, [-1]) img = tf.map_fn(tf.image.decode_image, raw_bytes, dtype=tf.uint8, back_prop=False) img.set_shape(raw_bytes.shape.concatenate(single_img_dims)) # Expand the collapsed batch dim back to the original img_batch_dims. img = tf.reshape(img, tf.concat([img_batch_dims, single_img_dims], 0)) return img for key, val in parsed_tensors.items(): tensor_spec = tensor_spec_dict[key] if tensorspec_utils.is_encoded_image_spec(tensor_spec): parsed_tensors[key] = decode_image(key, val) if tensor_spec.dtype != tf.uint8: raise ValueError('Encoded images with key {} must be ' 'specified with uint8 dtype.'.format(key)) # Ensure that we have a consistent ordered mapping despite the underlying # spec structure. flat_feature_tspec = tensorspec_utils.TensorSpecStruct( sorted( tensorspec_utils.flatten_spec_structure( feature_tspec).items())) # Using the flat spec structure we allow to map the same parsed_tensor # to multiple features or labels. Note, the spec structure ensures that # the corresponding tensorspecs are iddentical in such cases. features = tensorspec_utils.TensorSpecStruct([ (key, parsed_tensors[value.name]) for key, value in flat_feature_tspec.items() ]) features = tensorspec_utils.validate_and_pack(flat_feature_tspec, features, ignore_batch=True) if label_tspec is not None: # Ensure that we have a consistent ordered mapping despite the underlying # spec structure. flat_label_tspec = tensorspec_utils.TensorSpecStruct( sorted( tensorspec_utils.flatten_spec_structure( label_tspec).items())) labels = tensorspec_utils.TensorSpecStruct([ (key, parsed_tensors[value.name]) for key, value in flat_label_tspec.items() ]) labels = tensorspec_utils.validate_and_pack(flat_label_tspec, labels, ignore_batch=True) return features, labels return features