Exemple #1
0
def replace(graph,  # type: BaseGraph
            pattern,  # type: Pattern
            replacement,  # type: typing.Callable[[Match], typing.Any] # result is not used
            condition=None  # type: typing.Optional[typing.Callable[[Match], bool]]
            ):
    # type: (...)->None
    for op in list(graph.operations):
        if op.graph is not None:  # op can be removed if the graph is not topological-sorted
            m = match(graph, op, pattern)
            if m and (condition is None or condition(m)):
                replacement(m)
                graph_utils.remove_subgraph(graph, m.operations)
Exemple #2
0
def _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver):
    # type: (BaseGraph, BaseTensor, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None

    assert driver.conv_grad_filter_op_names

    cgf1_output = matcher.Tensor()
    cgf1 = matcher.Operation(name=driver.conv_grad_filter_op_names,
                             outputs=cgf1_output)
    transpose1 = matcher.Operation(name=driver.transpose_op_name,
                                   inputs=cgf1_output)

    cgf2_output = matcher.Tensor()
    cgf2 = matcher.Operation(name=driver.conv_grad_filter_op_names,
                             outputs=cgf2_output)
    reshape2_output = matcher.Tensor()
    reshape2 = matcher.Operation(name=driver.reshape_op_name,
                                 inputs=cgf2_output,
                                 outputs=reshape2_output)
    transpose2 = matcher.Operation(name=driver.transpose_op_name,
                                   inputs=reshape2_output)

    if tensor.producer is None:
        raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")

    m = matcher.match(g, tensor.producer,
                      matcher.OrPattern(transpose1, transpose2))

    if transpose1 in m:
        cgf = m[cgf1]  # type: BaseOperation
        transpose = m[transpose1]  # type: BaseOperation
        if not (len(transpose.output.consumers) <= 1
                and cgf.output not in g.outputs):
            raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")
        cgf.output.name = transpose.output.name
        add_transform(
            transforms_by_name, cgf.output,
            Transpose(
                utils.inverse_permutation(
                    driver.get_axes_from_transpose(transpose))))
        graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output)
        graph_utils.remove_subgraph(g, [transpose])
    elif transpose2 in m:
        cgf = m[cgf2]  # type: BaseOperation
        reshape = m[reshape2]  # type: BaseOperation
        transpose = m[transpose2]  # type: BaseOperation

        if not (len(reshape.output.consumers) <= 1
                and len(transpose.output.consumers) <= 1
                and cgf.output not in g.outputs):
            raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")

        cgf.output.name = transpose.output.name
        add_transform(
            transforms_by_name, cgf.output,
            Transpose(
                utils.inverse_permutation(
                    driver.get_axes_from_transpose(transpose))))
        add_transform(transforms_by_name, cgf.output,
                      Reshape(cgf.output.shape))

        graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output)
        graph_utils.remove_subgraph(g, [transpose, reshape])
    else:
        raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")