def input_to_feed_dict(graph: tf.Graph, input_data: Union[dict, xr.Dataset]) \ -> Dict[Union[Union[tf.Tensor, tf.Operation], Any], Any]: """ Converts some input data to a feedable dict for Tensorflow sessions based on the placeholders in a tf.Graph :param graph: tf.Graph object :param input_data: either xr.Dataset or some dict{"placeholder": data} :return: dict{"placeholder:0", data} for all placeholder names in `input_data` """ placeholders = { op.name: op for op in graph.get_operations() if op.type.lower().startswith("placeholder") } if isinstance(input_data, xr.Dataset): keys = input_data.variables.keys() else: keys = input_data.keys() keys = set(keys).intersection(placeholders.keys()) retval = {} for k in keys: retval[graph.get_tensor_by_name(k + ":0")] = input_data[k] return retval
def summary(g: tf.Graph): return """ graph has %s tensors. %s inputs %s outputs %s ops """ % (len(all_tensors(g)), num_ph(g), len( get_outputs(g)), len(g.get_operations()))
def graph_has_op(g: tf.Graph, op_name: str): """ A method that really ought to be part of `tf.Graph`. Returns true of the indicated graph has an op by the indicated name. """ all_ops_in_graph = g.get_operations() names_of_all_ops_in_graph = [o.name for o in all_ops_in_graph] return op_name in names_of_all_ops_in_graph
def get_weights_for_mnasnet(model: tf.Graph) -> dict: """Extracts weights dictionary from any MNasNet model hosted at **www.tensorflow.org/lite/models** Weights are preprocessed in order to eliminate operations, related to batch normalization. Parameters ---------- model: tf.Graph A static graph, from which weights data must be extracted. Returns ------- dict: A dictionary containing layers' weights data (including biases) """ with tf.Session(graph=model) as sess: operation_names = [ op.name for op in model.get_operations() if op.type == "Const" ] weights = {} # stem cell weights["stem"] = _gw(sess, "stem/conv", "stem/bn") # 0-th cascade cell weights["lead_cell_0/dws"] = _gw(sess, "lead_cell_0/op_0/depthwise_0", "lead_cell_0/op_0/bn1_0") weights["lead_cell_0/project"] = _gw(sess, "lead_cell_0/op_0/project_0", "lead_cell_0/op_0/bn2_0") # last cascade cell weights["lead_cell_17"] = _gw(sess, "lead_cell_17/op_0/conv2d_0", "lead_cell_17/op_0/bn_0") # intermediate cascade cells for cell_index in range(1, 17): cell_scope = _get_cascade_cell_name(operation_names, cell_index) # expand -> dws -> project weights[cell_scope + "/expand"] = _gw( sess, cell_scope + "/op_0/expand_0", cell_scope + "/op_0/bn0_0") weights[cell_scope + "/dws"] = _gw( sess, cell_scope + "/op_0/depthwise_0", cell_scope + "/op_0/bn1_0") weights[cell_scope + "/project"] = _gw( sess, cell_scope + "/op_0/project_0", cell_scope + "/op_0/bn2_0") # output cell weights["output/fc"] = _gw(sess, "output/fc", bias_scope="output/fc") return _fold_weights(weights)
def create_tensor_dict(detection_graph: tensorflow.Graph) -> dict: tensor_dict = {} with tensorflow.compat.v1.Session(graph=detection_graph): ops = detection_graph.get_operations() all_tensor_names = {output.name for op in ops for output in op.outputs} for key in ['num_detections', 'detection_boxes', 'detection_scores', 'detection_classes']: tensor_name = key + ':0' if tensor_name in all_tensor_names: tensor_dict[key] = detection_graph.get_tensor_by_name(tensor_name) return tensor_dict
def _get_quant_ops_from_tf_graph(self, gr: tf.Graph): """ utility to get quant op names in given graph :param graph: tf.Graph :return: """ ops = gr.get_operations() quantized_graph_op_names = [ op.name for op in ops if op.type in ["QcQuantize", "QcQuantizeRecurrentParam"] ] return quantized_graph_op_names
def get_weights_for_effnetb0(model: tf.Graph) -> dict: """Extracts weights dictionary from any MNasNet model hosted at **www.tensorflow.org/lite/models** Weights are preprocessed in order to eliminate operations, related to batch normalization. Parameters ---------- model: tf.Graph A static graph, from which weights data must be extracted. Returns ------- dict: A dictionary containing layers' weights data (including biases) """ with tf.Session(graph=model) as sess: operation_names = [op.name for op in model.get_operations() if op.type=="Const" and op.name.find('Mean')==-1] weights = {} # stem cell weights["Conv"] = _gw(sess, "Conv", "Conv/BatchNorm") # 0-th cascade cell weights["expanded_conv/dws"] = _gw(sess, "expanded_conv/depthwise", "expanded_conv/depthwise/BatchNorm") weights["expanded_conv/se_1"] = _gw(sess, "expanded_conv/se_1", bias_scope="expanded_conv/se_1") weights["expanded_conv/se_2"] = _gw(sess, "expanded_conv/se_2", bias_scope="expanded_conv/se_2") weights["expanded_conv/project"] = _gw(sess, "expanded_conv/project", "expanded_conv/project/BatchNorm") # intermediate cascade cells for cell_index in range(1, 16): cell_scope = _get_cascade_cell_name(operation_names, cell_index) # expand -> dws -> project weights[cell_scope + "/expand"] = _gw(sess, cell_scope + "/expand", cell_scope + "/expand/BatchNorm") weights[cell_scope + "/dws"] = _gw(sess, cell_scope + "/depthwise", cell_scope + "/depthwise/BatchNorm") weights[cell_scope + "/se_1"] = _gw(sess, cell_scope + "/se_1", bias_scope=cell_scope + "/se_1") weights[cell_scope + "/se_2"] = _gw(sess, cell_scope + "/se_2", bias_scope=cell_scope + "/se_2") weights[cell_scope + "/project"] = _gw(sess, cell_scope + "/project", cell_scope + "/project/BatchNorm") # output cell weights["Conv_1"] = _gw(sess, "Conv_1", "Conv_1/BatchNorm") weights["output"] = _gw(sess, "Logits/Conv2d_1c_1x1", bias_scope="Logits/Conv2d_1c_1x1") return _fold_weights(weights)
def get_training_tensors(graph: tf.Graph): """ Return a list of tensors in the graph used to set training mode :param graph: Graph to search for training tensors in :return: List of tensors in the graph used to set training mode """ training_tensors = set() for op in graph.get_operations(): # Currently the only training tensors we know of are attached to FusedBatchNorm blocks if op.type == 'FusedBatchNormV3' and op.get_attr('is_training'): try: switch_op = op.inputs[0].op assert switch_op.type == 'Switch' pred_id_op = switch_op.inputs[1].op assert pred_id_op.type == 'Identity' training_tensor = pred_id_op.inputs[0] training_tensors.add(training_tensor) # pylint: disable=bare-except except: continue return training_tensors
def print_graph(graph: tf.Graph): for i in graph.get_operations(): print(i)
def log_entry_points(g: tf.Graph): logging.info('Entry points: %s', [o.name for o in g.get_operations() if 'entry_point' in o.name])