def tensor_list_scatter(tensor, indices, element_shape, name=None): return gen_list_ops.tensor_list_scatter_v2( tensor=tensor, indices=indices, element_shape=_build_element_shape(element_shape), num_elements=-1, name=name)
def tensor_list_scatter(tensor, indices, element_shape, name=None): return gen_list_ops.tensor_list_scatter_v2( tensor=tensor, indices=indices, element_shape=_build_element_shape(element_shape), num_elements=-1, name=name)
def _TensorListGatherGrad(op, dtensor): """Gradient function for TensorListGather.""" input_list, indices, _ = op.inputs dlist = gen_list_ops.tensor_list_scatter_v2( tensor=dtensor, indices=indices, element_shape=gen_list_ops.tensor_list_element_shape( input_list, shape_type=dtypes.int32), num_elements=gen_list_ops.tensor_list_length(input_list)) return dlist, None, None
def _TensorListGatherGrad(op, dtensor): """Gradient function for TensorListGather.""" input_list, indices, _ = op.inputs dlist = gen_list_ops.tensor_list_scatter_v2( tensor=dtensor, indices=indices, element_shape=gen_list_ops.tensor_list_element_shape( input_list, shape_type=dtypes.int32), num_elements=gen_list_ops.tensor_list_length(input_list)) return dlist, None, None
def tensor_list_scatter(tensor, indices, element_shape=None, input_handle=None, name=None): if input_handle is not None: return gen_list_ops.tensor_list_scatter_into_existing_list( input_handle=input_handle, tensor=tensor, indices=indices, name=name) else: return gen_list_ops.tensor_list_scatter_v2( tensor=tensor, indices=indices, element_shape=_build_element_shape(element_shape), num_elements=-1, name=name)
def tensor_list_scatter(tensor, indices, element_shape=None, input_handle=None, name=None): if input_handle is not None: return gen_list_ops.tensor_list_scatter_into_existing_list( input_handle=input_handle, tensor=tensor, indices=indices, name=name) else: return gen_list_ops.tensor_list_scatter_v2( tensor=tensor, indices=indices, element_shape=_build_element_shape(element_shape), num_elements=-1, name=name)
def tensor_list_scatter(tensor, indices, element_shape=None, input_handle=None, name=None): """Returns a TensorList created or updated by scattering `tensor`.""" tensor = ops.convert_to_tensor(tensor) if input_handle is not None: output_handle = gen_list_ops.tensor_list_scatter_into_existing_list( input_handle=input_handle, tensor=tensor, indices=indices, name=name) handle_data_util.copy_handle_data(input_handle, output_handle) return output_handle else: output_handle = gen_list_ops.tensor_list_scatter_v2( tensor=tensor, indices=indices, element_shape=_build_element_shape(element_shape), num_elements=-1, name=name) _set_handle_data(output_handle, element_shape, tensor.dtype) return output_handle