示例#1
0
 def test_question_answer_demo(self, opt, args, kwargs):
     sys.argv = [
         BIN, 'question_answer', 'train_data', 'validation_data', 'export'
     ] + opt
     with patch_qa() as run:
         cli.main()
         run.assert_called_once_with('train_data', 'validation_data',
                                     'export', *args, **kwargs)
示例#2
0
 def test_text_classification_demo(self, opt, args, kwargs):
     sys.argv = [
         BIN, 'text_classification', 'data', 'lite', 'label', 'vocab'
     ] + opt
     with patch_text() as run:
         cli.main()
         run.assert_called_once_with('data', 'lite', 'label', 'vocab',
                                     *args, **kwargs)
示例#3
0
 def test_init(self, tf_opt, expected_tf):
     sys.argv = [BIN, 'image_classification', 'data', 'lite', 'label'
                 ] + tf_opt
     with patch_image() as run, patch_setup() as setup:
         cli.main()
         setup.assert_called_once_with(expected_tf)
         run.assert_called_once_with('data', 'lite', 'label',
                                     'efficientnet_b0')
示例#4
0
 def test_text_classification_demo_lack_param(self):
     sys.argv = [BIN, 'text_classification', 'data']
     with patch_text() as run:
         with self.assertRaisesRegex(fire.core.FireExit, '2'):
             cli.main()
         run.assert_not_called()
示例#5
0
 def test_image_classification_demo(self, opt, args, kwargs):
     sys.argv = [BIN, 'image_classification', 'data', 'export'] + opt
     with patch_image() as run:
         cli.main()
         run.assert_called_once_with('data', 'export', *args, **kwargs)
示例#6
0
 def test_help(self, opt):
     sys.argv = opt
     with self.assertRaisesRegex(fire.core.FireExit, '0'):
         cli.main()
示例#7
0
 def test_invalid_command(self):
     sys.argv = [BIN, 'invalid_command']
     with self.assertRaisesRegex(fire.core.FireExit, '2'):
         cli.main()
示例#8
0
 def test_question_answer_lack_param(self):
     sys.argv = [BIN, 'question_answer', 'train_data', 'validation_data']
     with patch_qa() as run:
         with self.assertRaisesRegex(fire.core.FireExit, '2'):
             cli.main()
         run.assert_not_called()