def test_from_json(self):
     self.assertEqual(
         Config.from_json({
             'input_names': ['input'],
             'output_names': ['output'],
             'input_formats': ['channels_first']
         }),
         Config(input_info=[('input', DataFormat.CHANNELS_FIRST)],
                output_names=['output']))
 def test_from_env(self):
     self.assertEqual(
         Config.from_env({
             'INPUT_NAMES': 'input1,input2:0',
             'OUTPUT_NAMES': 'output',
             'INPUT_FORMATS': 'channels_first'
         }),
         Config(input_info=[('input1', DataFormat.CHANNELS_FIRST),
                            ('input2:0', None)],
                output_names=['output']))
    def test_compile_with_variables(self):
        with TemporaryDirectory() as directory:
            model_path = os.path.join(directory, 'model.ckpt')

            _save_tensorflow_model(model_path)

            config = Config.from_json({
                'input_names': ['x:0', 'y:0'],
                'output_names': ['z:0'],
                'input_formats': ['channels_first', '']
            })

            compiled = compiler.compile_source(
                TfModelFile(model_path=model_path), config)

        self.assertIsInstance(compiled.session, tf.compat.v1.Session)
        self.assertEqual(
            [model_input.tensor.name for model_input in compiled.inputs],
            ['x:0', 'y:0'])

        self.assertEqual(
            [model_input.data_format for model_input in compiled.inputs],
            [DataFormat.CHANNELS_FIRST, None])

        self.assertEqual(
            [model_output.name for model_output in compiled.outputs], ['z:0'])
    def test_compile_with_none_input_format(self):
        with TemporaryDirectory() as directory:
            model_path = os.path.join(directory, 'model.ckpt')

            _save_tensorflow_model(model_path)

            config = Config.from_json({
                'input_names': ['x:0', 'y:0'],
                'output_names': ['z:0'],
                'input_formats': []
            })

            compiled = compiler.compile_source(
                TfModelFile(model_path=model_path), config)

        self.assertEqual(
            [model_input.data_format for model_input in compiled.inputs],
            [None, None])