Exemplo n.º 1
0
def test_ignore_internal_bound_flag():
    #     +-2-3-4-5-+
    #   1-+         +-6-
    #     +---------+
    n1 = Node("n1")
    n2 = Node("n2")
    n3 = Node("n3")
    n4 = Node("n4")
    n5 = Node("n5")
    n6 = Node("n6")

    n2.append_prev(n1)
    n3.append_prev(n2)
    n4.append_prev(n3)
    n5.append_prev(n4)
    n6.append_prev(n1)
    n6.append_prev(n5)

    # noinspection PyTypeChecker
    graph = Graph([n1, n3], [n4, n6])

    result = listup_nodes(graph,
                          ignore_internal_input_bound=False,
                          ignore_internal_output_bound=False)
    assert len(result) == 4 and set(result) == {n1, n3, n4, n6}
    assert result.index(n1) < result.index(n6)
    assert result.index(n3) < result.index(n4)

    result = listup_nodes(graph,
                          ignore_internal_input_bound=False,
                          ignore_internal_output_bound=True)
    assert len(result) == 5 and set(result) == {n1, n3, n4, n5, n6}
    assert result.index(n1) < result.index(n6)
    assert result.index(n3) < result.index(n4)
    assert result.index(n4) < result.index(n5)
    assert result.index(n5) < result.index(n6)

    result = listup_nodes(graph,
                          ignore_internal_input_bound=True,
                          ignore_internal_output_bound=False)
    assert len(result) == 5 and set(result) == {n1, n2, n3, n4, n6}
    assert result.index(n1) < result.index(n2)
    assert result.index(n2) < result.index(n3)
    assert result.index(n3) < result.index(n4)
    assert result.index(n1) < result.index(n6)

    result = listup_nodes(graph,
                          ignore_internal_input_bound=True,
                          ignore_internal_output_bound=True)
    assert len(result) == 6 and set(result) == {n1, n2, n3, n4, n5, n6}
    assert result.index(n1) < result.index(n2)
    assert result.index(n2) < result.index(n3)
    assert result.index(n3) < result.index(n4)
    assert result.index(n4) < result.index(n5)
    assert result.index(n5) < result.index(n6)
    assert result.index(n1) < result.index(n6)
Exemplo n.º 2
0
    def generate(cls, graph: Graph, **kwargs):
        graph, _ = WebGLOptimizeRule().optimize(graph)
        if flags.DEBUG:
            traverse.dump(graph)
            with open("cg.dot", "w") as f:
                f.write(traverse.dump_dot(graph))

        memory_layout = allocate(graph)

        constants_map = {}
        for constant in traverse.filter_nodes(traverse.listup_nodes(graph), ConstantVariable):  # type: ConstantVariable
            constants_map[constant.name] = {
                "byte_offset": memory_layout[constant].offset * 4,
                "size": constant.size
            }

        constant_encoder = ConstantEncoder.get_encoder(kwargs.get("constant_encoder_name", None))
        constants_bytes = constant_encoder.encode(memory_layout)

        kernels = cls.generate_kernels(graph)

        descriptor = GraphDescriptor(
            kernels=kernels,
            memory_layout=memory_layout,
            inputs=graph.inputs,
            outputs=graph.outputs,
            constants_encoding=constant_encoder.name,
            constants_map=constants_map,
            licenses=graph.licenses
        )

        return GraphExecutionData(graph, descriptor, constants_bytes)
Exemplo n.º 3
0
def allocate(graph: Graph) -> WebGLMemoryLayout:
    nodes = traverse.listup_nodes(graph)
    operators = traverse.filter_nodes(nodes, Operator)  # type: List[Operator]
    variables = traverse.filter_nodes(nodes, Variable)  # type: List[Variable]

    for i, v in enumerate(variables):
        if v.name is None:
            v.name = _name("v")

    dynamic_constants = traverse.filter_nodes([v for v in variables if not Placeholder.check_resolved(v.size)], ConstantVariable)
    assert len(dynamic_constants) == 0, f"ConstantVariable with unresolved placeholder shape is detected: f{dynamic_constants}"

    allocations = _get_allocations(graph, operators, variables)
    _optimize_buffer_reuse(allocations)

    variable_allocations = {v: allocations[v] for v in variables if not isinstance(v, ConstantVariable)}
    constant_allocations = {v: allocations[v] for v in variables if isinstance(v, ConstantVariable)}

    data = _update_constant_offset(constant_allocations)

    allocations = variable_allocations
    allocations.update(constant_allocations)

    layout = WebGLMemoryLayout(allocations, data)
    return layout
Exemplo n.º 4
0
    def generate(cls, graph: Graph, **kwargs):
        data_dict = {}  # type: Dict[int, Tuple[GraphDescriptor, bytes]]

        for max_texture_size in [4096, 8192, 16384]:
            config.WEBGL_MAX_TEXTURE_SIZE = max_texture_size
            graph, _ = WebGLOptimizeRule().optimize(graph)

            memory_layout = allocate(graph)

            constants_map = {}
            for constant in traverse.filter_nodes(traverse.listup_nodes(graph), ConstantVariable):  # type: ConstantVariable
                constants_map[constant.name] = {
                    "byte_offset": memory_layout[constant].offset * 4,
                    "size": constant.size
                }

            constant_encoder = ConstantEncoder.get_encoder(kwargs.get("constant_encoder_name", None))
            constants_bytes = constant_encoder.encode(memory_layout)

            kernels = cls.generate_kernels(graph)

            descriptor = GraphDescriptor(
                kernels=kernels,
                memory_layout=memory_layout,
                inputs=graph.inputs,
                outputs=graph.outputs,
                constants_encoding=constant_encoder.name,
                constants_map=constants_map,
                licenses=graph.licenses
            )
            data_dict[max_texture_size] = (descriptor, constants_bytes)

        return GraphExecutionData(graph, data_dict)
