def get_type_embedding_net_variables_from_graph_def(graph_def: tf.GraphDef, suffix: str = "") -> Dict: """ Get the type embedding net variables with the given tf.GraphDef object Parameters ---------- graph_def : tf.GraphDef The input tf.GraphDef object suffix : str, optional The suffix of the scope Returns ---------- Dict The embedding net variables within the given tf.GraphDef object """ type_embedding_net_variables = {} type_embedding_net_nodes = get_type_embedding_net_nodes_from_graph_def( graph_def, suffix=suffix) for item in type_embedding_net_nodes: node = type_embedding_net_nodes[item] dtype = tf.as_dtype(node.dtype).as_numpy_dtype tensor_shape = tf.TensorShape(node.tensor_shape).as_list() if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): tensor_value = np.frombuffer(node.tensor_content, dtype=tf.as_dtype( node.dtype).as_numpy_dtype) else: tensor_value = get_tensor_by_type(node, dtype) type_embedding_net_variables[item] = np.reshape( tensor_value, tensor_shape) return type_embedding_net_variables
def _get_matrix(self): matrix = {} for layer in range(1, self.layer_size + 1): matrix["layer_" + str(layer)] = [] if self.type_one_side: for ii in range(0, self.ntypes): node = self.embedding_net_nodes[ f"filter_type_all{self.suffix}/matrix_{layer}_{ii}"] tensor_value = np.frombuffer( node.tensor_content, dtype=tf.as_dtype(node.dtype).as_numpy_dtype) tensor_shape = tf.TensorShape(node.tensor_shape).as_list() matrix["layer_" + str(layer)].append( np.reshape(tensor_value, tensor_shape)) else: for ii in range(0, self.ntypes * self.ntypes): if (ii // self.ntypes, int(ii % self.ntypes)) not in self.exclude_types: node = self.embedding_net_nodes[ f"filter_type_{ii // self.ntypes}{self.suffix}/matrix_{layer}_{ii % self.ntypes}"] tensor_value = np.frombuffer( node.tensor_content, dtype=tf.as_dtype(node.dtype).as_numpy_dtype) tensor_shape = tf.TensorShape( node.tensor_shape).as_list() matrix["layer_" + str(layer)].append( np.reshape(tensor_value, tensor_shape)) else: matrix["layer_" + str(layer)].append(np.array([])) return matrix
def get_fitting_net_variables_from_graph_def(graph_def: tf.GraphDef) -> Dict: """ Get the fitting net variables with the given tf.GraphDef object Parameters ---------- graph_def The input tf.GraphDef object Returns ---------- Dict The fitting net variables within the given tf.GraphDef object """ fitting_net_variables = {} fitting_net_nodes = get_fitting_net_nodes_from_graph_def(graph_def) for item in fitting_net_nodes: node = fitting_net_nodes[item] dtype = tf.as_dtype(node.dtype).as_numpy_dtype tensor_shape = tf.TensorShape(node.tensor_shape).as_list() if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): tensor_value = np.frombuffer(node.tensor_content, dtype=tf.as_dtype( node.dtype).as_numpy_dtype) else: tensor_value = get_tensor_by_type(node, dtype) fitting_net_variables[item] = np.reshape(tensor_value, tensor_shape) return fitting_net_variables
def transform_graph(raw_graph: tf.Graph, old_graph: tf.Graph) -> tf.Graph: """Trasform old graph into new. Parameters ---------- raw_graph : tf.Graph graph receiving parameters from the old one old_graph : tf.Graph graph providing parameters Returns ------- tf.Graph new graph with parameters transfered form the old one """ old_graph_def = old_graph.as_graph_def() raw_graph_def = raw_graph.as_graph_def() raw_graph_node = load_transform_node(raw_graph_def) old_graph_node = load_transform_node(old_graph_def) for node in raw_graph_def.node: if node.name not in raw_graph_node.keys(): continue old_node = old_graph_node[node.name] raw_node = raw_graph_node[node.name] cp_attr = CopyNodeAttr(node) check_dim(raw_graph_node, old_graph_node, node.name) tensor_shape = [dim.size for dim in raw_node.tensor_shape.dim] old_graph_dtype = tf.as_dtype(old_node.dtype).as_numpy_dtype raw_graph_dtype = tf.as_dtype(raw_node.dtype).as_numpy_dtype log.info(f"{node.name} is passed from old graph({old_graph_dtype}) " f"to raw graph({raw_graph_dtype})") if raw_graph_dtype == np.float16: if old_graph_dtype == np.float64 or old_graph_dtype == np.float32: if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): tensor = np.frombuffer(old_node.tensor_content, dtype=raw_graph_dtype) cp_attr.from_array(tensor, tf.float16, shape=tensor_shape) else: tensor = load_tensor(old_node, old_graph_dtype, raw_graph_dtype) cp_attr.from_array(tensor, tf.float16, [1]) elif old_graph_dtype[1] == "float16": tensor = convertMatrix(np.array(old_node.half_val), tensor_shape) cp_attr.from_array(tensor, raw_graph_dtype) elif raw_graph_dtype == np.float64 or raw_graph_dtype == np.float32: if old_graph_dtype == np.float64 or old_graph_dtype == np.float32: if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): tensor = np.frombuffer(old_node.tensor_content, dtype=raw_graph_dtype) cp_attr.from_str(tensor) else: tensor = load_tensor(old_node, old_graph_dtype, raw_graph_dtype) cp_attr.from_array(tensor, raw_graph_dtype, shape=[1]) elif old_graph_dtype == np.float16: if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): tensor = convertMatrix( np.array(old_node.half_val), tensor_shape).astype(raw_graph_dtype) cp_attr.from_str(tensor) else: tensor = convertMatrix( np.array(old_node.half_val), tensor_shape).astype(raw_graph_dtype) cp_attr.from_array(tensor, raw_graph_dtype) return raw_graph_def