def get_tensors_from_tensor_names(graph, tensor_names): """Gets the Tensors associated with the `tensor_names` in the provided graph. Args: graph: TensorFlow Graph. tensor_names: List of strings that represent names of tensors in the graph. Returns: A list of Tensor objects in the same order the names are provided. Raises: ValueError: tensor_names contains an invalid tensor name. """ # Get the list of all of the tensors. tensor_name_to_tensor = { tensor_name(tensor): tensor for op in graph.get_operations() for tensor in op.values() } # Get the tensors associated with tensor_names. tensors = [] invalid_tensors = [] for name in tensor_names: tensor = tensor_name_to_tensor.get(name) if tensor is None: invalid_tensors.append(name) else: tensors.append(tensor) # Throw ValueError if any user input names are not valid tensors. if invalid_tensors: raise ValueError("Invalid tensors '{}' were found.".format( ",".join(invalid_tensors))) return tensors
def set_tensor_shapes(tensors, shapes): """Sets Tensor shape for each tensor if the shape is defined. Args: tensors: TensorFlow ops.Tensor. shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). Raises: ValueError: `shapes` contains an invalid tensor. `shapes` contains an invalid shape for a valid tensor. """ if shapes: tensor_names_to_tensor = {tensor_name(tensor): tensor for tensor in tensors} for name, shape in shapes.items(): if name not in tensor_names_to_tensor: raise ValueError("Invalid tensor \'{}\' found in tensor shapes " "map.".format(name)) if shape is not None: tensor = tensor_names_to_tensor[name] try: tensor.set_shape(shape) except ValueError as error: message = ("The shape of tensor '{0}' cannot be changed from {1} to " "{2}. {3}".format(name, tensor.get_shape(), shape, str(error))) raise ValueError(message)
def testTensorName(self): in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32) # out_tensors should have names: "split:0", "split:1", "split:2", "split:3". out_tensors = array_ops.split( value=in_tensor, num_or_size_splits=[1, 1, 1, 1], axis=0) expect_names = ["split", "split:1", "split:2", "split:3"] for i in range(len(expect_names)): got_name = convert.tensor_name(out_tensors[i]) self.assertEqual(got_name, expect_names[i])
def set_tensor_shapes(tensors, shapes): """Sets Tensor shape for each tensor if the shape is defined. Args: tensors: TensorFlow ops.Tensor. shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). """ if shapes: for tensor in tensors: shape = shapes.get(tensor_name(tensor)) if shape is not None: tensor.set_shape(shape)
def set_tensor_shapes(tensors, shapes): """Sets Tensor shape for each tensor if the shape is defined. Args: tensors: TensorFlow ops.Tensor. shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). Raises: ValueError: `shapes` contains an invalid tensor. """ if shapes: tensor_names_to_tensor = { tensor_name(tensor): tensor for tensor in tensors } for name, shape in shapes.items(): if name not in tensor_names_to_tensor: raise ValueError( "Invalid tensor \'{}\' found in tensor shapes " "map.".format(name)) if shape is not None: tensor_names_to_tensor[name].set_shape(shape)