def _capture_tensor_as_extra_input(self, tensor, name=None): # Substitute with a placeholder. self.extra_inputs.append(tensor) # Hoist the new input placeholder out of any control flow context # we're currently in. with ops.control_dependencies(None): ph = array_ops.placeholder( tensor.dtype, shape=tensor.get_shape(), name=name) # pylint: disable=protected-access if isinstance(tensor, ops.EagerTensor): handle_data = tensor._handle_data if handle_data: handle_data = handle_data.SerializeToString() else: handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph, tensor._as_tf_output()) if handle_data: c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(), compat.as_bytes(handle_data)) # pylint: enable=protected-access self.inputs.append(ph) self._captured[tensor.ref()] = ph self.extra_args.append(ph) if _is_guaranteed_const(tensor): with ops.control_dependencies(None): return array_ops.guarantee_const(ph) else: return ph
def _capture_tensor_as_extra_input(self, tensor, name=None): # Substitute with a placeholder. self.extra_inputs.append(tensor) # Hoist the new input placeholder out of any control flow context # we're currently in. with ops.control_dependencies(None): ph = array_ops.placeholder( tensor.dtype, shape=tensor.get_shape(), name=name) # pylint: disable=protected-access if ops._USE_C_SHAPES: handle_data = c_api.GetResourceHandleShapeAndType(tensor.graph._c_graph, tensor._as_tf_output()) if handle_data: c_api.SetResourceHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(), compat.as_bytes(handle_data)) else: ph._handle_data = tensor._handle_data # pylint: enable=protected-access self.inputs.append(ph) self._captured[tensor] = ph self.extra_args.append(ph) if _is_guaranteed_const(tensor): with ops.control_dependencies(None): return array_ops.guarantee_const(ph) else: return ph
def _capture_tensor_as_extra_input(self, tensor): # Substitute with a placeholder. self.extra_inputs.append(tensor) # Hoist the new input placeholder out of any control flow context # we're currently in. with ops.control_dependencies(None): ph = array_ops.placeholder(tensor.dtype, shape=tensor.get_shape()) # pylint: disable=protected-access if ops._USE_C_SHAPES: handle_data = c_api.GetResourceHandleShapeAndType( tensor.graph._c_graph, tensor._as_tf_output()) if handle_data: c_api.SetResourceHandleShapeAndType( ph.graph._c_graph, ph._as_tf_output(), compat.as_bytes(handle_data)) else: ph._handle_data = tensor._handle_data # pylint: enable=protected-access self._captured[tensor] = ph self.extra_args.append(ph) if _is_guaranteed_const(tensor): with ops.control_dependencies(None): return array_ops.guarantee_const(ph) else: return ph
def testVariables(self): with self.test_session() as sess: for use_resource in [False, True]: a = variable_scope.get_variable( "var_{}".format(use_resource), [], initializer=init_ops.constant_initializer(10.0), use_resource=use_resource) guarantee_a = array_ops.guarantee_const(a) sess.run(variables.global_variables_initializer()) self.assertEqual(10.0, guarantee_a.eval())
def testResourceRejection(self): with self.test_session() as sess: a = variable_scope.get_variable( "resource_var", [], initializer=init_ops.constant_initializer(10.0), use_resource=True) guarantee_a = array_ops.guarantee_const(a.handle) sess.run(variables.global_variables_initializer()) with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, "cannot be a resource variable"): guarantee_a.eval()
def guarantee_const_getter(getter, name, *args, **kwargs): with ops.control_dependencies(None): return array_ops.guarantee_const(getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
def testSimple(self): with self.test_session(): a = array_ops.constant(10) guarantee_a = array_ops.guarantee_const(a) self.assertEqual(10, guarantee_a.eval())
def guarantee_const_getter(getter, name, *args, **kwargs): with ops.control_dependencies(None): return array_ops.guarantee_const( getter(name, *args, **kwargs), name=name + "/GuaranteeConst")