Beispiel #1
0
def _encode_and_add_to_node_collection(collection_prefix, key, node):
    tf.add_to_collection(
        encoding.with_suffix(collection_prefix, encoding.KEY_SUFFIX),
        encoding.encode_key(key))
    tf.add_to_collection(
        encoding.with_suffix(collection_prefix, encoding.NODE_SUFFIX),
        encoding.encode_tensor_node(node))
Beispiel #2
0
    def testMultipleCallsToEvalInputReceiver(self):
        graph = tf.Graph()
        features1 = {'apple': tf.constant(1.0), 'banana': tf.constant(2.0)}
        labels1 = tf.constant(3.0)
        receiver_tensors1 = {'examples': tf.placeholder(tf.string)}

        features2 = {'cherry': tf.constant(3.0)}
        labels2 = {'alpha': tf.constant(4.0), 'bravo': tf.constant(5.0)}
        receiver_tensors2 = {'examples': tf.placeholder(tf.string)}

        with graph.as_default():
            export.EvalInputReceiver(features=features1,
                                     labels=labels1,
                                     receiver_tensors=receiver_tensors1)

            feature_keys_collection_name = encoding.with_suffix(
                encoding.FEATURES_COLLECTION, encoding.KEY_SUFFIX)
            feature_nodes_collection_name = encoding.with_suffix(
                encoding.FEATURES_COLLECTION, encoding.NODE_SUFFIX)
            label_keys_collection_name = encoding.with_suffix(
                encoding.LABELS_COLLECTION, encoding.KEY_SUFFIX)
            label_nodes_collection_name = encoding.with_suffix(
                encoding.LABELS_COLLECTION, encoding.NODE_SUFFIX)

            self.assertEqual(
                2, len(tf.get_collection(feature_keys_collection_name)))
            self.assertEqual(
                2, len(tf.get_collection(feature_nodes_collection_name)))
            self.assertEqual(
                1, len(tf.get_collection(label_keys_collection_name)))
            self.assertEqual(
                1, len(tf.get_collection(label_nodes_collection_name)))
            self.assertEqual(
                1, len(tf.get_collection(encoding.EXAMPLE_REF_COLLECTION)))
            self.assertEqual(
                1, len(tf.get_collection(encoding.TFMA_VERSION_COLLECTION)))

            # Call again with a different set of features, labels and receiver
            # tensors, check that the latest call overrides the earlier one.
            #
            # Note that we only check the lengths of some collections: more detailed
            # checks would require the test to include more knowledge about the
            # details of how exporting is done.
            export.EvalInputReceiver(features=features2,
                                     labels=labels2,
                                     receiver_tensors=receiver_tensors2)
            self.assertEqual(
                1, len(tf.get_collection(feature_keys_collection_name)))
            self.assertEqual(
                1, len(tf.get_collection(feature_nodes_collection_name)))
            self.assertEqual(
                2, len(tf.get_collection(label_keys_collection_name)))
            self.assertEqual(
                2, len(tf.get_collection(label_nodes_collection_name)))
            self.assertEqual(
                1, len(tf.get_collection(encoding.EXAMPLE_REF_COLLECTION)))
            self.assertEqual(
                1, len(tf.get_collection(encoding.TFMA_VERSION_COLLECTION)))
def _encode_and_add_to_node_collection(collection_prefix: Text,
                                       key: types.FPLKeyType,
                                       node: types.TensorType) -> None:
  tf.compat.v1.add_to_collection(
      encoding.with_suffix(collection_prefix, encoding.KEY_SUFFIX),
      encoding.encode_key(key))
  tf.compat.v1.add_to_collection(
      encoding.with_suffix(collection_prefix, encoding.NODE_SUFFIX),
      encoding.encode_tensor_node(node))
Beispiel #4
0
def _add_tfma_collections(features: types.TensorTypeMaybeDict,
                          labels: Optional[types.TensorTypeMaybeDict],
                          input_refs: types.TensorType):
    """Add extra collections for features, labels, input_refs, version.

  This should be called within the Graph that will be saved. Typical usage
  would be when features and labels have been parsed, i.e. in the
  input_receiver_fn.

  Args:
    features: dict of strings to tensors representing features
    labels: dict of strings to tensors or a single tensor
    input_refs: See EvalInputReceiver().
  """
    # Clear existing collections first, in case the EvalInputReceiver was called
    # multiple times.
    del tf.compat.v1.get_collection_ref(
        encoding.with_suffix(encoding.FEATURES_COLLECTION,
                             encoding.KEY_SUFFIX))[:]
    del tf.compat.v1.get_collection_ref(
        encoding.with_suffix(encoding.FEATURES_COLLECTION,
                             encoding.NODE_SUFFIX))[:]
    del tf.compat.v1.get_collection_ref(
        encoding.with_suffix(encoding.LABELS_COLLECTION,
                             encoding.KEY_SUFFIX))[:]
    del tf.compat.v1.get_collection_ref(
        encoding.with_suffix(encoding.LABELS_COLLECTION,
                             encoding.NODE_SUFFIX))[:]
    del tf.compat.v1.get_collection_ref(encoding.EXAMPLE_REF_COLLECTION)[:]
    del tf.compat.v1.get_collection_ref(encoding.TFMA_VERSION_COLLECTION)[:]

    for feature_name, feature_node in features.items():
        _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION,
                                           feature_name, feature_node)

    if labels is not None:
        # Labels can either be a Tensor, or a dict of Tensors.
        if not isinstance(labels, dict):
            labels = {util.default_dict_key(constants.LABELS_NAME): labels}

        for label_key, label_node in labels.items():
            _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION,
                                               label_key, label_node)
    # Previously input_refs was called example_ref. This code is being deprecated
    # so it was not renamed.
    example_ref_collection = tf.compat.v1.get_collection_ref(
        encoding.EXAMPLE_REF_COLLECTION)
    example_ref_collection.append(encoding.encode_tensor_node(input_refs))

    tf.compat.v1.add_to_collection(encoding.TFMA_VERSION_COLLECTION,
                                   version.VERSION)
