def test_from_env(self):
     self.assertEqual(
         Config.from_env({
             'INPUT_NAMES': 'input1,input2:0',
             'OUTPUT_NAMES': 'output',
             'INPUT_FORMATS': 'channels_last'
         }),
         Config(input_names=['input1', 'input2:0'],
                data_formats=[DataFormat.CHANNELS_LAST],
                output_names=['output']))
 def test_from_json(self):
     self.assertEqual(
         Config.from_json({
             'input_names': ['input'],
             'output_names': ['output'],
             'input_formats': ['channels_first']
         }),
         Config(input_names=['input'],
                data_formats=[DataFormat.CHANNELS_FIRST],
                output_names=['output']))
    def test_compile_with_no_output_name(self):
        with NamedTemporaryFile(suffix='.pb') as model_file:
            _save_frozen_graph_model(model_file)

            config = Config.from_json({
                'input_names': ['x:0', 'y:0'],
                'input_formats': ['channels_first', 'channels_last']
            })
            compiled = compiler.compile_source(
                FrozenGraphFile(model_path=model_file.name), config)

        self.assertEqual(
            [model_output.name for model_output in compiled.outputs], ['z:0'])
    def test_compile_with_no_input_formats(self):
        with NamedTemporaryFile(suffix='.pb') as model_file:
            _save_frozen_graph_model(model_file)

            config = Config.from_json({
                'input_names': ['x:0'],
                'output_names': ['z:0']
            })
            compiled = compiler.compile_source(
                FrozenGraphFile(model_path=model_file.name), config)

        self.assertEqual(
            [model_input.data_format for model_input in compiled.inputs],
            [None])
    def test_compile_with_variables(self):
        with NamedTemporaryFile(suffix='.pb') as model_file:
            _save_frozen_graph_model(model_file)
            config = Config.from_json({
                'input_names': ['x:0', 'y:0'],
                'output_names': ['z:0'],
                'input_formats': ['channels_first', 'channels_last']
            })
            compiled = compiler.compile_source(
                FrozenGraphFile(model_path=model_file.name), config)

        self.assertIsInstance(compiled.session, tf.compat.v1.Session)

        self.assertEqual(
            [model_input.tensor.name for model_input in compiled.inputs],
            ['x:0', 'y:0'])
        self.assertEqual(
            [model_input.data_format for model_input in compiled.inputs],
            [DataFormat.CHANNELS_FIRST, DataFormat.CHANNELS_LAST])

        self.assertEqual(
            [model_output.name for model_output in compiled.outputs], ['z:0'])
 def test_from_env_no_names(self):
     self.assertEqual(
         Config.from_env({}),
         Config(input_names=None, data_formats=[], output_names=None))