コード例 #1
0
ファイル: test_graph.py プロジェクト: wenxcs/tvm
def test_create_cascader_graph(TwoConv2DWithSliceTE):
    _, te_graph, const_dict = TwoConv2DWithSliceTE
    device_config = cs.EthosuDeviceConfig("ethos-u55-256")
    graph = cs.create_cascader_graph(te_graph, const_dict, device_config)

    output_tensor = graph.output_tensors[0]
    assert output_tensor.shape == [1, 6, 1, 6, 16]
    assert len(output_tensor.producers) == 1
    assert not output_tensor.is_constant

    conv2_part = output_tensor.producers[0]
    assert isinstance(conv2_part, cs.EthosuPart)
    assert len(conv2_part.input_tensors) == 3

    assert conv2_part.input_tensors[0].shape == [1, 6, 6, 64]
    assert len(conv2_part.input_tensors[0].producers) == 1
    assert not conv2_part.input_tensors[0].is_constant

    assert conv2_part.input_tensors[1].shape == [16, 3, 3, 64]
    assert len(conv2_part.input_tensors[1].producers) == 0
    assert conv2_part.input_tensors[1].is_constant

    assert conv2_part.input_tensors[2].shape == [16, 10]
    assert len(conv2_part.input_tensors[2].producers) == 0
    assert conv2_part.input_tensors[2].is_constant

    slice_part = conv2_part.input_tensors[0].producers[0]
    assert isinstance(slice_part, cs.InlinePart)
    assert len(slice_part.input_tensors) == 1

    assert slice_part.input_tensors[0].shape == [1, 12, 12, 64]
    assert len(slice_part.input_tensors[0].producers) == 1
    assert not slice_part.input_tensors[0].is_constant

    conv1_part = slice_part.input_tensors[0].producers[0]
    assert isinstance(conv1_part, cs.EthosuPart)
    assert len(conv1_part.input_tensors) == 3

    assert conv1_part.input_tensors[0].shape == [1, 12, 12, 8]
    assert len(conv1_part.input_tensors[0].producers) == 0
    assert not conv1_part.input_tensors[0].is_constant

    assert conv1_part.input_tensors[1].shape == [64, 1, 1, 8]
    assert len(conv1_part.input_tensors[1].producers) == 0
    assert conv1_part.input_tensors[1].is_constant

    assert conv1_part.input_tensors[2].shape == [64, 10]
    assert len(conv1_part.input_tensors[2].producers) == 0
    assert conv1_part.input_tensors[2].is_constant
コード例 #2
0
ファイル: test_graph.py プロジェクト: wenxcs/tvm
def test_create_diamond_graph(MobileNetv2DiamondTE):
    _, te_graph, const_dict = MobileNetv2DiamondTE
    device_config = cs.EthosuDeviceConfig("ethos-u55-256")
    graph = cs.create_cascader_graph(te_graph, const_dict, device_config)

    output_tensor = graph.output_tensors[0]
    assert output_tensor.shape == [1, 56, 56, 24]
    assert len(output_tensor.producers) == 1
    assert not output_tensor.is_constant

    add1_part = output_tensor.producers[0]
    assert isinstance(add1_part, cs.EthosuPart)
    assert len(add1_part.input_tensors) == 2
    assert graph.get_part_id(add1_part) == 0

    assert add1_part.input_tensors[0].shape == [1, 56, 56, 24]
    assert len(add1_part.input_tensors[0].producers) == 1
    assert not add1_part.input_tensors[0].is_constant

    assert add1_part.input_tensors[1].shape == [1, 56, 56, 24]
    assert len(add1_part.input_tensors[0].producers) == 1
    assert not add1_part.input_tensors[0].is_constant
コード例 #3
0
ファイル: conftest.py プロジェクト: zjppoet/tvm
 def TwoConv2DGraph():
     _, te_graph, const_dict = make_TwoConv2DTE()
     device_config = cs.EthosuDeviceConfig("ethos-u55-256")
     return cs.create_cascader_graph(te_graph, const_dict, device_config)
コード例 #4
0
 def MobileNetv1Graph():
     _, te_graph, const_dict = make_MobileNetv1TE()
     device_config = cs.EthosuDeviceConfig("ethos-u55-256")
     return cs.create_cascader_graph(te_graph, const_dict, device_config)