Пример #1
0
def _TensorListGatherGrad(op, dtensor):
    input_list, indices = op.inputs
    dlist = gen_list_ops.tensor_list_scatter(
        tensor=dtensor,
        indices=indices,
        element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32))
    # TensorListScatter returns a list with size `max(indices) + 1`
    # so we manually resize it to match the size of the input list.
    input_list_size = gen_list_ops.tensor_list_length(input_list)
    dlist = gen_list_ops.tensor_list_resize(dlist, input_list_size)
    return dlist, None
Пример #2
0
def _TensorListGatherGrad(op, dtensor):
  """Gradient function for TensorListGather."""
  input_list, indices, _ = op.inputs
  dlist = gen_list_ops.tensor_list_scatter(
      tensor=dtensor,
      indices=indices,
      element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32))
  # TensorListScatter returns a list with size `max(indices) + 1`
  # so we manually resize it to match the size of the input list.
  input_list_size = gen_list_ops.tensor_list_length(input_list)
  dlist = gen_list_ops.tensor_list_resize(dlist, input_list_size)
  return dlist, None, None
Пример #3
0
def tensor_list_set_item(input_handle,
                         index,
                         item,
                         resize_if_index_out_of_bounds=False,
                         name=None):
  """Sets `item` at `index` in input list."""
  if resize_if_index_out_of_bounds:
    input_list_size = gen_list_ops.tensor_list_length(input_handle)
    # TODO(srbs): This could cause some slowdown. Consider fusing resize
    # functionality in the SetItem op.
    input_handle = control_flow_ops.cond(
        index >= input_list_size,
        lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
            input_handle, index + 1),
        lambda: input_handle)
  return gen_list_ops.tensor_list_set_item(
      input_handle=input_handle, index=index, item=item, name=name)
Пример #4
0
def tensor_list_set_item(input_handle,
                         index,
                         item,
                         resize_if_index_out_of_bounds=False,
                         name=None):
  """Sets `item` at `index` in input list."""
  if resize_if_index_out_of_bounds:
    input_list_size = gen_list_ops.tensor_list_length(input_handle)
    # TODO(srbs): This could cause some slowdown. Consider fusing resize
    # functionality in the SetItem op.
    input_handle = control_flow_ops.cond(
        index >= input_list_size,
        lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
            input_handle, index + 1),
        lambda: input_handle)
  return gen_list_ops.tensor_list_set_item(
      input_handle=input_handle, index=index, item=item, name=name)
Пример #5
0
def _TensorListResizeGrad(op, dlist):
    input_list, _ = op.inputs
    input_list_size = gen_list_ops.tensor_list_length(input_list)
    return gen_list_ops.tensor_list_resize(dlist, input_list_size), None
Пример #6
0
def _TensorListResizeGrad(op, dlist):
  input_list, _ = op.inputs
  input_list_size = gen_list_ops.tensor_list_length(input_list)
  return gen_list_ops.tensor_list_resize(dlist, input_list_size), None