Beispiel #1
0
def _PopBackGrad(op, dlist, delement):
    if dlist is None:
        dlist = empty_tensor_list(
            element_dtype=delement.dtype,
            element_shape=gen_list_ops.tensor_list_element_shape(
                op.outputs[0], shape_type=dtypes.int32))
    return gen_list_ops.tensor_list_push_back(dlist, delement), None
Beispiel #2
0
def _PopBackGrad(op, dlist, delement):
  if dlist is None:
    dlist = gen_list_ops.empty_tensor_list(
        element_dtype=delement.dtype,
        element_shape=gen_list_ops.tensor_list_element_shape(
            op.outputs[0], shape_type=dtypes.int32))
  return gen_list_ops.tensor_list_push_back(dlist, delement)
Beispiel #3
0
def _TensorListGatherGrad(op, dtensor):
    """Gradient function for TensorListGather."""
    input_list, indices, _ = op.inputs
    dlist = gen_list_ops.tensor_list_scatter_v2(
        tensor=dtensor,
        indices=indices,
        element_shape=gen_list_ops.tensor_list_element_shape(
            input_list, shape_type=dtypes.int32),
        num_elements=gen_list_ops.tensor_list_length(input_list))
    return dlist, None, None
def _TensorListGatherGrad(op, dtensor):
  """Gradient function for TensorListGather."""
  input_list, indices, _ = op.inputs
  element_shape = gen_list_ops.tensor_list_element_shape(
      input_list, shape_type=dtypes.int32)
  num_elements = gen_list_ops.tensor_list_length(input_list)
  dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype)
  dlist = tensor_list_scatter(
      tensor=dtensor, indices=indices, input_handle=dlist)
  return dlist, None, None
Beispiel #5
0
def _TensorListGatherGrad(op, dtensor):
  """Gradient function for TensorListGather."""
  input_list, indices, _ = op.inputs
  dlist = gen_list_ops.tensor_list_scatter_v2(
      tensor=dtensor,
      indices=indices,
      element_shape=gen_list_ops.tensor_list_element_shape(
          input_list, shape_type=dtypes.int32),
      num_elements=gen_list_ops.tensor_list_length(input_list))
  return dlist, None, None
Beispiel #6
0
def _TensorListGatherGrad(op, dtensor):
  """Gradient function for TensorListGather."""
  input_list, indices, _ = op.inputs
  element_shape = gen_list_ops.tensor_list_element_shape(
      input_list, shape_type=dtypes.int32)
  num_elements = gen_list_ops.tensor_list_length(input_list)
  dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype)
  dlist = tensor_list_scatter(
      tensor=dtensor, indices=indices, input_handle=dlist)
  return dlist, None, None
Beispiel #7
0
def _TensorListConcatGrad(op, dtensor, unused_dlengths):
    """Gradient function for TensorListConcat."""
    dlist = tensor_list_split(
        dtensor,
        element_shape=gen_list_ops.tensor_list_element_shape(
            op.inputs[0], shape_type=dtypes.int32),
        lengths=op.outputs[1])
    if op.type == "TensorListConcatV2":
        return dlist, None, None
    else:
        return dlist
Beispiel #8
0
def _TensorListConcatGrad(op, dtensor, unused_dlengths):
  """Gradient function for TensorListConcat."""
  dlist = tensor_list_split(
      dtensor,
      element_shape=gen_list_ops.tensor_list_element_shape(
          op.inputs[0], shape_type=dtypes.int32),
      lengths=op.outputs[1])
  if op.type == "TensorListConcatV2":
    return dlist, None, None
  else:
    return dlist
Beispiel #9
0
def _TensorListGetItemGrad(op, ditem):
  """Gradient for TensorListGetItem."""
  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
  list_grad = gen_list_ops.tensor_list_set_item(
      gen_list_ops.tensor_list_reserve(
          gen_list_ops.tensor_list_element_shape(op.inputs[0],
                                                 shape_type=dtypes.int32),
          list_size, element_dtype=ditem.dtype),
      index=op.inputs[1],
      item=ditem)
  index_grad = None
  return list_grad, index_grad
def _TensorListGetItemGrad(op, ditem):
  """Gradient for TensorListGetItem."""
  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
  list_grad = gen_list_ops.tensor_list_set_item(
      gen_list_ops.tensor_list_reserve(
          gen_list_ops.tensor_list_element_shape(op.inputs[0],
                                                 shape_type=dtypes.int32),
          list_size, element_dtype=ditem.dtype),
      index=op.inputs[1],
      item=ditem)
  index_grad = None
  return list_grad, index_grad
def _TensorListFromTensorGrad(op, dlist):
    """Gradient for TensorListFromTensor."""
    if op.inputs[0].shape[0] is not None:
        num_elements = op.inputs[0].shape[0]
    else:
        num_elements = None
    if dlist is None:
        dlist = gen_list_ops.empty_tensor_list(
            element_dtype=op.inputs[0].dtype,
            element_shape=gen_list_ops.tensor_list_element_shape(
                op.outputs[0], shape_type=dtypes.int32))
    return gen_list_ops.tensor_list_stack(dlist,
                                          element_dtype=op.inputs[0].dtype,
                                          num_elements=num_elements)
Beispiel #12
0
def _TensorListFromTensorGrad(op, dlist):
  """Gradient for TensorListFromTensor."""
  if op.inputs[0].shape[0] is not None:
    num_elements = op.inputs[0].shape[0]
  else:
    num_elements = None
  if dlist is None:
    dlist = gen_list_ops.empty_tensor_list(
        element_dtype=op.inputs[0].dtype,
        element_shape=gen_list_ops.tensor_list_element_shape(
            op.outputs[0], shape_type=dtypes.int32))
  return gen_list_ops.tensor_list_stack(
      dlist, element_dtype=op.inputs[0].dtype,
      num_elements=num_elements)
Beispiel #13
0
def _TensorListFromTensorGrad(op, dlist):
    """Gradient for TensorListFromTensor."""
    if op.inputs[0].shape.dims and op.inputs[0].shape.dims[0].value is not None:
        num_elements = op.inputs[0].shape.dims[0].value
    else:
        num_elements = None
    if dlist is None:
        dlist = empty_tensor_list(
            element_dtype=op.inputs[0].dtype,
            element_shape=gen_list_ops.tensor_list_element_shape(
                op.outputs[0], shape_type=dtypes.int32))
    tensor_grad = gen_list_ops.tensor_list_stack(
        dlist, element_dtype=op.inputs[0].dtype, num_elements=num_elements)
    shape_grad = None
    return tensor_grad, shape_grad
Beispiel #14
0
def _TensorListFromTensorGrad(op, dlist):
  """Gradient for TensorListFromTensor."""
  if op.inputs[0].shape.dims and op.inputs[0].shape.dims[0].value is not None:
    num_elements = op.inputs[0].shape.dims[0].value
  else:
    num_elements = None
  if dlist is None:
    dlist = empty_tensor_list(
        element_dtype=op.inputs[0].dtype,
        element_shape=gen_list_ops.tensor_list_element_shape(
            op.outputs[0], shape_type=dtypes.int32))
  tensor_grad = gen_list_ops.tensor_list_stack(
      dlist, element_dtype=op.inputs[0].dtype, num_elements=num_elements)
  shape_grad = None
  return tensor_grad, shape_grad