예제 #1
0
def tensor_list_scatter(tensor, indices, element_shape, name=None):
  return gen_list_ops.tensor_list_scatter_v2(
      tensor=tensor,
      indices=indices,
      element_shape=_build_element_shape(element_shape),
      num_elements=-1,
      name=name)
예제 #2
0
def tensor_list_scatter(tensor, indices, element_shape, name=None):
    return gen_list_ops.tensor_list_scatter_v2(
        tensor=tensor,
        indices=indices,
        element_shape=_build_element_shape(element_shape),
        num_elements=-1,
        name=name)
예제 #3
0
def _TensorListGatherGrad(op, dtensor):
    """Gradient function for TensorListGather."""
    input_list, indices, _ = op.inputs
    dlist = gen_list_ops.tensor_list_scatter_v2(
        tensor=dtensor,
        indices=indices,
        element_shape=gen_list_ops.tensor_list_element_shape(
            input_list, shape_type=dtypes.int32),
        num_elements=gen_list_ops.tensor_list_length(input_list))
    return dlist, None, None
예제 #4
0
def _TensorListGatherGrad(op, dtensor):
  """Gradient function for TensorListGather."""
  input_list, indices, _ = op.inputs
  dlist = gen_list_ops.tensor_list_scatter_v2(
      tensor=dtensor,
      indices=indices,
      element_shape=gen_list_ops.tensor_list_element_shape(
          input_list, shape_type=dtypes.int32),
      num_elements=gen_list_ops.tensor_list_length(input_list))
  return dlist, None, None
예제 #5
0
def tensor_list_scatter(tensor,
                        indices,
                        element_shape=None,
                        input_handle=None,
                        name=None):
  if input_handle is not None:
    return gen_list_ops.tensor_list_scatter_into_existing_list(
        input_handle=input_handle, tensor=tensor, indices=indices, name=name)
  else:
    return gen_list_ops.tensor_list_scatter_v2(
        tensor=tensor,
        indices=indices,
        element_shape=_build_element_shape(element_shape),
        num_elements=-1,
        name=name)
예제 #6
0
def tensor_list_scatter(tensor,
                        indices,
                        element_shape=None,
                        input_handle=None,
                        name=None):
  if input_handle is not None:
    return gen_list_ops.tensor_list_scatter_into_existing_list(
        input_handle=input_handle, tensor=tensor, indices=indices, name=name)
  else:
    return gen_list_ops.tensor_list_scatter_v2(
        tensor=tensor,
        indices=indices,
        element_shape=_build_element_shape(element_shape),
        num_elements=-1,
        name=name)
예제 #7
0
def tensor_list_scatter(tensor,
                        indices,
                        element_shape=None,
                        input_handle=None,
                        name=None):
    """Returns a TensorList created or updated by scattering `tensor`."""
    tensor = ops.convert_to_tensor(tensor)
    if input_handle is not None:
        output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(
            input_handle=input_handle,
            tensor=tensor,
            indices=indices,
            name=name)
        handle_data_util.copy_handle_data(input_handle, output_handle)
        return output_handle
    else:
        output_handle = gen_list_ops.tensor_list_scatter_v2(
            tensor=tensor,
            indices=indices,
            element_shape=_build_element_shape(element_shape),
            num_elements=-1,
            name=name)
        _set_handle_data(output_handle, element_shape, tensor.dtype)
        return output_handle