def test_node_factory_validate_missing_arguments():
    factory = NodeFactory("opset1")

    try:
        factory.create(
            "TopK", None, {"axis": 1, "mode": "max", "sort": "value"}
        )
    except UserInputError:
        pass
    else:
        raise AssertionError("Validation of missing arguments has unexpectedly passed.")
def test_node_factory_empty_topk_with_args_and_attrs():
    dtype = np.int32
    data = ng.parameter([2, 10], dtype=dtype, name="A")
    k = ng.constant(3, dtype=dtype, name="B")
    factory = NodeFactory("opset1")
    arguments = NodeFactory._arguments_as_outputs([data, k])
    node = factory.create("TopK", None, None)
    node.set_arguments(arguments)
    node.set_attribute("axis", 1)
    node.set_attribute("mode", "max")
    node.set_attribute("sort", "value")
    node.validate()

    assert node.get_type_name() == "TopK"
    assert node.get_output_size() == 2
    assert list(node.get_output_shape(0)) == [2, 3]
def test_node_factory_add():
    shape = [2, 2]
    dtype = np.int8
    parameter_a = ng.parameter(shape, dtype=dtype, name="A")
    parameter_b = ng.parameter(shape, dtype=dtype, name="B")

    factory = _NodeFactory("opset1")
    arguments = NodeFactory._arguments_as_outputs([parameter_a, parameter_b])
    node = factory.create("Add", arguments, {})

    assert node.get_type_name() == "Add"
    assert node.get_output_size() == 1
    assert list(node.get_output_shape(0)) == [2, 2]
Exemple #4
0
def test_node_factory_topk():
    dtype = np.int32
    data = ng.parameter([2, 10], dtype=dtype, name="A")
    k = ng.constant(3, dtype=dtype, name="B")
    factory = _NodeFactory("opset1")
    arguments = NodeFactory._arguments_as_outputs([data, k])
    node = factory.create(
        "TopK", arguments, {"axis": 1, "mode": "max", "sort": "value"}
    )

    assert node.get_type_name() == "TopK"
    assert node.get_output_size() == 2
    assert list(node.get_output_shape(0)) == [2, 3]
def test_node_factory_empty_topk():
    factory = NodeFactory("opset1")
    node = factory.create("TopK")

    assert node.get_type_name() == "TopK"
def _get_node_factory(opset_version: Optional[str] = None) -> NodeFactory:
    """Return NodeFactory configured to create operators from specified opset version."""
    if opset_version:
        return NodeFactory(opset_version)
    else:
        return NodeFactory()