示例#1
0
def input_to_feed_dict(graph: tf.Graph, input_data: Union[dict, xr.Dataset]) \
        -> Dict[Union[Union[tf.Tensor, tf.Operation], Any], Any]:
    """
    Converts some input data to a feedable dict for Tensorflow sessions based on the placeholders in a tf.Graph

    :param graph: tf.Graph object
    :param input_data: either xr.Dataset or some dict{"placeholder": data}
    :return: dict{"placeholder:0", data} for all placeholder names in `input_data`
    """
    placeholders = {
        op.name: op
        for op in graph.get_operations()
        if op.type.lower().startswith("placeholder")
    }

    if isinstance(input_data, xr.Dataset):
        keys = input_data.variables.keys()
    else:
        keys = input_data.keys()
    keys = set(keys).intersection(placeholders.keys())

    retval = {}
    for k in keys:
        retval[graph.get_tensor_by_name(k + ":0")] = input_data[k]

    return retval
示例#2
0
文件: tf.py 项目: zenna/wacacore
def summary(g: tf.Graph):
    return """
    graph has %s tensors.
    %s inputs
    %s outputs
    %s ops
    """ % (len(all_tensors(g)), num_ph(g), len(
        get_outputs(g)), len(g.get_operations()))
示例#3
0
def graph_has_op(g: tf.Graph, op_name: str):
  """
  A method that really ought to be part of `tf.Graph`. Returns true of the
  indicated graph has an op by the indicated name.
  """
  all_ops_in_graph = g.get_operations()
  names_of_all_ops_in_graph = [o.name for o in all_ops_in_graph]
  return op_name in names_of_all_ops_in_graph
示例#4
0
def get_weights_for_mnasnet(model: tf.Graph) -> dict:
    """Extracts weights dictionary from any MNasNet model hosted at
    **www.tensorflow.org/lite/models**

    Weights are preprocessed in order to eliminate operations, related to batch normalization.

    Parameters
    ----------
    model: tf.Graph
        A static graph, from which weights data must be extracted.

    Returns
    -------
    dict:
        A dictionary containing layers' weights data (including biases)
    """
    with tf.Session(graph=model) as sess:

        operation_names = [
            op.name for op in model.get_operations() if op.type == "Const"
        ]

        weights = {}

        # stem cell
        weights["stem"] = _gw(sess, "stem/conv", "stem/bn")

        # 0-th cascade cell
        weights["lead_cell_0/dws"] = _gw(sess, "lead_cell_0/op_0/depthwise_0",
                                         "lead_cell_0/op_0/bn1_0")
        weights["lead_cell_0/project"] = _gw(sess,
                                             "lead_cell_0/op_0/project_0",
                                             "lead_cell_0/op_0/bn2_0")

        # last cascade cell
        weights["lead_cell_17"] = _gw(sess, "lead_cell_17/op_0/conv2d_0",
                                      "lead_cell_17/op_0/bn_0")

        # intermediate cascade cells
        for cell_index in range(1, 17):
            cell_scope = _get_cascade_cell_name(operation_names, cell_index)
            # expand -> dws -> project
            weights[cell_scope + "/expand"] = _gw(
                sess, cell_scope + "/op_0/expand_0",
                cell_scope + "/op_0/bn0_0")
            weights[cell_scope + "/dws"] = _gw(
                sess, cell_scope + "/op_0/depthwise_0",
                cell_scope + "/op_0/bn1_0")
            weights[cell_scope + "/project"] = _gw(
                sess, cell_scope + "/op_0/project_0",
                cell_scope + "/op_0/bn2_0")

        # output cell
        weights["output/fc"] = _gw(sess, "output/fc", bias_scope="output/fc")

    return _fold_weights(weights)
