def replace(graph, # type: BaseGraph pattern, # type: Pattern replacement, # type: typing.Callable[[Match], typing.Any] # result is not used condition=None # type: typing.Optional[typing.Callable[[Match], bool]] ): # type: (...)->None for op in list(graph.operations): if op.graph is not None: # op can be removed if the graph is not topological-sorted m = match(graph, op, pattern) if m and (condition is None or condition(m)): replacement(m) graph_utils.remove_subgraph(graph, m.operations)
def _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver): # type: (BaseGraph, BaseTensor, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None assert driver.conv_grad_filter_op_names cgf1_output = matcher.Tensor() cgf1 = matcher.Operation(name=driver.conv_grad_filter_op_names, outputs=cgf1_output) transpose1 = matcher.Operation(name=driver.transpose_op_name, inputs=cgf1_output) cgf2_output = matcher.Tensor() cgf2 = matcher.Operation(name=driver.conv_grad_filter_op_names, outputs=cgf2_output) reshape2_output = matcher.Tensor() reshape2 = matcher.Operation(name=driver.reshape_op_name, inputs=cgf2_output, outputs=reshape2_output) transpose2 = matcher.Operation(name=driver.transpose_op_name, inputs=reshape2_output) if tensor.producer is None: raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") m = matcher.match(g, tensor.producer, matcher.OrPattern(transpose1, transpose2)) if transpose1 in m: cgf = m[cgf1] # type: BaseOperation transpose = m[transpose1] # type: BaseOperation if not (len(transpose.output.consumers) <= 1 and cgf.output not in g.outputs): raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") cgf.output.name = transpose.output.name add_transform( transforms_by_name, cgf.output, Transpose( utils.inverse_permutation( driver.get_axes_from_transpose(transpose)))) graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output) graph_utils.remove_subgraph(g, [transpose]) elif transpose2 in m: cgf = m[cgf2] # type: BaseOperation reshape = m[reshape2] # type: BaseOperation transpose = m[transpose2] # type: BaseOperation if not (len(reshape.output.consumers) <= 1 and len(transpose.output.consumers) <= 1 and cgf.output not in g.outputs): raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") cgf.output.name = transpose.output.name add_transform( transforms_by_name, cgf.output, Transpose( utils.inverse_permutation( driver.get_axes_from_transpose(transpose)))) add_transform(transforms_by_name, cgf.output, Reshape(cgf.output.shape)) graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output) graph_utils.remove_subgraph(g, [transpose, reshape]) else: raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")