def _TensorListSetItemGrad(op, dlist): _, index, item = op.inputs list_grad = gen_list_ops.tensor_list_set_item( dlist, index=index, item=array_ops.zeros_like(item)) index_grad = None element_grad = gen_list_ops.tensor_list_get_item( dlist, index, element_dtype=item.dtype) return list_grad, index_grad, element_grad
def _TensorListSetItemGrad(op, dlist): _, index, item = op.inputs list_grad = gen_list_ops.tensor_list_set_item( dlist, index=index, item=array_ops.zeros_like(item)) index_grad = None element_grad = gen_list_ops.tensor_list_get_item( dlist, index, element_dtype=item.dtype) return list_grad, index_grad, element_grad
def _TensorListSetItemGrad(op, dlist): """Gradient function for TensorListSetItem.""" _, index, item = op.inputs list_grad = gen_list_ops.tensor_list_set_item( dlist, index=index, item=array_ops.zeros_like(item)) index_grad = None element_grad = tensor_list_get_item(dlist, index, element_shape=array_ops.shape(item), element_dtype=item.dtype) return list_grad, index_grad, element_grad
def _TensorListGetItemGrad(op, ditem): """Gradient for TensorListGetItem.""" list_size = gen_list_ops.tensor_list_length(op.inputs[0]) list_grad = gen_list_ops.tensor_list_set_item( gen_list_ops.tensor_list_reserve( gen_list_ops.tensor_list_element_shape(op.inputs[0], shape_type=dtypes.int32), list_size, element_dtype=ditem.dtype), index=op.inputs[1], item=ditem) index_grad = None return list_grad, index_grad
def _TensorListGetItemGrad(op, ditem): """Gradient for TensorListGetItem.""" list_size = gen_list_ops.tensor_list_length(op.inputs[0]) list_grad = gen_list_ops.tensor_list_set_item( gen_list_ops.tensor_list_reserve( gen_list_ops.tensor_list_element_shape(op.inputs[0], shape_type=dtypes.int32), list_size, element_dtype=ditem.dtype), index=op.inputs[1], item=ditem) index_grad = None return list_grad, index_grad
def _TensorListSetItemGrad(op, dlist): """Gradient function for TensorListSetItem.""" _, index, item = op.inputs list_grad = gen_list_ops.tensor_list_set_item( dlist, index=index, item=array_ops.zeros_like(item)) index_grad = None element_grad = tensor_list_get_item( dlist, index, element_shape=array_ops.shape(item), element_dtype=item.dtype) return list_grad, index_grad, element_grad
def tensor_list_set_item(input_handle, index, item, resize_if_index_out_of_bounds=False, name=None): """Sets `item` at `index` in input list.""" if resize_if_index_out_of_bounds: input_list_size = gen_list_ops.tensor_list_length(input_handle) # TODO(srbs): This could cause some slowdown. Consider fusing resize # functionality in the SetItem op. input_handle = control_flow_ops.cond( index >= input_list_size, lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda input_handle, index + 1), lambda: input_handle) return gen_list_ops.tensor_list_set_item( input_handle=input_handle, index=index, item=item, name=name)
def tensor_list_set_item(input_handle, index, item, resize_if_index_out_of_bounds=False, name=None): """Sets `item` at `index` in input list.""" if resize_if_index_out_of_bounds: input_list_size = gen_list_ops.tensor_list_length(input_handle) # TODO(srbs): This could cause some slowdown. Consider fusing resize # functionality in the SetItem op. input_handle = control_flow_ops.cond( index >= input_list_size, lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda input_handle, index + 1), lambda: input_handle) return gen_list_ops.tensor_list_set_item( input_handle=input_handle, index=index, item=item, name=name)