def testVariant(self): # TODO(ebrevdo): Re-enable use_gpu=True once non-DMA Variant # copying between CPU and GPU is supported. with self.test_session(use_gpu=False): variant_tensor = tensor_pb2.TensorProto( dtype=dtypes_lib.variant.as_datatype_enum, tensor_shape=tensor_shape.TensorShape([]).as_proto(), variant_val=[ tensor_pb2.VariantTensorDataProto( # Match registration in variant_op_registry.cc type_name=b"int", metadata=np.array(1, dtype=np.int32).tobytes()) ]) const = constant_op.constant(variant_tensor) const_value = const.op.get_attr("value") # Ensure we stored the tensor proto properly. self.assertProtoEquals(variant_tensor, const_value) # Smoke test -- ensure this executes without trouble. # Right now, non-numpy-compatible objects cannot be returned from a # session.run call; similarly, objects that can't be converted to # native numpy types cannot be passed to ops.convert_to_tensor. # TODO(ebrevdo): Add registration mechanism for # ops.convert_to_tensor and for session.run output. logging_const_op = logging_ops.Print( const, [const], message="Variant storing an int, decoded const value:").op logging_const_op.run()
def testZerosLikeVariant(self): # TODO(ebrevdo): Re-enable use_gpu=True once non-DMA Variant # copying between CPU and GPU is supported AND we register a # ZerosLike callback for GPU for Variant storing primitive types # in variant_op_registry.cc. with self.test_session(use_gpu=False): variant_tensor = tensor_pb2.TensorProto( dtype=dtypes_lib.variant.as_datatype_enum, tensor_shape=tensor_shape.TensorShape([]).as_proto(), variant_val=[ tensor_pb2.VariantTensorDataProto( # Match registration in variant_op_registry.cc type_name=b"int", metadata=np.array(1, dtype=np.int32).tobytes()) ]) const_variant = constant_op.constant(variant_tensor) zeros_like = array_ops.zeros_like(const_variant) zeros_like_op = logging_ops.Print( zeros_like, [const_variant, zeros_like], message= "Variant storing an int, input and output of zeros_like:").op # Smoke test -- ensure this executes without trouble. # Right now, non-numpy-compatible objects cannot be returned from a # session.run call; similarly, objects that can't be converted to # native numpy types cannot be passed to ops.convert_to_tensor. # TODO(ebrevdo): Add registration mechanism for # ops.convert_to_tensor and for session.run output. zeros_like_op.run()
def create_constant_variant(value): return constant_op.constant( tensor_pb2.TensorProto( dtype=dtypes.variant.as_datatype_enum, tensor_shape=tensor_shape.TensorShape([]).as_proto(), variant_val=[ tensor_pb2.VariantTensorDataProto( # Match registration in variant_op_registry.cc type_name=b"int", metadata=np.array(value, dtype=np.int32).tobytes()) ]))