예제 #1
0
def _TensorListGatherGrad(op, dtensor):
  """Gradient function for TensorListGather."""
  input_list, indices, _ = op.inputs
  element_shape = gen_list_ops.tensor_list_element_shape(
      input_list, shape_type=dtypes.int32)
  num_elements = gen_list_ops.tensor_list_length(input_list)
  dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype)
  dlist = tensor_list_scatter(
      tensor=dtensor, indices=indices, input_handle=dlist)
  return dlist, None, None
예제 #2
0
def _TensorListGatherGrad(op, dtensor):
    """Gradient function for TensorListGather."""
    input_list, indices, _ = op.inputs
    dlist = gen_list_ops.tensor_list_scatter_v2(
        tensor=dtensor,
        indices=indices,
        element_shape=gen_list_ops.tensor_list_element_shape(
            input_list, shape_type=dtypes.int32),
        num_elements=gen_list_ops.tensor_list_length(input_list))
    return dlist, None, None
예제 #3
0
def _TensorListGatherGrad(op, dtensor):
  """Gradient function for TensorListGather."""
  input_list, indices, _ = op.inputs
  dlist = gen_list_ops.tensor_list_scatter_v2(
      tensor=dtensor,
      indices=indices,
      element_shape=gen_list_ops.tensor_list_element_shape(
          input_list, shape_type=dtypes.int32),
      num_elements=gen_list_ops.tensor_list_length(input_list))
  return dlist, None, None
예제 #4
0
def _TensorListGatherGrad(op, dtensor):
  """Gradient function for TensorListGather."""
  input_list, indices, _ = op.inputs
  element_shape = gen_list_ops.tensor_list_element_shape(
      input_list, shape_type=dtypes.int32)
  num_elements = gen_list_ops.tensor_list_length(input_list)
  dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype)
  dlist = tensor_list_scatter(
      tensor=dtensor, indices=indices, input_handle=dlist)
  return dlist, None, None
예제 #5
0
def _TensorListGatherGrad(op, dtensor):
    input_list, indices = op.inputs
    dlist = gen_list_ops.tensor_list_scatter(
        tensor=dtensor,
        indices=indices,
        element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32))
    # TensorListScatter returns a list with size `max(indices) + 1`
    # so we manually resize it to match the size of the input list.
    input_list_size = gen_list_ops.tensor_list_length(input_list)
    dlist = gen_list_ops.tensor_list_resize(dlist, input_list_size)
    return dlist, None
예제 #6
0
def _TensorListGetItemGrad(op, ditem):
  """Gradient for TensorListGetItem."""
  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
  list_grad = gen_list_ops.tensor_list_set_item(
      gen_list_ops.tensor_list_reserve(
          gen_list_ops.tensor_list_element_shape(op.inputs[0],
                                                 shape_type=dtypes.int32),
          list_size, element_dtype=ditem.dtype),
      index=op.inputs[1],
      item=ditem)
  index_grad = None
  return list_grad, index_grad
def _TensorListGetItemGrad(op, ditem):
  """Gradient for TensorListGetItem."""
  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
  list_grad = gen_list_ops.tensor_list_set_item(
      gen_list_ops.tensor_list_reserve(
          gen_list_ops.tensor_list_element_shape(op.inputs[0],
                                                 shape_type=dtypes.int32),
          list_size, element_dtype=ditem.dtype),
      index=op.inputs[1],
      item=ditem)
  index_grad = None
  return list_grad, index_grad
예제 #8
0
def _TensorListGatherGrad(op, dtensor):
  """Gradient function for TensorListGather."""
  input_list, indices, _ = op.inputs
  dlist = gen_list_ops.tensor_list_scatter(
      tensor=dtensor,
      indices=indices,
      element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32))
  # TensorListScatter returns a list with size `max(indices) + 1`
  # so we manually resize it to match the size of the input list.
  input_list_size = gen_list_ops.tensor_list_length(input_list)
  dlist = gen_list_ops.tensor_list_resize(dlist, input_list_size)
  return dlist, None, None
예제 #9
0
def tensor_list_set_item(input_handle,
                         index,
                         item,
                         resize_if_index_out_of_bounds=False,
                         name=None):
  """Sets `item` at `index` in input list."""
  if resize_if_index_out_of_bounds:
    input_list_size = gen_list_ops.tensor_list_length(input_handle)
    # TODO(srbs): This could cause some slowdown. Consider fusing resize
    # functionality in the SetItem op.
    input_handle = control_flow_ops.cond(
        index >= input_list_size,
        lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
            input_handle, index + 1),
        lambda: input_handle)
  return gen_list_ops.tensor_list_set_item(
      input_handle=input_handle, index=index, item=item, name=name)
예제 #10
0
def tensor_list_set_item(input_handle,
                         index,
                         item,
                         resize_if_index_out_of_bounds=False,
                         name=None):
  """Sets `item` at `index` in input list."""
  if resize_if_index_out_of_bounds:
    input_list_size = gen_list_ops.tensor_list_length(input_handle)
    # TODO(srbs): This could cause some slowdown. Consider fusing resize
    # functionality in the SetItem op.
    input_handle = control_flow_ops.cond(
        index >= input_list_size,
        lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
            input_handle, index + 1),
        lambda: input_handle)
  return gen_list_ops.tensor_list_set_item(
      input_handle=input_handle, index=index, item=item, name=name)
예제 #11
0
def _TensorListResizeGrad(op, dlist):
    input_list, _ = op.inputs
    input_list_size = gen_list_ops.tensor_list_length(input_list)
    return gen_list_ops.tensor_list_resize(dlist, input_list_size), None
예제 #12
0
def _TensorListResizeGrad(op, dlist):
  input_list, _ = op.inputs
  input_list_size = gen_list_ops.tensor_list_length(input_list)
  return gen_list_ops.tensor_list_resize(dlist, input_list_size), None