def func_attr_value(): func = NameAttrList(name='test_attr_list_name', attr={ 'int': AttrValue(i=3333), 'bool': AttrValue(b=True) }) return AttrValue(func=func)
def attr_value_to_python_type(attr_value: tf.AttrValue) -> Any: """ Inverse of python_type_to_attr_value(). Args: attr_value: Protocol buffer version of a node's attribute value Returns: A Python object or built-in type corresponding to the field in `attr_value` that is in use. """ # TODO(frreiss): Handle AttrValues that are lists if attr_value.HasField("s"): # str # TODO(frreiss): Should we return the binary value here? return tf.compat.as_str(attr_value.s) elif attr_value.HasField("i"): # int return attr_value.i elif attr_value.HasField("f"): # float return attr_value.f elif attr_value.HasField("b"): # bool return attr_value.b elif attr_value.HasField("type"): # DType return tf.DType(attr_value.type) elif attr_value.HasField("shape"): # TensorShape # Undocumented behavior of public API: tf.TensorShape constructor accepts # a TensorShapeProto. return tf.TensorShape(attr_value.shape) elif attr_value.HasField("tensor"): # TensorProto return tf.make_ndarray(attr_value.tensor) # TODO(frreiss): Convert the "func" and "placeholder" fields of the union # here else: raise ValueError("Don't know how to convert AttrValue {} to " "a Python object".format(attr_value))
def int_attr_value(): return AttrValue(i=3333)
def placeholder_attr_value(): return AttrValue(placeholder='placeholder')
def bool_attr_value(): return AttrValue(b=True)
def float_attr_value(): return AttrValue(f=3.14159)
def bytes_attr_value(): return AttrValue(s=b'bytes')