def transform_fuse_activations(tf_graph): # type: (TFGraph)->None fuse_to = [ "tf.add", "tf.subtract", "tf.multiply", "tf.divide", "tf.nn.conv2d", "tf.nn.depthwise_conv2d", "tf.nn.max_pool", "tf.nn.avg_pool", # "tf.nn.conv2d_transpose", (not working yet) "tf.matmul", "tf.nn.l2_normalize", # "tf.concat" (not working yet) ] conv_output = matcher.Tensor() convlike = matcher.Operation(name=fuse_to, outputs=conv_output) activation = matcher.Operation(name="tf.nn.relu", inputs={0: conv_output}) matcher.replace( tf_graph, activation, lambda m: TFOperation( graph=tf_graph, name=m[convlike].name, attribs=utils.dict_union(m[convlike].attribs, dict(fused_activation_function='RELU')), inputs=m[convlike].inputs, outputs=m[activation].outputs), lambda m: not m[convlike].attribs.get('fused_activation_function')) conv_output = matcher.Tensor() convlike = matcher.Operation(name=fuse_to, outputs=conv_output) activation = matcher.Operation(name="tf.clip_by_value", inputs={0: conv_output}) matcher.replace( graph=tf_graph, pattern=activation, replacement=lambda m: TFOperation( graph=tf_graph, name=m[convlike].name, attribs=utils.dict_union(m[convlike].attribs, dict(fused_activation_function='RELU6')), inputs=m[convlike].inputs, outputs=m[activation].outputs), condition=lambda m: (m[activation].inputs[1].data == [0] and m[activation].inputs[2].data == [6] and not m[convlike].attribs.get('fused_activation_function')))
def _merge_pads(g): # type: (NNEFGraph)->None t = matcher.Tensor() pad = matcher.Operation(name=['box', 'pad'], outputs=t) sliding = matcher.Operation(name=['argmax_pool', 'max_pool', 'max_pool_with_index', 'avg_pool', 'conv'], inputs={0: t}) def condition(m): # type: (matcher.Match)->bool if not (m[pad].name == 'pad' or (m[pad].name == 'box' and all(s == 1 for s in m[pad].attribs.get('size', [])) and all(s == 1 for s in m[pad].attribs.get('stride', [])) and all(s == 1 for s in m[pad].attribs.get('dilation', [])) and not m[pad].attribs.get('normalize', False))): return False value = m[pad].attribs.get('_value', 0.0) if value not in [0.0, float('-inf')]: return False if value == float('-inf'): if not m[sliding].name in ['argmax_pool', 'max_pool', 'max_pool_with_index']: return False if m[pad].attribs.get('border', 'constant') != 'constant': return False if (m[sliding].attribs.get('border', 'constant') != 'constant' and any(p != 0 or q != 0 for p, q in m[sliding].attribs.get('padding', []))): return False if m[sliding].name in ['conv'] and any(p != 0 or q != 0 for p, q in m[pad].attribs.get('padding', [])[:2]): return False return True def action(m): # type: (matcher.Match)->None value = m[pad].attribs.get('_value', 0.0) pad_padding = m[pad].attribs.get('padding', [(0, 0) * m[t].rank]) sliding_padding = m[sliding].attribs.get('padding', [(0, 0) * m[t].rank]) if m[sliding].name in ['conv']: pad_padding = pad_padding[2:] assert len(pad_padding) == len(sliding_padding) m[sliding].attribs['padding'] = [(p + pp, q + qq) for (p, q), (pp, qq) in zip(pad_padding, sliding_padding)] m[sliding].attribs['border'] = 'ignore' if value == float('-inf') else 'constant' graph_utils.remove_passthrough(g, m[pad]) matcher.for_each(graph=g, pattern=sliding, action=action, condition=condition) for op in g.operations: if op.name in ['box', 'pad'] and '_value' in op.attribs: raise utils.NNEFToolsException('Could not export {} with value={}'.format(op.name, op.attribs['_value']))
def transform_fuse_bias_add_to_conv(tf_graph): # type: (TFGraph)->None conv_output = matcher.Tensor() conv = matcher.Operation(name=["tf.nn.conv2d", "tf.nn.depthwise_conv2d"], outputs=conv_output) add = matcher.Operation(name="tf.nn.bias_add", inputs={0: conv_output}) matcher.replace(tf_graph, add, lambda m: TFOperation(graph=tf_graph, name=m[conv].name, attribs=m[conv].attribs, inputs=tuple(m[conv].inputs) + (m[add].inputs[1],), outputs=m[add].outputs), lambda m: len(m[conv].inputs) == 2)
def transform_bias_add_conv(g): # type: (TFGraph)->None conv_output = matcher.Tensor() conv = matcher.Operation(name=[ "_conv", "_planewise_conv", "_separable_conv", "_deconv", "_planewise_deconv" ], inputs={2: None}, outputs={0: conv_output}) add = matcher.Operation(name="tf.nn.bias_add", inputs={0: conv_output}) matcher.replace( g, add, lambda m: TFOperation(graph=g, name=m[conv].name, inputs=(m[conv].inputs[0], m[conv]. inputs[1], m[add].inputs[1]), attribs=m[conv].attribs, outputs=m[add].outputs))
def _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver): # type: (BaseGraph, BaseTensor, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None assert driver.conv_grad_filter_op_names cgf1_output = matcher.Tensor() cgf1 = matcher.Operation(name=driver.conv_grad_filter_op_names, outputs=cgf1_output) transpose1 = matcher.Operation(name=driver.transpose_op_name, inputs=cgf1_output) cgf2_output = matcher.Tensor() cgf2 = matcher.Operation(name=driver.conv_grad_filter_op_names, outputs=cgf2_output) reshape2_output = matcher.Tensor() reshape2 = matcher.Operation(name=driver.reshape_op_name, inputs=cgf2_output, outputs=reshape2_output) transpose2 = matcher.Operation(name=driver.transpose_op_name, inputs=reshape2_output) if tensor.producer is None: raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") m = matcher.match(g, tensor.producer, matcher.OrPattern(transpose1, transpose2)) if transpose1 in m: cgf = m[cgf1] # type: BaseOperation transpose = m[transpose1] # type: BaseOperation if not (len(transpose.output.consumers) <= 1 and cgf.output not in g.outputs): raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") cgf.output.name = transpose.output.name add_transform( transforms_by_name, cgf.output, Transpose( utils.inverse_permutation( driver.get_axes_from_transpose(transpose)))) graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output) graph_utils.remove_subgraph(g, [transpose]) elif transpose2 in m: cgf = m[cgf2] # type: BaseOperation reshape = m[reshape2] # type: BaseOperation transpose = m[transpose2] # type: BaseOperation if not (len(reshape.output.consumers) <= 1 and len(transpose.output.consumers) <= 1 and cgf.output not in g.outputs): raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") cgf.output.name = transpose.output.name add_transform( transforms_by_name, cgf.output, Transpose( utils.inverse_permutation( driver.get_axes_from_transpose(transpose)))) add_transform(transforms_by_name, cgf.output, Reshape(cgf.output.shape)) graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output) graph_utils.remove_subgraph(g, [transpose, reshape]) else: raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")