示例#1
0
def _resample_image_gradient(op: tf.Operation, mask_grad, value_grad):

    assert resample_image_ops is not None
    assert mask_grad is None

    grad_feature = op.get_attr("grad_feature")
    grad_coordinate = op.get_attr("grad_coordinate")
    onebased = op.get_attr('onebased')
    feature, coordinate, kernel = op.inputs
    mask, value = op.outputs

    feature_grad, coordinate_grad = \
      resample_image_gradient_op(
        feature, coordinate, kernel, mask, value, value_grad,
        grad_feature=grad_feature,
        grad_coordinate=grad_coordinate,
        onebased=onebased)

    if not grad_feature:
        feature_grad = None

    if not grad_coordinate:
        coordinate_grad = None

    return feature_grad, None, coordinate_grad
示例#2
0
    def get_epsilon(bn_op: tf.Operation) -> float:
        """
        Returns epsilon extracted from given bn op.

        :param bn_op: bn_op obtained from connected graph using get_modules a mul_1 op inside BN scope.
        :return: epsilon value
        """

        if bn_op.type in ['Mul']:
            assert len(bn_op.inputs) >= 2, _BN_STRUCTURE_ERROR_MSG
            mul = bn_op.inputs[1].op
            assert len(mul.inputs) >= 1, _BN_STRUCTURE_ERROR_MSG
            rsqrt = mul.inputs[0].op
            assert len(rsqrt.inputs) >= 1, _BN_STRUCTURE_ERROR_MSG
            add = rsqrt.inputs[0].op
            assert len(add.inputs) >= 2, _BN_STRUCTURE_ERROR_MSG
            epsilon = add.inputs[1].op
            numpy_epsilon = epsilon.get_attr('value').float_val[0]
        elif bn_op.type in ['FusedBatchNormV3', 'FusedBatchNorm']:
            # epsilon can be derived as attribute value
            numpy_epsilon = bn_op.get_attr("epsilon")
        else:
            logger.error("Error, unknown BN op")
            assert False

        return numpy_epsilon
示例#3
0
def get_conv2d_activation_shape(sess: tf.compat.v1.Session, op: tf.Operation,
                                input_op_names: List[str],
                                input_shape: Union[Tuple, List[Tuple]],
                                input_activation: bool) -> List:
    """
    :param sess: TensorFlow Session
    :param op: TensorFlow op
    :param input_op_names: list of input op names of model
    :param input_shape: tuple or list of tuple of input shape of model
    :param input_activation: whether input / output activation shape
    :return: List of input / output activation shape in Common format [NCHW]
    """
    # use static shape for input / output activations
    if input_activation:
        activation_shape = op.inputs[0].get_shape().as_list()

    else:
        activation_shape = op.outputs[0].get_shape().as_list()

    data_format = op.get_attr('data_format')

    # convert input / output activation shape to Common format [NCHW], if channels_last
    if str(data_format.decode("utf-8")) == "NHWC":
        activation_shape = [
            activation_shape[0], activation_shape[3], activation_shape[1],
            activation_shape[2]
        ]

    # if the static shape is undefined, then find dynamic shape of input / output activations
    if activation_shape[2] is None:

        # get input data
        input_data = create_rand_tensors_given_shapes(input_shape=input_shape)

        # create feed_dict
        feed_dict = create_input_feed_dict(graph=op.graph,
                                           input_op_names_list=input_op_names,
                                           input_data=input_data,
                                           training=False)

        if input_activation:
            # get the input activation shape by evaluating the input tensor
            input_tensor = op.inputs[0]
            activation_shape = input_tensor.eval(feed_dict=feed_dict,
                                                 session=sess).shape
        else:
            # get the output activation shape by evaluating the output tensor
            output_tensor = op.outputs[0]
            activation_shape = output_tensor.eval(feed_dict=feed_dict,
                                                  session=sess).shape

        # convert output activation shape to Common format [NCHW], if channels_last
        if str(data_format.decode("utf-8")) == "NHWC":
            activation_shape = [
                activation_shape[0], activation_shape[3], activation_shape[1],
                activation_shape[2]
            ]

    return activation_shape
示例#4
0
def change_out_act_shape_to_channels_first(op: tf.Operation) -> List:
    """
    Convert TensorFlow Conv2d output shape 'channels_last' to 'channels_first'
    :return: shape [N, C, H, W]
    """
    data_format = op.get_attr('data_format')
    shape = op.outputs[0].get_shape().as_list()

    if str(data_format.decode("utf-8")) == "NHWC":
        shape = [shape[0], shape[3], shape[1], shape[2]]

    return shape
示例#5
0
def get_strides_for_split_conv_ops(op: tf.Operation) -> (List, List):
    """
    :param op: TensorFlow Op
    :return: (conv_a_strides, conv_b_strides)
    """

    if not op.type == 'Conv2D':
        raise ValueError("Only Conv2d op can be split")

    strides = op.get_attr("strides")
    data_format = op.get_attr("data_format")

    if str(data_format.decode("utf-8")) == "NHWC":
        conv_a_strides = [strides[1], 1]
        conv_b_strides = [1, strides[2]]

    elif str(data_format.decode("utf-8")) == "NCHW":
        conv_a_strides = [strides[2], 1]
        conv_b_strides = [1, strides[3]]

    else:
        raise ValueError("Unknown data format!")

    return conv_a_strides, conv_b_strides
