def test_compile_with_no_input_formats(self):
        with NamedTemporaryFile(suffix='.onnx') as model_file:
            onnx.save_model(_make_onnx_model().model_proto, model_file.name)

            config = Config.from_json({})
            compiled = onnx_compiler.compile_source(source=ONNXModelFile(
                model_file.name),
                                                    config=config)

        self.assertEqual(compiled.input_data_formats, [None, None])
    def test_with_data_formats(self):
        with NamedTemporaryFile(suffix='.onnx') as model_file:
            onnx.save_model(_make_onnx_model().model_proto, model_file.name)

            config = Config.from_json(
                {'input_formats': ['channels_first', 'channels_last']})
            onnx_model = onnx_compiler.compile_source(source=ONNXModelFile(
                model_file.name),
                                                      config=config)

        compiled = onnx_to_tflite_compiler.compile_source(source=onnx_model)

        self.assertIsInstance(compiled.tflite_model, bytes)
        self.assertEqual(compiled.input_formats,
                         [DataFormat.CHANNELS_FIRST, DataFormat.CHANNELS_LAST])
    def test_compile_with_input_formats_is_none(self):
        with NamedTemporaryFile(suffix='.onnx') as model_file:
            onnx.save_model(_make_onnx_model().model_proto, model_file.name)

            config = Config.from_json({'input_formats': []})

            with self.assertRaises(ValueError) as context_manager:
                onnx_compiler.compile_source(source=ONNXModelFile(
                    model_file.name),
                                             config=config)

        self.assertEqual(
            context_manager.exception.args,
            ('Number of input formats (0) does not match number of inputs (2)',
             ))
    def test_compile_with_variables(self):
        with NamedTemporaryFile(suffix='.onnx') as model_file:
            onnx.save_model(_make_onnx_model().model_proto, model_file.name)

            config = Config.from_json(
                {'input_formats': ['channels_first', None]})
            compiled = onnx_compiler.compile_source(source=ONNXModelFile(
                model_file.name),
                                                    config=config)

        self.assertEqual([
            graph_input.name
            for graph_input in compiled.model_proto.graph.input
        ], ['x:0', 'y:0'])
        self.assertEqual(compiled.input_data_formats,
                         [DataFormat.CHANNELS_FIRST, None])
 def test_from_env(self):
     self.assertEqual(Config.from_env({'INPUT_FORMATS': 'channels_last'}),
                      Config(input_formats=[DataFormat.CHANNELS_LAST]))
 def test_from_json(self):
     self.assertEqual(
         Config.from_json({'input_formats': ['channels_first']}),
         Config(input_formats=[DataFormat.CHANNELS_FIRST]))