def _resample_image_gradient(op: tf.Operation, mask_grad, value_grad): assert resample_image_ops is not None assert mask_grad is None grad_feature = op.get_attr("grad_feature") grad_coordinate = op.get_attr("grad_coordinate") onebased = op.get_attr('onebased') feature, coordinate, kernel = op.inputs mask, value = op.outputs feature_grad, coordinate_grad = \ resample_image_gradient_op( feature, coordinate, kernel, mask, value, value_grad, grad_feature=grad_feature, grad_coordinate=grad_coordinate, onebased=onebased) if not grad_feature: feature_grad = None if not grad_coordinate: coordinate_grad = None return feature_grad, None, coordinate_grad
def get_epsilon(bn_op: tf.Operation) -> float: """ Returns epsilon extracted from given bn op. :param bn_op: bn_op obtained from connected graph using get_modules a mul_1 op inside BN scope. :return: epsilon value """ if bn_op.type in ['Mul']: assert len(bn_op.inputs) >= 2, _BN_STRUCTURE_ERROR_MSG mul = bn_op.inputs[1].op assert len(mul.inputs) >= 1, _BN_STRUCTURE_ERROR_MSG rsqrt = mul.inputs[0].op assert len(rsqrt.inputs) >= 1, _BN_STRUCTURE_ERROR_MSG add = rsqrt.inputs[0].op assert len(add.inputs) >= 2, _BN_STRUCTURE_ERROR_MSG epsilon = add.inputs[1].op numpy_epsilon = epsilon.get_attr('value').float_val[0] elif bn_op.type in ['FusedBatchNormV3', 'FusedBatchNorm']: # epsilon can be derived as attribute value numpy_epsilon = bn_op.get_attr("epsilon") else: logger.error("Error, unknown BN op") assert False return numpy_epsilon
def get_conv2d_activation_shape(sess: tf.compat.v1.Session, op: tf.Operation, input_op_names: List[str], input_shape: Union[Tuple, List[Tuple]], input_activation: bool) -> List: """ :param sess: TensorFlow Session :param op: TensorFlow op :param input_op_names: list of input op names of model :param input_shape: tuple or list of tuple of input shape of model :param input_activation: whether input / output activation shape :return: List of input / output activation shape in Common format [NCHW] """ # use static shape for input / output activations if input_activation: activation_shape = op.inputs[0].get_shape().as_list() else: activation_shape = op.outputs[0].get_shape().as_list() data_format = op.get_attr('data_format') # convert input / output activation shape to Common format [NCHW], if channels_last if str(data_format.decode("utf-8")) == "NHWC": activation_shape = [ activation_shape[0], activation_shape[3], activation_shape[1], activation_shape[2] ] # if the static shape is undefined, then find dynamic shape of input / output activations if activation_shape[2] is None: # get input data input_data = create_rand_tensors_given_shapes(input_shape=input_shape) # create feed_dict feed_dict = create_input_feed_dict(graph=op.graph, input_op_names_list=input_op_names, input_data=input_data, training=False) if input_activation: # get the input activation shape by evaluating the input tensor input_tensor = op.inputs[0] activation_shape = input_tensor.eval(feed_dict=feed_dict, session=sess).shape else: # get the output activation shape by evaluating the output tensor output_tensor = op.outputs[0] activation_shape = output_tensor.eval(feed_dict=feed_dict, session=sess).shape # convert output activation shape to Common format [NCHW], if channels_last if str(data_format.decode("utf-8")) == "NHWC": activation_shape = [ activation_shape[0], activation_shape[3], activation_shape[1], activation_shape[2] ] return activation_shape
def change_out_act_shape_to_channels_first(op: tf.Operation) -> List: """ Convert TensorFlow Conv2d output shape 'channels_last' to 'channels_first' :return: shape [N, C, H, W] """ data_format = op.get_attr('data_format') shape = op.outputs[0].get_shape().as_list() if str(data_format.decode("utf-8")) == "NHWC": shape = [shape[0], shape[3], shape[1], shape[2]] return shape
def get_strides_for_split_conv_ops(op: tf.Operation) -> (List, List): """ :param op: TensorFlow Op :return: (conv_a_strides, conv_b_strides) """ if not op.type == 'Conv2D': raise ValueError("Only Conv2d op can be split") strides = op.get_attr("strides") data_format = op.get_attr("data_format") if str(data_format.decode("utf-8")) == "NHWC": conv_a_strides = [strides[1], 1] conv_b_strides = [1, strides[2]] elif str(data_format.decode("utf-8")) == "NCHW": conv_a_strides = [strides[2], 1] conv_b_strides = [1, strides[3]] else: raise ValueError("Unknown data format!") return conv_a_strides, conv_b_strides
def get_conv2d_op_params(op: tf.Operation) -> (Tuple, Tuple, Tuple): """ Get Conv2d op's parameters :param op: TensorFlow Op :return: (strides, padding, groups) """ strides = op.get_attr('strides') data_format = op.get_attr('data_format') padding = op.get_attr('padding') if str(data_format.decode("utf-8")) == "NHWC": strides = (strides[1], strides[2]) elif str(data_format.decode("utf-8")) == "NCHW": strides = (strides[2], strides[3]) else: raise ValueError("unknown data format") # For Conv2D op groups should be 1 groups = 1 return strides, padding, groups
def op_not_in_loop_control_flow_context(graph: tf.Graph, input_op: tf.Operation) -> bool: """ checks if the op is not in loop control flow context or not :param graph: tf.Graph is the active graph :param input_op: op as tf.Operation :return: True if op is not in a loop control flow context, False otherwise. """ # pylint: disable=protected-access active_ctxt = graph._get_control_flow_context() input_ctxt = input_op._get_control_flow_context() if not input_ctxt or input_ctxt is active_ctxt: # input_op isn't in 'a' loop control flow context or # input_op is in the same context as op. return True return False
def get_training(bn_op: tf.Operation) -> Union[None, bool, tf.Tensor]: """ Returns either a boolean of whether the BN op training mode is True or False, or the is_training tensor feeding into the BN op if it is using a tensor to determine the mode dynamically. :param bn_op: bn_op obtained in the connected graph :return: True or False for training mode, or tf.Tensor that determines the mode dynamically. """ assert bn_op.type in ['FusedBatchNormV3', 'FusedBatchNorm', 'Mul'] if bn_op.type == 'FusedBatchNormV3' or bn_op.type == 'FusedBatchNorm': if 'FusedBatchNormV3_1' in bn_op.name: switch_op = bn_op.inputs[0].op pred_id_op = switch_op.inputs[1].op training = pred_id_op.inputs[0] else: training = bn_op.get_attr('is_training') return training # Non fused batchnorm case mul_op = bn_op.inputs[1].op assert mul_op.type == 'Mul' rsqrt_op = mul_op.inputs[0].op assert rsqrt_op.type == 'Rsqrt' add_op = rsqrt_op.inputs[0].op assert add_op.type == 'AddV2' add_input_op = add_op.inputs[0].op if add_input_op.type == 'Squeeze': return True if add_input_op.type == 'ReadVariableOp': return False if add_input_op.type == 'Merge': switch_op = add_input_op.inputs[1].op assert switch_op.type == 'Switch' pred_id_op = switch_op.inputs[1].op assert pred_id_op.type == 'Identity' return pred_id_op.inputs[0] logger.error('Error, unknown BN structure') return None
def get_layer_attributes( sess: tf.compat.v1.Session, op: tf.Operation, input_op_names: List[str], input_shape: Union[Tuple, List[Tuple]]) -> (Tuple, Tuple, Tuple): """ Get attributes (kernel_size, stride, padding) of tf.nn.Conv2d Op :param sess: TensorFLow Session :param op: TensorFLow Operation :param input_op_names: List of input op names of model :param input_shape: tuple or list of tuple of input shape of model :return: (kernel_size, stride, padding) """ # pylint: disable=too-many-locals assert op.type == 'Conv2D' stride = op.get_attr('strides') data_format = op.get_attr('data_format') output_activation_shape = get_conv2d_activation_shape( sess=sess, op=op, input_op_names=input_op_names, input_shape=input_shape, input_activation=False) input_activation_shape = get_conv2d_activation_shape( sess=sess, op=op, input_op_names=input_op_names, input_shape=input_shape, input_activation=True) _, _, activation_h, activation_w = output_activation_shape output_shape = (activation_h, activation_w) _, _, activation_h, activation_w = input_activation_shape input_shape = (activation_h, activation_w) # 'channels_last' format if str(data_format.decode("utf-8")) == "NHWC": stride = (int(stride[1]), int(stride[2])) # 'channels_first' format elif str(data_format.decode("utf-8")) == "NCHW": stride = (int(stride[2]), int(stride[3])) else: raise ValueError("Unknown data format!") # Conv2d weight shape in TensorFlow [kh, kw, Nic, Noc] weight_index = WeightTensorUtils.get_tensor_index_in_given_op(input_op=op) weight_shape = op.inputs[weight_index].shape kernel_size = (int(weight_shape[0]), int(weight_shape[1])) # get the padding for (height, width) dimension padding = get_padding(input_shape=input_shape, output_shape=output_shape, kernel_size=kernel_size, stride=stride) return kernel_size, stride, padding