示例#6
0
def get_conv2d_op_params(op: tf.Operation) -> (Tuple, Tuple, Tuple):
    """
    Get Conv2d op's parameters
    :param op: TensorFlow Op
    :return: (strides, padding, groups)
    """

    strides = op.get_attr('strides')
    data_format = op.get_attr('data_format')
    padding = op.get_attr('padding')

    if str(data_format.decode("utf-8")) == "NHWC":
        strides = (strides[1], strides[2])

    elif str(data_format.decode("utf-8")) == "NCHW":
        strides = (strides[2], strides[3])

    else:
        raise ValueError("unknown data format")

    # For Conv2D op groups should be 1
    groups = 1

    return strides, padding, groups
示例#7
0
def op_not_in_loop_control_flow_context(graph: tf.Graph,
                                        input_op: tf.Operation) -> bool:
    """
    checks if the  op is not in loop control flow context or not
    :param graph: tf.Graph is the active graph
    :param input_op: op as tf.Operation
    :return: True if op is not in a loop control flow context, False otherwise.
    """
    # pylint: disable=protected-access
    active_ctxt = graph._get_control_flow_context()
    input_ctxt = input_op._get_control_flow_context()

    if not input_ctxt or input_ctxt is active_ctxt:
        # input_op isn't in 'a' loop control flow context or
        # input_op is in the same context as op.
        return True

    return False
示例#8
0
    def get_training(bn_op: tf.Operation) -> Union[None, bool, tf.Tensor]:
        """
        Returns either a boolean of whether the BN op training mode is True or False, or the is_training tensor
        feeding into the BN op if it is using a tensor to determine the mode dynamically.
        :param bn_op: bn_op obtained in the connected graph
        :return: True or False for training mode, or tf.Tensor that determines the mode dynamically.
        """
        assert bn_op.type in ['FusedBatchNormV3', 'FusedBatchNorm', 'Mul']
        if bn_op.type == 'FusedBatchNormV3' or bn_op.type == 'FusedBatchNorm':
            if 'FusedBatchNormV3_1' in bn_op.name:
                switch_op = bn_op.inputs[0].op
                pred_id_op = switch_op.inputs[1].op
                training = pred_id_op.inputs[0]
            else:
                training = bn_op.get_attr('is_training')
            return training

        # Non fused batchnorm case
        mul_op = bn_op.inputs[1].op
        assert mul_op.type == 'Mul'
        rsqrt_op = mul_op.inputs[0].op
        assert rsqrt_op.type == 'Rsqrt'
        add_op = rsqrt_op.inputs[0].op
        assert add_op.type == 'AddV2'
        add_input_op = add_op.inputs[0].op
        if add_input_op.type == 'Squeeze':
            return True
        if add_input_op.type == 'ReadVariableOp':
            return False
        if add_input_op.type == 'Merge':
            switch_op = add_input_op.inputs[1].op
            assert switch_op.type == 'Switch'
            pred_id_op = switch_op.inputs[1].op
            assert pred_id_op.type == 'Identity'
            return pred_id_op.inputs[0]
        logger.error('Error, unknown BN structure')
        return None
示例#9
0
def get_layer_attributes(
        sess: tf.compat.v1.Session, op: tf.Operation,
        input_op_names: List[str],
        input_shape: Union[Tuple, List[Tuple]]) -> (Tuple, Tuple, Tuple):
    """
    Get attributes (kernel_size, stride, padding) of tf.nn.Conv2d Op
    :param sess: TensorFLow Session
    :param op: TensorFLow Operation
    :param input_op_names: List of input op names of model
    :param input_shape: tuple or list of tuple of input shape of model
    :return: (kernel_size, stride, padding)
    """
    # pylint: disable=too-many-locals
    assert op.type == 'Conv2D'

    stride = op.get_attr('strides')
    data_format = op.get_attr('data_format')

    output_activation_shape = get_conv2d_activation_shape(
        sess=sess,
        op=op,
        input_op_names=input_op_names,
        input_shape=input_shape,
        input_activation=False)

    input_activation_shape = get_conv2d_activation_shape(
        sess=sess,
        op=op,
        input_op_names=input_op_names,
        input_shape=input_shape,
        input_activation=True)

    _, _, activation_h, activation_w = output_activation_shape
    output_shape = (activation_h, activation_w)

    _, _, activation_h, activation_w = input_activation_shape
    input_shape = (activation_h, activation_w)

    # 'channels_last' format
    if str(data_format.decode("utf-8")) == "NHWC":

        stride = (int(stride[1]), int(stride[2]))

    # 'channels_first' format
    elif str(data_format.decode("utf-8")) == "NCHW":

        stride = (int(stride[2]), int(stride[3]))

    else:
        raise ValueError("Unknown data format!")

    # Conv2d weight shape in TensorFlow  [kh, kw, Nic, Noc]
    weight_index = WeightTensorUtils.get_tensor_index_in_given_op(input_op=op)
    weight_shape = op.inputs[weight_index].shape
    kernel_size = (int(weight_shape[0]), int(weight_shape[1]))

    # get the padding for (height, width) dimension
    padding = get_padding(input_shape=input_shape,
                          output_shape=output_shape,
                          kernel_size=kernel_size,
                          stride=stride)

    return kernel_size, stride, padding