def test_create_cascader_graph(TwoConv2DWithSliceTE): _, te_graph, const_dict = TwoConv2DWithSliceTE device_config = cs.EthosuDeviceConfig("ethos-u55-256") graph = cs.create_cascader_graph(te_graph, const_dict, device_config) output_tensor = graph.output_tensors[0] assert output_tensor.shape == [1, 6, 1, 6, 16] assert len(output_tensor.producers) == 1 assert not output_tensor.is_constant conv2_part = output_tensor.producers[0] assert isinstance(conv2_part, cs.EthosuPart) assert len(conv2_part.input_tensors) == 3 assert conv2_part.input_tensors[0].shape == [1, 6, 6, 64] assert len(conv2_part.input_tensors[0].producers) == 1 assert not conv2_part.input_tensors[0].is_constant assert conv2_part.input_tensors[1].shape == [16, 3, 3, 64] assert len(conv2_part.input_tensors[1].producers) == 0 assert conv2_part.input_tensors[1].is_constant assert conv2_part.input_tensors[2].shape == [16, 10] assert len(conv2_part.input_tensors[2].producers) == 0 assert conv2_part.input_tensors[2].is_constant slice_part = conv2_part.input_tensors[0].producers[0] assert isinstance(slice_part, cs.InlinePart) assert len(slice_part.input_tensors) == 1 assert slice_part.input_tensors[0].shape == [1, 12, 12, 64] assert len(slice_part.input_tensors[0].producers) == 1 assert not slice_part.input_tensors[0].is_constant conv1_part = slice_part.input_tensors[0].producers[0] assert isinstance(conv1_part, cs.EthosuPart) assert len(conv1_part.input_tensors) == 3 assert conv1_part.input_tensors[0].shape == [1, 12, 12, 8] assert len(conv1_part.input_tensors[0].producers) == 0 assert not conv1_part.input_tensors[0].is_constant assert conv1_part.input_tensors[1].shape == [64, 1, 1, 8] assert len(conv1_part.input_tensors[1].producers) == 0 assert conv1_part.input_tensors[1].is_constant assert conv1_part.input_tensors[2].shape == [64, 10] assert len(conv1_part.input_tensors[2].producers) == 0 assert conv1_part.input_tensors[2].is_constant
def test_create_diamond_graph(MobileNetv2DiamondTE): _, te_graph, const_dict = MobileNetv2DiamondTE device_config = cs.EthosuDeviceConfig("ethos-u55-256") graph = cs.create_cascader_graph(te_graph, const_dict, device_config) output_tensor = graph.output_tensors[0] assert output_tensor.shape == [1, 56, 56, 24] assert len(output_tensor.producers) == 1 assert not output_tensor.is_constant add1_part = output_tensor.producers[0] assert isinstance(add1_part, cs.EthosuPart) assert len(add1_part.input_tensors) == 2 assert graph.get_part_id(add1_part) == 0 assert add1_part.input_tensors[0].shape == [1, 56, 56, 24] assert len(add1_part.input_tensors[0].producers) == 1 assert not add1_part.input_tensors[0].is_constant assert add1_part.input_tensors[1].shape == [1, 56, 56, 24] assert len(add1_part.input_tensors[0].producers) == 1 assert not add1_part.input_tensors[0].is_constant
def TwoConv2DGraph(): _, te_graph, const_dict = make_TwoConv2DTE() device_config = cs.EthosuDeviceConfig("ethos-u55-256") return cs.create_cascader_graph(te_graph, const_dict, device_config)
def MobileNetv1Graph(): _, te_graph, const_dict = make_MobileNetv1TE() device_config = cs.EthosuDeviceConfig("ethos-u55-256") return cs.create_cascader_graph(te_graph, const_dict, device_config)