Beispiel #5
0
def _add_tfma_collections(features,
                          labels,
                          example_ref):
  """Add extra collections for features, labels, example_ref, version.

  This should be called within the Graph that will be saved. Typical usage
  would be when features and labels have been parsed, i.e. in the
  input_receiver_fn.

  Args:
    features: dict of strings to tensors representing features
    labels: dict of strings to tensors or a single tensor
    example_ref: See EvalInputReceiver().
  """
  # Clear existing collections first, in case the EvalInputReceiver was called
  # multiple times.
  del tf.get_collection_ref(
      encoding.with_suffix(encoding.FEATURES_COLLECTION,
                           encoding.KEY_SUFFIX))[:]
  del tf.get_collection_ref(
      encoding.with_suffix(encoding.FEATURES_COLLECTION,
                           encoding.NODE_SUFFIX))[:]
  del tf.get_collection_ref(
      encoding.with_suffix(encoding.LABELS_COLLECTION, encoding.KEY_SUFFIX))[:]
  del tf.get_collection_ref(
      encoding.with_suffix(encoding.LABELS_COLLECTION, encoding.NODE_SUFFIX))[:]
  del tf.get_collection_ref(encoding.EXAMPLE_REF_COLLECTION)[:]
  del tf.get_collection_ref(encoding.TFMA_VERSION_COLLECTION)[:]

  for feature_name, feature_node in features.items():
    _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION,
                                       feature_name, feature_node)

  if labels is not None:
    # Labels can either be a Tensor, or a dict of Tensors.
    if not isinstance(labels, dict):
      labels = {encoding.DEFAULT_LABELS_DICT_KEY: labels}

    for label_key, label_node in labels.items():
      _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION, label_key,
                                         label_node)
  example_ref_collection = tf.get_collection_ref(
      encoding.EXAMPLE_REF_COLLECTION)
  example_ref_collection.append(encoding.encode_tensor_node(example_ref))

  tf.add_to_collection(encoding.TFMA_VERSION_COLLECTION, version.VERSION_STRING)
Beispiel #6
0
def get_node_map(
    meta_graph_def: meta_graph_pb2.MetaGraphDef, prefix: str,
    node_suffixes: List[str]
) -> Dict[types.FPLKeyType, Dict[str, CollectionDefValueType]]:
    """Get node map from meta_graph_def.

  This is designed to extract structures of the following form from the
  meta_graph_def collection_def:
    prefix/key
      key1
      key2
      key3
    prefix/suffix_a
      node1
      node2
      node3
    prefix/suffix_b
      node4
      node5
      node6

   which will become a dictionary:
   {
     key1 : {suffix_a: node1, suffix_b: node4}
     key2 : {suffix_a: node2, suffix_b: node5}
     key3 : {suffix_a: node3, suffix_b: node6}
   }.

  Keys must always be bytes. Values can be any supported CollectionDef type
  (bytes_list, any_list, etc)

  Args:
     meta_graph_def: MetaGraphDef containing the CollectionDefs to extract the
       structure from.
     prefix: Prefix for the CollectionDef names.
     node_suffixes: The suffixes to the prefix to form the names of the
       CollectionDefs to extract the nodes from, e.g. in the example described
       above, node_suffixes would be ['suffix_a', 'suffix_b'].

  Returns:
    A dictionary of dictionaries, as described in the example above.

  Raises:
    ValueError: The length of some node list did not match length of the key
    list.
  """
    node_lists = []
    for node_suffix in node_suffixes:
        collection_def_name = encoding.with_suffix(prefix, node_suffix)
        collection_def = meta_graph_def.collection_def.get(collection_def_name)
        if collection_def is None:
            # If we can't find the CollectionDef, append an empty list.
            #
            # Either all the CollectionDefs are missing, in which case we correctly
            # return an empty dict, or some of the CollectionDefs are non-empty,
            # in which case we raise an exception below.
            node_lists.append([])
        else:
            node_lists.append(
                getattr(collection_def,
                        collection_def.WhichOneof('kind')).value)
    keys = meta_graph_def.collection_def[encoding.with_suffix(
        prefix, encoding.KEY_SUFFIX)].bytes_list.value
    if not all([len(node_list) == len(keys) for node_list in node_lists]):
        raise ValueError(
            'length of each node_list should match length of keys. '
            'prefix was %s, node_lists were %s, keys was %s' %
            (prefix, node_lists, keys))
    result = {}
    for key, elems in zip(keys, zip(*node_lists)):
        result[encoding.decode_key(key)] = dict(zip(node_suffixes, elems))
    return result