def run(input_path: str, dest_dir_path: str, project_name: str, activate_hard_quantization: bool, threshold_skipping: bool = False, debug: bool = False, cache_dma: bool = False): output_dlk_test_dir = path.join(dest_dir_path, f'{project_name}.test') optimized_pb_path = path.join(dest_dir_path, f'{project_name}') optimized_pb_path += '.pb' output_project_path = path.join(dest_dir_path, f'{project_name}.prj') config = Config(activate_hard_quantization=activate_hard_quantization, threshold_skipping=threshold_skipping, test_dir=output_dlk_test_dir, optimized_pb_path=optimized_pb_path, output_pj_path=output_project_path, debug=debug, cache_dma=cache_dma) dest_dir_path = path.abspath(dest_dir_path) util.make_dirs(dest_dir_path) click.echo('import pb file') io = TensorFlowIO() graph: Graph = io.read(input_path) click.echo('optimize graph step: start') optimize_graph_step(graph, config) click.echo('optimize graph step: done!') click.echo('generate code step: start') generate_code_step(graph, config) click.echo(f'generate code step: done!')
def test_import_group_convolution_classification(self) -> None: """Test code for importing Tensorflow file with TensorflowIO.""" tf_path = path.join('tests', 'fixtures', 'classification', 'lmnet_v1_group_conv', 'minimal_graph_with_shape.pb') tf_io = TensorFlowIO() graph: Graph = tf_io.read(tf_path) outputs = graph.get_outputs() self.assertEqual(len(outputs), 1) self.assertEqual(outputs[0].shape, [1, 10]) print("TF file import test passed for group convolution!")
def test_import_classification(self) -> None: """Test code for importing Tensorflow file with TensorflowIO.""" tf_path = path.join('tests', 'fixtures', 'classification', 'lmnet_quantize_cifar10_stride_2.20180523.3x3', 'minimal_graph_with_shape.pb') tf_io = TensorFlowIO() graph: Graph = tf_io.read(tf_path) outputs = graph.get_outputs() self.assertEqual(len(outputs), 1) self.assertEqual(outputs[0].shape, [1, 10]) print("TF file import test passed for classification!")
def test_import_object_detection_lack_shape_info(self) -> None: """Test code for importing tf pb file of object detection (lack of shape info for some operator) with TensorflowIO. """ tf_path = path.join( 'tests', 'fixtures', 'object_detection', 'fyolo_quantize_17_v14_some_1x1_wide_pact_add_conv', 'minimal_graph_with_shape.pb') tf_io = TensorFlowIO() graph: Graph = tf_io.read(tf_path) outputs = graph.get_outputs() self.assertEqual(len(outputs), 1) self.assertEqual(outputs[0].shape, [1, 10, 10, 125]) print("TF file import test passed for object detection!")