示例#5
0
def create_tensor_dict(detection_graph: tensorflow.Graph) -> dict:
    tensor_dict = {}
    with tensorflow.compat.v1.Session(graph=detection_graph):
        ops = detection_graph.get_operations()
        all_tensor_names = {output.name for op in ops for output in op.outputs}
        for key in ['num_detections', 'detection_boxes', 'detection_scores', 'detection_classes']:
            tensor_name = key + ':0'
            if tensor_name in all_tensor_names:
                tensor_dict[key] = detection_graph.get_tensor_by_name(tensor_name)
    return tensor_dict
    def _get_quant_ops_from_tf_graph(self, gr: tf.Graph):
        """
        utility to get quant op names in given graph
        :param graph: tf.Graph
        :return:
        """
        ops = gr.get_operations()
        quantized_graph_op_names = [
            op.name for op in ops
            if op.type in ["QcQuantize", "QcQuantizeRecurrentParam"]
        ]

        return quantized_graph_op_names
示例#7
0
def get_weights_for_effnetb0(model: tf.Graph) -> dict:
    """Extracts weights dictionary from any MNasNet model hosted at
    **www.tensorflow.org/lite/models**

    Weights are preprocessed in order to eliminate operations, related to batch normalization.

    Parameters
    ----------
    model: tf.Graph
        A static graph, from which weights data must be extracted.

    Returns
    -------
    dict:
        A dictionary containing layers' weights data (including biases)
    """
    with tf.Session(graph=model) as sess:
    
        operation_names = [op.name for op in model.get_operations() if op.type=="Const" and op.name.find('Mean')==-1]

        weights = {}

        # stem cell
        weights["Conv"] = _gw(sess, "Conv", "Conv/BatchNorm")

        # 0-th cascade cell
        weights["expanded_conv/dws"] = _gw(sess, "expanded_conv/depthwise", "expanded_conv/depthwise/BatchNorm")
        weights["expanded_conv/se_1"] = _gw(sess, "expanded_conv/se_1", bias_scope="expanded_conv/se_1")
        weights["expanded_conv/se_2"] = _gw(sess, "expanded_conv/se_2", bias_scope="expanded_conv/se_2")
        weights["expanded_conv/project"] = _gw(sess, "expanded_conv/project", "expanded_conv/project/BatchNorm")

        # intermediate cascade cells
        for cell_index in range(1, 16):
            cell_scope = _get_cascade_cell_name(operation_names, cell_index)
            # expand -> dws -> project
            weights[cell_scope + "/expand"] = _gw(sess, cell_scope + "/expand", cell_scope + "/expand/BatchNorm")
            weights[cell_scope + "/dws"] = _gw(sess, cell_scope + "/depthwise", cell_scope + "/depthwise/BatchNorm")
            weights[cell_scope + "/se_1"] = _gw(sess, cell_scope + "/se_1", bias_scope=cell_scope + "/se_1")
            weights[cell_scope + "/se_2"] = _gw(sess, cell_scope + "/se_2", bias_scope=cell_scope + "/se_2")
            weights[cell_scope + "/project"] = _gw(sess, cell_scope + "/project", cell_scope + "/project/BatchNorm")

        # output cell
        weights["Conv_1"] = _gw(sess, "Conv_1", "Conv_1/BatchNorm")
        weights["output"] = _gw(sess, "Logits/Conv2d_1c_1x1", bias_scope="Logits/Conv2d_1c_1x1")
    
    return _fold_weights(weights)
示例#8
0
def get_training_tensors(graph: tf.Graph):
    """
    Return a list of tensors in the graph used to set training mode
    :param graph: Graph to search for training tensors in
    :return: List of tensors in the graph used to set training mode
    """
    training_tensors = set()
    for op in graph.get_operations():
        # Currently the only training tensors we know of are attached to FusedBatchNorm blocks
        if op.type == 'FusedBatchNormV3' and op.get_attr('is_training'):
            try:
                switch_op = op.inputs[0].op
                assert switch_op.type == 'Switch'
                pred_id_op = switch_op.inputs[1].op
                assert pred_id_op.type == 'Identity'
                training_tensor = pred_id_op.inputs[0]
                training_tensors.add(training_tensor)
            # pylint: disable=bare-except
            except:
                continue
    return training_tensors
示例#9
0
def print_graph(graph: tf.Graph):
    for i in graph.get_operations():
        print(i)
示例#10
0
def log_entry_points(g: tf.Graph):
    logging.info('Entry points: %s',
                 [o.name for o in g.get_operations() if 'entry_point' in o.name])