Exemplo n.º 5
0
def allocate(graph: Graph) -> MemoryLayout:
    nodes = traverse.listup_nodes(graph)
    operators = traverse.filter_nodes(nodes, Operator)  # type: List[Operator]
    variables = traverse.filter_nodes(nodes, Variable)  # type: List[Variable]

    for i, v in enumerate(variables):
        if v.name is None:
            v.name = _name("v")

    dynamic_constants = traverse.filter_nodes([v for v in variables if not Placeholder.check_resolved(v.size)], ConstantVariable)
    assert len(dynamic_constants) == 0, f"ConstantVariable with unresolved placeholder shape is detected: f{dynamic_constants}"

    allocations = _get_allocations(graph, operators, variables)
    _optimize_inplace(operators, allocations)

    variable_allocations = {v: allocations[v] for v in variables if not isinstance(v, ConstantVariable)}
    constant_allocations = {v: allocations[v] for v in variables if isinstance(v, ConstantVariable)}

    _update_offset(variable_allocations)
    _optimize_buffer_reuse(variable_allocations)

    data = _update_constant_offset(constant_allocations)

    for allocation in set(variable_allocations.values()):
        allocation.offset += data.size

    allocations = variable_allocations
    allocations.update(constant_allocations)

    layout = MemoryLayout(allocations, data)

    if flags.VISUALIZE_MEMORY_ALLOCATION:
        _visualize_allocation(operators, variables, layout)

    return layout
Exemplo n.º 6
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        global _rgba_support_operators
        flag_changed = False
        for node in traverse.listup_nodes(graph):
            if node.has_attribute(ChannelMode):
                continue

            if isinstance(node, ConvertRtoRGBA) or isinstance(node, ConvertRGBAtoR):
                continue

            flag_changed = True
            node.attributes.add(ChannelMode(node, ChannelModeEnum.R))

            if isinstance(node, Operator):
                node.attributes.add(SupportedChannelMode(node, ChannelModeEnum.R))

                if node.__class__ not in _rgba_support_operators:
                    continue

                variables = list(node.inputs.values()) + list(node.outputs.values())

                if not all(v.order == variables[0].order for v in variables):
                    continue
                if not all(v.shape == variables[0].shape for v in variables):
                    continue

                node.attributes.add(SupportedChannelMode(node, ChannelModeEnum.RGBA))

        return graph, flag_changed
Exemplo n.º 7
0
    def optimize(self, graph: Graph):
        flag_changed = False

        for v in traverse.filter_nodes(traverse.listup_nodes(graph), SplitTarget):  # type: Variable
            axis = _choose_split_axis(v)
            _split_axis(v, axis, graph)
            flag_changed = True

        return graph, flag_changed
Exemplo n.º 8
0
def test_listup_nodes_hidden_output():
    v0 = Variable((1, 1), OrderNC)
    op1 = Operator("op1")
    v1 = Variable((1, 2), OrderNC)
    op2 = TestOperator("op2")
    v2 = Variable((1, 3), OrderNC)

    op1.append_input("v0", v0)
    op1.append_output("v1", v1)
    op2.append_input("v1", v1)
    op2.append_output("v2", v2)

    graph = Graph([v0], [v1, v2])  # outputs hidden variable

    nodes = listup_nodes(graph)

    assert tuple(nodes) == (v0, op1, v1, op2, v2), str(nodes)
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        if not (flags.optimize.OPTIMIZE
                and flags.optimize.CONCAT_SCALAR_OPERATION):
            return graph, False

        flag_changed = False

        nodes = traverse.listup_nodes(graph)

        filtered_nodes = traverse.filter_nodes(
            nodes, ScalarAffine)  # type: List[ScalarAffine]
        while len(filtered_nodes) > 0:
            op = filtered_nodes.pop()
            if op.scale == 1 and op.bias == 0:
                remove_operator(op)
                flag_changed = True

        filtered_nodes = traverse.filter_nodes(
            nodes, ScalarAdd)  # type: List[ScalarAdd]
        while len(filtered_nodes) > 0:
            op = filtered_nodes.pop()
            if op.value == 0:
                remove_operator(op)
                flag_changed = True

        filtered_nodes = traverse.filter_nodes(
            nodes, ScalarMul)  # type: List[ScalarMul]
        while len(filtered_nodes) > 0:
            op = filtered_nodes.pop()
            if op.value == 1:
                remove_operator(op)
                flag_changed = True

        filtered_nodes = traverse.filter_nodes(
            nodes, ScalarPow)  # type: List[ScalarPow]
        while len(filtered_nodes) > 0:
            op = filtered_nodes.pop()
            if op.value == 1:
                remove_operator(op)
                flag_changed = True

        return graph, flag_changed
Exemplo n.º 10
0
def test_listup_nodes_residual():
    global graph, op1, op2, op3, v0, v1, v2, v3
    nodes = listup_nodes(graph)

    assert tuple(nodes) == (v0, op1, v1, op2, v2, op3, v3)