def _TensorListGatherGrad(op, dtensor): input_list, indices = op.inputs dlist = gen_list_ops.tensor_list_scatter( tensor=dtensor, indices=indices, element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32)) # TensorListScatter returns a list with size `max(indices) + 1` # so we manually resize it to match the size of the input list. input_list_size = gen_list_ops.tensor_list_length(input_list) dlist = gen_list_ops.tensor_list_resize(dlist, input_list_size) return dlist, None
def _TensorListGatherGrad(op, dtensor): """Gradient function for TensorListGather.""" input_list, indices, _ = op.inputs dlist = gen_list_ops.tensor_list_scatter( tensor=dtensor, indices=indices, element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32)) # TensorListScatter returns a list with size `max(indices) + 1` # so we manually resize it to match the size of the input list. input_list_size = gen_list_ops.tensor_list_length(input_list) dlist = gen_list_ops.tensor_list_resize(dlist, input_list_size) return dlist, None, None
def tensor_list_set_item(input_handle, index, item, resize_if_index_out_of_bounds=False, name=None): """Sets `item` at `index` in input list.""" if resize_if_index_out_of_bounds: input_list_size = gen_list_ops.tensor_list_length(input_handle) # TODO(srbs): This could cause some slowdown. Consider fusing resize # functionality in the SetItem op. input_handle = control_flow_ops.cond( index >= input_list_size, lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda input_handle, index + 1), lambda: input_handle) return gen_list_ops.tensor_list_set_item( input_handle=input_handle, index=index, item=item, name=name)
def tensor_list_set_item(input_handle, index, item, resize_if_index_out_of_bounds=False, name=None): """Sets `item` at `index` in input list.""" if resize_if_index_out_of_bounds: input_list_size = gen_list_ops.tensor_list_length(input_handle) # TODO(srbs): This could cause some slowdown. Consider fusing resize # functionality in the SetItem op. input_handle = control_flow_ops.cond( index >= input_list_size, lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda input_handle, index + 1), lambda: input_handle) return gen_list_ops.tensor_list_set_item( input_handle=input_handle, index=index, item=item, name=name)
def _TensorListResizeGrad(op, dlist): input_list, _ = op.inputs input_list_size = gen_list_ops.tensor_list_length(input_list) return gen_list_ops.tensor_list_resize(dlist, input_list_size), None
def _TensorListResizeGrad(op, dlist): input_list, _ = op.inputs input_list_size = gen_list_ops.tensor_list_length(input_list) return gen_list_ops.tensor_list_resize(dlist, input_list_size), None