def _encode_and_add_to_node_collection(collection_prefix, key, node): tf.add_to_collection( encoding.with_suffix(collection_prefix, encoding.KEY_SUFFIX), encoding.encode_key(key)) tf.add_to_collection( encoding.with_suffix(collection_prefix, encoding.NODE_SUFFIX), encoding.encode_tensor_node(node))
def testMultipleCallsToEvalInputReceiver(self): graph = tf.Graph() features1 = {'apple': tf.constant(1.0), 'banana': tf.constant(2.0)} labels1 = tf.constant(3.0) receiver_tensors1 = {'examples': tf.placeholder(tf.string)} features2 = {'cherry': tf.constant(3.0)} labels2 = {'alpha': tf.constant(4.0), 'bravo': tf.constant(5.0)} receiver_tensors2 = {'examples': tf.placeholder(tf.string)} with graph.as_default(): export.EvalInputReceiver(features=features1, labels=labels1, receiver_tensors=receiver_tensors1) feature_keys_collection_name = encoding.with_suffix( encoding.FEATURES_COLLECTION, encoding.KEY_SUFFIX) feature_nodes_collection_name = encoding.with_suffix( encoding.FEATURES_COLLECTION, encoding.NODE_SUFFIX) label_keys_collection_name = encoding.with_suffix( encoding.LABELS_COLLECTION, encoding.KEY_SUFFIX) label_nodes_collection_name = encoding.with_suffix( encoding.LABELS_COLLECTION, encoding.NODE_SUFFIX) self.assertEqual( 2, len(tf.get_collection(feature_keys_collection_name))) self.assertEqual( 2, len(tf.get_collection(feature_nodes_collection_name))) self.assertEqual( 1, len(tf.get_collection(label_keys_collection_name))) self.assertEqual( 1, len(tf.get_collection(label_nodes_collection_name))) self.assertEqual( 1, len(tf.get_collection(encoding.EXAMPLE_REF_COLLECTION))) self.assertEqual( 1, len(tf.get_collection(encoding.TFMA_VERSION_COLLECTION))) # Call again with a different set of features, labels and receiver # tensors, check that the latest call overrides the earlier one. # # Note that we only check the lengths of some collections: more detailed # checks would require the test to include more knowledge about the # details of how exporting is done. export.EvalInputReceiver(features=features2, labels=labels2, receiver_tensors=receiver_tensors2) self.assertEqual( 1, len(tf.get_collection(feature_keys_collection_name))) self.assertEqual( 1, len(tf.get_collection(feature_nodes_collection_name))) self.assertEqual( 2, len(tf.get_collection(label_keys_collection_name))) self.assertEqual( 2, len(tf.get_collection(label_nodes_collection_name))) self.assertEqual( 1, len(tf.get_collection(encoding.EXAMPLE_REF_COLLECTION))) self.assertEqual( 1, len(tf.get_collection(encoding.TFMA_VERSION_COLLECTION)))
def _encode_and_add_to_node_collection(collection_prefix: Text, key: types.FPLKeyType, node: types.TensorType) -> None: tf.compat.v1.add_to_collection( encoding.with_suffix(collection_prefix, encoding.KEY_SUFFIX), encoding.encode_key(key)) tf.compat.v1.add_to_collection( encoding.with_suffix(collection_prefix, encoding.NODE_SUFFIX), encoding.encode_tensor_node(node))
def _add_tfma_collections(features: types.TensorTypeMaybeDict, labels: Optional[types.TensorTypeMaybeDict], input_refs: types.TensorType): """Add extra collections for features, labels, input_refs, version. This should be called within the Graph that will be saved. Typical usage would be when features and labels have been parsed, i.e. in the input_receiver_fn. Args: features: dict of strings to tensors representing features labels: dict of strings to tensors or a single tensor input_refs: See EvalInputReceiver(). """ # Clear existing collections first, in case the EvalInputReceiver was called # multiple times. del tf.compat.v1.get_collection_ref( encoding.with_suffix(encoding.FEATURES_COLLECTION, encoding.KEY_SUFFIX))[:] del tf.compat.v1.get_collection_ref( encoding.with_suffix(encoding.FEATURES_COLLECTION, encoding.NODE_SUFFIX))[:] del tf.compat.v1.get_collection_ref( encoding.with_suffix(encoding.LABELS_COLLECTION, encoding.KEY_SUFFIX))[:] del tf.compat.v1.get_collection_ref( encoding.with_suffix(encoding.LABELS_COLLECTION, encoding.NODE_SUFFIX))[:] del tf.compat.v1.get_collection_ref(encoding.EXAMPLE_REF_COLLECTION)[:] del tf.compat.v1.get_collection_ref(encoding.TFMA_VERSION_COLLECTION)[:] for feature_name, feature_node in features.items(): _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION, feature_name, feature_node) if labels is not None: # Labels can either be a Tensor, or a dict of Tensors. if not isinstance(labels, dict): labels = {util.default_dict_key(constants.LABELS_NAME): labels} for label_key, label_node in labels.items(): _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION, label_key, label_node) # Previously input_refs was called example_ref. This code is being deprecated # so it was not renamed. example_ref_collection = tf.compat.v1.get_collection_ref( encoding.EXAMPLE_REF_COLLECTION) example_ref_collection.append(encoding.encode_tensor_node(input_refs)) tf.compat.v1.add_to_collection(encoding.TFMA_VERSION_COLLECTION, version.VERSION)
def _add_tfma_collections(features, labels, example_ref): """Add extra collections for features, labels, example_ref, version. This should be called within the Graph that will be saved. Typical usage would be when features and labels have been parsed, i.e. in the input_receiver_fn. Args: features: dict of strings to tensors representing features labels: dict of strings to tensors or a single tensor example_ref: See EvalInputReceiver(). """ # Clear existing collections first, in case the EvalInputReceiver was called # multiple times. del tf.get_collection_ref( encoding.with_suffix(encoding.FEATURES_COLLECTION, encoding.KEY_SUFFIX))[:] del tf.get_collection_ref( encoding.with_suffix(encoding.FEATURES_COLLECTION, encoding.NODE_SUFFIX))[:] del tf.get_collection_ref( encoding.with_suffix(encoding.LABELS_COLLECTION, encoding.KEY_SUFFIX))[:] del tf.get_collection_ref( encoding.with_suffix(encoding.LABELS_COLLECTION, encoding.NODE_SUFFIX))[:] del tf.get_collection_ref(encoding.EXAMPLE_REF_COLLECTION)[:] del tf.get_collection_ref(encoding.TFMA_VERSION_COLLECTION)[:] for feature_name, feature_node in features.items(): _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION, feature_name, feature_node) if labels is not None: # Labels can either be a Tensor, or a dict of Tensors. if not isinstance(labels, dict): labels = {encoding.DEFAULT_LABELS_DICT_KEY: labels} for label_key, label_node in labels.items(): _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION, label_key, label_node) example_ref_collection = tf.get_collection_ref( encoding.EXAMPLE_REF_COLLECTION) example_ref_collection.append(encoding.encode_tensor_node(example_ref)) tf.add_to_collection(encoding.TFMA_VERSION_COLLECTION, version.VERSION_STRING)
def get_node_map( meta_graph_def: meta_graph_pb2.MetaGraphDef, prefix: str, node_suffixes: List[str] ) -> Dict[types.FPLKeyType, Dict[str, CollectionDefValueType]]: """Get node map from meta_graph_def. This is designed to extract structures of the following form from the meta_graph_def collection_def: prefix/key key1 key2 key3 prefix/suffix_a node1 node2 node3 prefix/suffix_b node4 node5 node6 which will become a dictionary: { key1 : {suffix_a: node1, suffix_b: node4} key2 : {suffix_a: node2, suffix_b: node5} key3 : {suffix_a: node3, suffix_b: node6} }. Keys must always be bytes. Values can be any supported CollectionDef type (bytes_list, any_list, etc) Args: meta_graph_def: MetaGraphDef containing the CollectionDefs to extract the structure from. prefix: Prefix for the CollectionDef names. node_suffixes: The suffixes to the prefix to form the names of the CollectionDefs to extract the nodes from, e.g. in the example described above, node_suffixes would be ['suffix_a', 'suffix_b']. Returns: A dictionary of dictionaries, as described in the example above. Raises: ValueError: The length of some node list did not match length of the key list. """ node_lists = [] for node_suffix in node_suffixes: collection_def_name = encoding.with_suffix(prefix, node_suffix) collection_def = meta_graph_def.collection_def.get(collection_def_name) if collection_def is None: # If we can't find the CollectionDef, append an empty list. # # Either all the CollectionDefs are missing, in which case we correctly # return an empty dict, or some of the CollectionDefs are non-empty, # in which case we raise an exception below. node_lists.append([]) else: node_lists.append( getattr(collection_def, collection_def.WhichOneof('kind')).value) keys = meta_graph_def.collection_def[encoding.with_suffix( prefix, encoding.KEY_SUFFIX)].bytes_list.value if not all([len(node_list) == len(keys) for node_list in node_lists]): raise ValueError( 'length of each node_list should match length of keys. ' 'prefix was %s, node_lists were %s, keys was %s' % (prefix, node_lists, keys)) result = {} for key, elems in zip(keys, zip(*node_lists)): result[encoding.decode_key(key)] = dict(zip(node_suffixes, elems)) return result