Esempio n. 1
0
  def _capture_tensor_as_extra_input(self, tensor, name=None):
    # Substitute with a placeholder.
    self.extra_inputs.append(tensor)
    # Hoist the new input placeholder out of any control flow context
    # we're currently in.
    with ops.control_dependencies(None):
      ph = array_ops.placeholder(
          tensor.dtype, shape=tensor.get_shape(), name=name)
    # pylint: disable=protected-access
    if isinstance(tensor, ops.EagerTensor):
      handle_data = tensor._handle_data
      if handle_data:
        handle_data = handle_data.SerializeToString()
    else:
      handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph,
                                                tensor._as_tf_output())

    if handle_data:
      c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(),
                                  compat.as_bytes(handle_data))
    # pylint: enable=protected-access
    self.inputs.append(ph)
    self._captured[tensor.ref()] = ph
    self.extra_args.append(ph)
    if _is_guaranteed_const(tensor):
      with ops.control_dependencies(None):
        return array_ops.guarantee_const(ph)
    else:
      return ph
Esempio n. 2
0
 def _capture_tensor_as_extra_input(self, tensor, name=None):
   # Substitute with a placeholder.
   self.extra_inputs.append(tensor)
   # Hoist the new input placeholder out of any control flow context
   # we're currently in.
   with ops.control_dependencies(None):
     ph = array_ops.placeholder(
         tensor.dtype, shape=tensor.get_shape(), name=name)
   # pylint: disable=protected-access
   if ops._USE_C_SHAPES:
     handle_data = c_api.GetResourceHandleShapeAndType(tensor.graph._c_graph,
                                                       tensor._as_tf_output())
     if handle_data:
       c_api.SetResourceHandleShapeAndType(ph.graph._c_graph,
                                           ph._as_tf_output(),
                                           compat.as_bytes(handle_data))
   else:
     ph._handle_data = tensor._handle_data
   # pylint: enable=protected-access
   self.inputs.append(ph)
   self._captured[tensor] = ph
   self.extra_args.append(ph)
   if _is_guaranteed_const(tensor):
     with ops.control_dependencies(None):
       return array_ops.guarantee_const(ph)
   else:
     return ph
Esempio n. 3
0
 def _capture_tensor_as_extra_input(self, tensor):
     # Substitute with a placeholder.
     self.extra_inputs.append(tensor)
     # Hoist the new input placeholder out of any control flow context
     # we're currently in.
     with ops.control_dependencies(None):
         ph = array_ops.placeholder(tensor.dtype, shape=tensor.get_shape())
     # pylint: disable=protected-access
     if ops._USE_C_SHAPES:
         handle_data = c_api.GetResourceHandleShapeAndType(
             tensor.graph._c_graph, tensor._as_tf_output())
         if handle_data:
             c_api.SetResourceHandleShapeAndType(
                 ph.graph._c_graph, ph._as_tf_output(),
                 compat.as_bytes(handle_data))
     else:
         ph._handle_data = tensor._handle_data
     # pylint: enable=protected-access
     self._captured[tensor] = ph
     self.extra_args.append(ph)
     if _is_guaranteed_const(tensor):
         with ops.control_dependencies(None):
             return array_ops.guarantee_const(ph)
     else:
         return ph
Esempio n. 4
0
 def testVariables(self):
   with self.test_session() as sess:
     for use_resource in [False, True]:
       a = variable_scope.get_variable(
           "var_{}".format(use_resource), [],
           initializer=init_ops.constant_initializer(10.0),
           use_resource=use_resource)
       guarantee_a = array_ops.guarantee_const(a)
       sess.run(variables.global_variables_initializer())
       self.assertEqual(10.0, guarantee_a.eval())
 def testVariables(self):
   with self.test_session() as sess:
     for use_resource in [False, True]:
       a = variable_scope.get_variable(
           "var_{}".format(use_resource), [],
           initializer=init_ops.constant_initializer(10.0),
           use_resource=use_resource)
       guarantee_a = array_ops.guarantee_const(a)
       sess.run(variables.global_variables_initializer())
       self.assertEqual(10.0, guarantee_a.eval())
Esempio n. 6
0
 def testResourceRejection(self):
   with self.test_session() as sess:
     a = variable_scope.get_variable(
         "resource_var", [],
         initializer=init_ops.constant_initializer(10.0),
         use_resource=True)
     guarantee_a = array_ops.guarantee_const(a.handle)
     sess.run(variables.global_variables_initializer())
     with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
                                              "cannot be a resource variable"):
       guarantee_a.eval()
 def testResourceRejection(self):
   with self.test_session() as sess:
     a = variable_scope.get_variable(
         "resource_var", [],
         initializer=init_ops.constant_initializer(10.0),
         use_resource=True)
     guarantee_a = array_ops.guarantee_const(a.handle)
     sess.run(variables.global_variables_initializer())
     with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
                                              "cannot be a resource variable"):
       guarantee_a.eval()
Esempio n. 8
0
 def guarantee_const_getter(getter, name, *args, **kwargs):
     with ops.control_dependencies(None):
         return array_ops.guarantee_const(getter(name, *args, **kwargs),
                                          name=name + "/GuaranteeConst")
Esempio n. 9
0
 def testSimple(self):
   with self.test_session():
     a = array_ops.constant(10)
     guarantee_a = array_ops.guarantee_const(a)
     self.assertEqual(10, guarantee_a.eval())
Esempio n. 10
0
 def testSimple(self):
   with self.test_session():
     a = array_ops.constant(10)
     guarantee_a = array_ops.guarantee_const(a)
     self.assertEqual(10, guarantee_a.eval())
Esempio n. 11
0
 def guarantee_const_getter(getter, name, *args, **kwargs):
   with ops.control_dependencies(None):
     return array_ops.guarantee_const(
         getter(name, *args, **kwargs), name=name + "/GuaranteeConst")