예제 #1
0
 def scatter(self, indices, value, name=None):
   """See TensorArray."""
   with ops.name_scope(name, "TensorArrayScatter",
                       [self._handle, value, indices]):
     # TODO(b/129870929): Fix after all callers provide proper init dtype.
     value = ops.convert_to_tensor(
         value, preferred_dtype=self._dtype, name="value")
     _check_dtypes(value, self._dtype)
     if self._infer_shape and not context.executing_eagerly():
       self._merge_element_shape(value.shape[1:])
     with self._maybe_colocate_with(value):
       flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
           handle=self._handle,
           indices=indices,
           value=value,
           flow_in=self._flow,
           name=name)
     ta = TensorArray(
         dtype=self._dtype,
         handle=self._handle,
         flow=flow_out,
         colocate_with_first_write_call=self._colocate_with_first_write_call)
     ta._infer_shape = self._infer_shape
     ta._element_shape = self._element_shape
     ta._colocate_with = self._colocate_with
     ta._dynamic_size = self._dynamic_size
     return ta
예제 #2
0
 def scatter(self, indices, value, name=None):
   """See TensorArray."""
   with ops.name_scope(name, "TensorArrayScatter",
                       [self._handle, value, indices]):
     value = ops.convert_to_tensor(value, name="value")
     if self._infer_shape and not context.executing_eagerly():
       self._merge_element_shape(value.shape[1:])
     with self._maybe_colocate_with(value):
       flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
           handle=self._handle,
           indices=indices,
           value=value,
           flow_in=self._flow,
           name=name)
     ta = TensorArray(
         dtype=self._dtype, handle=self._handle, flow=flow_out,
         colocate_with_first_write_call=self._colocate_with_first_write_call)
     ta._infer_shape = self._infer_shape
     ta._element_shape = self._element_shape
     ta._colocate_with = self._colocate_with
     return ta
예제 #3
0
 def scatter(self, indices, value, name=None):
   """See TensorArray."""
   with ops.name_scope(name, "TensorArrayScatter",
                       [self._handle, value, indices]):
     value = ops.convert_to_tensor(value, name="value")
     if self._infer_shape and context.in_graph_mode():
       self._merge_element_shape(value.shape[1:])
     with self._maybe_colocate_with(value):
       flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
           handle=self._handle,
           indices=indices,
           value=value,
           flow_in=self._flow,
           name=name)
     ta = TensorArray(
         dtype=self._dtype, handle=self._handle, flow=flow_out,
         colocate_with_first_write_call=self._colocate_with_first_write_call)
     ta._infer_shape = self._infer_shape
     ta._element_shape = self._element_shape
     ta._colocate_with = self._colocate_with
     return ta