def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None): result = gen_list_ops.tensor_list_reserve( element_shape=_build_element_shape(element_shape), num_elements=num_elements, element_dtype=element_dtype, name=name) # TODO(b/169968286): gen_ops needs to ensure the metadata is properly # populated for eager operations. _set_handle_data(result, element_shape, element_dtype) return result
def _TensorListGetItemGrad(op, ditem): """Gradient for TensorListGetItem.""" list_size = gen_list_ops.tensor_list_length(op.inputs[0]) list_grad = gen_list_ops.tensor_list_set_item( gen_list_ops.tensor_list_reserve( gen_list_ops.tensor_list_element_shape(op.inputs[0], shape_type=dtypes.int32), list_size, element_dtype=ditem.dtype), index=op.inputs[1], item=ditem) index_grad = None return list_grad, index_grad
def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None): return gen_list_ops.tensor_list_reserve( element_shape=_build_element_shape(element_shape), num_elements=num_elements, element_dtype=element_dtype, name=name)