Beispiel #1
0
    def insert_bias_add_op(sess: tf.compat.v1.Session,
                           conv_op_out_tensor: tf.Tensor,
                           new_bias_tensor: tf.Variable,
                           bias_name="bias_value") -> None:
        """
        Insert bias-add op to given conv op.
        :param sess: model as tf.compat.v1.Session
        :param conv_op_out_tensor: output of conv op that should feed into the new bias op as tf.Tensor
        :param new_bias_tensor:  bias tensor to be added as tf.Variable
        :param bias_name: name string for the bias op
        :return: None ,
        Note : Higher level api needs to perform a save and load to get updated session after usage of this api
        """

        assert conv_op_out_tensor is not None, 'Error, insert_bias_add_op() : conv op output tensor must be provided'
        with sess.graph.as_default():
            if conv_op_out_tensor.consumers():

                consumer_list = []
                for consumer in conv_op_out_tensor.consumers():
                    consumer_list.append(consumer)

                # create new Bias add op
                bias_add_op = tf.nn.bias_add(value=conv_op_out_tensor,
                                             bias=new_bias_tensor,
                                             name=bias_name)

                # use reroute to insert bias-add and swap current outputs of conv with bias-add op
                ge.reroute_ts(bias_add_op,
                              conv_op_out_tensor,
                              can_modify=consumer_list)

                # initialize tensor once it's added
                sess.run(tf.compat.v1.variables_initializer([new_bias_tensor]))
Beispiel #2
0
def _get_kernel_regularizer(kernel_tensor: tf.Tensor) -> Union[None, tf.Tensor]:
    """
    Get a kernel regularizer of the same kind as attached to kernel_tensor
    :param kernel_tensor: Kernel tensor to check for regularization
    :return: A new kernel regularizer if kernel_tensor has regularization, None otherwise
    """
    kernel_regularizer = None
    for consumer in kernel_tensor.consumers():
        if consumer.type == 'L2Loss':
            # Try to see if there is a scale value associated with it
            try:
                l2_regularizer_mul = consumer.outputs[0].consumers()[0]
                scale_op = l2_regularizer_mul.inputs[0].op
                scale_val = scale_op.get_attr('value').float_val[0]
                kernel_regularizer = tf.contrib.layers.l2_regularizer(scale_val)
            except:     # pylint: disable=bare-except
                kernel_regularizer = tf.nn.l2_loss      # pylint: disable=no-member
    return kernel_regularizer
Beispiel #3
0
def iterate_tf_graph_from_tensor(
    graph: Graph,
    current_tensor: tf.Tensor,
    visited_tensors: Dict[tf.Tensor, int],
    visited_ops: Dict[tf.Operation, int],
):
    if current_tensor not in visited_tensors:
        wrapped_tensor = Tensor(
            graph,
            current_tensor.name,
            shape=current_tensor.shape.as_list()
            if current_tensor.shape.dims is not None else None,
            dtype=current_tensor.dtype.as_numpy_dtype
            if current_tensor.dtype.is_numpy_compatible else None,
        )

        visited_tensors[current_tensor] = wrapped_tensor.id
        iterate_tf_graph_from_op(graph, current_tensor.op, visited_tensors,
                                 visited_ops)
        wrapped_tensor.op_id = visited_ops[current_tensor.op]
        for op in current_tensor.consumers():
            iterate_tf_graph_from_op(graph, op, visited_tensors, visited_ops)
            wrapped_tensor.add_output(visited_ops[op])