Exemplo n.º 1
0
 def scatter(self, indices, value, name=None):
   """See TensorArray."""
   with ops.name_scope(name, "TensorArrayScatter",
                       [self._flow, value, indices]):
     value = ops.convert_to_tensor(value, name="value")
     if self._infer_shape and not context.executing_eagerly():
       self._merge_element_shape(value.shape[1:])
     flow_out = list_ops.tensor_list_scatter(
         tensor=value, indices=indices, element_shape=-1)
     return build_ta_with_new_flow(self, flow_out)
Exemplo n.º 2
0
 def scatter(self, indices, value, name=None):
   """See TensorArray."""
   with ops.name_scope(name, "TensorArrayScatter",
                       [self._flow, value, indices]):
     value = ops.convert_to_tensor(value, name="value")
     if self._infer_shape and not context.executing_eagerly():
       self._merge_element_shape(value.shape[1:])
     element_shape = self._element_shape[0] if self._element_shape else None
     flow_out = list_ops.tensor_list_scatter(
         tensor=value, indices=indices, input_handle=self._flow)
     return build_ta_with_new_flow(self, flow_out)
Exemplo n.º 3
0
 def testScatterGrad(self):
   with backprop.GradientTape() as tape:
     c0 = constant_op.constant([1.0, 2.0])
     tape.watch(c0)
     l = list_ops.tensor_list_scatter(
         c0, [1, 0], ops.convert_to_tensor([], dtype=dtypes.int32))
     t0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
     t1 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
     self.assertAllEqual(self.evaluate(t0), 2.0)
     self.assertAllEqual(self.evaluate(t1), 1.0)
     loss = t0 * t0 + t1 * t1
   dt = tape.gradient(loss, c0)
   self.assertAllEqual(self.evaluate(dt), [2., 4.])
Exemplo n.º 4
0
 def testScatterGrad(self):
   with backprop.GradientTape() as tape:
     c0 = constant_op.constant([1.0, 2.0])
     tape.watch(c0)
     l = list_ops.tensor_list_scatter(
         c0, [1, 0], ops.convert_to_tensor([], dtype=dtypes.int32))
     t0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
     t1 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
     self.assertAllEqual(self.evaluate(t0), 2.0)
     self.assertAllEqual(self.evaluate(t1), 1.0)
     loss = t0 * t0 + t1 * t1
   dt = tape.gradient(loss, c0)
   self.assertAllEqual(self.evaluate(dt), [2., 4.])
Exemplo n.º 5
0
 def scatter(self, indices, value, name=None):
   """See TensorArray."""
   with ops.name_scope(name, "TensorArrayScatter",
                       [self._flow, value, indices]):
     # TODO(b/129870929): Fix after all callers provide proper init dtype.
     value = ops.convert_to_tensor(
         value, preferred_dtype=self._dtype, name="value")
     _check_dtypes(value, self._dtype)
     if self._infer_shape and not context.executing_eagerly():
       self._merge_element_shape(value.shape[1:])
     element_shape = self._element_shape[0] if self._element_shape else None
     flow_out = list_ops.tensor_list_scatter(
         tensor=value, indices=indices, input_handle=self._flow)
     return build_ta_with_new_flow(self, flow_out)
Exemplo n.º 6
0
 def scatter(self, indices, value, name=None):
   """See TensorArray."""
   with ops.name_scope(name, "TensorArrayScatter",
                       [self._flow, value, indices]):
     value = ops.convert_to_tensor(value, name="value")
     if self._infer_shape and not context.executing_eagerly():
       self._merge_element_shape(value.shape[1:])
     flow_out = list_ops.tensor_list_scatter(
         tensor=value, indices=indices, element_shape=-1)
     ta = TensorArray(
         dtype=self._dtype,
         handle=self.handle,
         flow=flow_out,
         colocate_with_first_write_call=self._colocate_with_first_write_call)
     ta._infer_shape = self._infer_shape
     ta._element_shape = self._element_shape
     ta._colocate_with = self._colocate_with
     return ta
Exemplo n.º 7
0
 def scatter(self, indices, value, name=None):
   """See TensorArray."""
   with ops.name_scope(name, "TensorArrayScatter",
                       [self._flow, value, indices]):
     value = ops.convert_to_tensor(value, name="value")
     if self._infer_shape and not context.executing_eagerly():
       self._merge_element_shape(value.shape[1:])
     flow_out = list_ops.tensor_list_scatter(
         tensor=value, indices=indices, element_shape=-1)
     ta = TensorArray(
         dtype=self._dtype,
         handle=self.handle,
         flow=flow_out,
         colocate_with_first_write_call=self._colocate_with_first_write_call)
     ta._infer_shape = self._infer_shape
     ta._element_shape = self._element_shape
     ta._colocate_with = self._colocate_with
     return ta