コード例 #1
0
ファイル: list_ops.py プロジェクト: becster/tensorflow
def _TensorListSetItemGrad(op, dlist):
  _, index, item = op.inputs
  list_grad = gen_list_ops.tensor_list_set_item(
      dlist, index=index, item=array_ops.zeros_like(item))
  index_grad = None
  element_grad = gen_list_ops.tensor_list_get_item(
      dlist, index, element_dtype=item.dtype)
  return list_grad, index_grad, element_grad
コード例 #2
0
def _TensorListSetItemGrad(op, dlist):
  _, index, item = op.inputs
  list_grad = gen_list_ops.tensor_list_set_item(
      dlist, index=index, item=array_ops.zeros_like(item))
  index_grad = None
  element_grad = gen_list_ops.tensor_list_get_item(
      dlist, index, element_dtype=item.dtype)
  return list_grad, index_grad, element_grad
コード例 #3
0
def _TensorListSetItemGrad(op, dlist):
    """Gradient function for TensorListSetItem."""
    _, index, item = op.inputs
    list_grad = gen_list_ops.tensor_list_set_item(
        dlist, index=index, item=array_ops.zeros_like(item))
    index_grad = None
    element_grad = tensor_list_get_item(dlist,
                                        index,
                                        element_shape=array_ops.shape(item),
                                        element_dtype=item.dtype)
    return list_grad, index_grad, element_grad
コード例 #4
0
ファイル: list_ops.py プロジェクト: becster/tensorflow
def _TensorListGetItemGrad(op, ditem):
  """Gradient for TensorListGetItem."""
  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
  list_grad = gen_list_ops.tensor_list_set_item(
      gen_list_ops.tensor_list_reserve(
          gen_list_ops.tensor_list_element_shape(op.inputs[0],
                                                 shape_type=dtypes.int32),
          list_size, element_dtype=ditem.dtype),
      index=op.inputs[1],
      item=ditem)
  index_grad = None
  return list_grad, index_grad
コード例 #5
0
def _TensorListGetItemGrad(op, ditem):
  """Gradient for TensorListGetItem."""
  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
  list_grad = gen_list_ops.tensor_list_set_item(
      gen_list_ops.tensor_list_reserve(
          gen_list_ops.tensor_list_element_shape(op.inputs[0],
                                                 shape_type=dtypes.int32),
          list_size, element_dtype=ditem.dtype),
      index=op.inputs[1],
      item=ditem)
  index_grad = None
  return list_grad, index_grad
コード例 #6
0
ファイル: list_ops.py プロジェクト: adit-chandra/tensorflow
def _TensorListSetItemGrad(op, dlist):
  """Gradient function for TensorListSetItem."""
  _, index, item = op.inputs
  list_grad = gen_list_ops.tensor_list_set_item(
      dlist, index=index, item=array_ops.zeros_like(item))
  index_grad = None
  element_grad = tensor_list_get_item(
      dlist,
      index,
      element_shape=array_ops.shape(item),
      element_dtype=item.dtype)
  return list_grad, index_grad, element_grad
コード例 #7
0
ファイル: list_ops.py プロジェクト: datanonymous/TFandroid
def tensor_list_set_item(input_handle,
                         index,
                         item,
                         resize_if_index_out_of_bounds=False,
                         name=None):
  """Sets `item` at `index` in input list."""
  if resize_if_index_out_of_bounds:
    input_list_size = gen_list_ops.tensor_list_length(input_handle)
    # TODO(srbs): This could cause some slowdown. Consider fusing resize
    # functionality in the SetItem op.
    input_handle = control_flow_ops.cond(
        index >= input_list_size,
        lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
            input_handle, index + 1),
        lambda: input_handle)
  return gen_list_ops.tensor_list_set_item(
      input_handle=input_handle, index=index, item=item, name=name)
コード例 #8
0
ファイル: list_ops.py プロジェクト: terrytangyuan/tensorflow
def tensor_list_set_item(input_handle,
                         index,
                         item,
                         resize_if_index_out_of_bounds=False,
                         name=None):
  """Sets `item` at `index` in input list."""
  if resize_if_index_out_of_bounds:
    input_list_size = gen_list_ops.tensor_list_length(input_handle)
    # TODO(srbs): This could cause some slowdown. Consider fusing resize
    # functionality in the SetItem op.
    input_handle = control_flow_ops.cond(
        index >= input_list_size,
        lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
            input_handle, index + 1),
        lambda: input_handle)
  return gen_list_ops.tensor_list_set_item(
      input_handle=input_handle, index=index, item=item, name=name)