def test_correct_proto_reader_from_string_4(self):
     result = SimpleProtoParser().parse_from_string(correct_proto_message_4)
     expected_result = {
         'train_input_reader': {'label_map_path': "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt",
                                'tf_record_input_reader': {
                                    'input_path': "PATH_TO_BE_CONFIGURED/  mscoco_train.record"}}}
     self.assertDictEqual(result, expected_result)
Ejemplo n.º 2
0
    def __init__(self, file_name: str):
        self._raw_data_dict = SimpleProtoParser().parse_file(file_name)
        if not self._raw_data_dict:
            raise Error(
                'Failed to parse pipeline.config file {}'.format(file_name))

        self._initialize_model_params()
Ejemplo n.º 3
0
    def test_proto_reader_from_non_readable_file(self):
        file = tempfile.NamedTemporaryFile('wt', delete=False)
        file.write(correct_proto_message_1)
        file_name = file.name
        file.close()
        os.chmod(file_name, 0000)

        result = SimpleProtoParser().parse_file(file_name)
        self.assertIsNone(result)
        os.unlink(file_name)
Ejemplo n.º 4
0
 def test_correct_proto_reader_from_string_with_special_characters_in_string(
         self):
     result = SimpleProtoParser().parse_from_string(
         correct_proto_message_11)
     expected_result = {
         'model': {
             'path': "C:\[{],}",
             'other_value': [1, 2, 3, 4]
         }
     }
     self.assertDictEqual(result, expected_result)
    def test_correct_proto_reader_from_file(self):
        file = tempfile.NamedTemporaryFile('wt', delete=False)
        file.write(correct_proto_message_1)
        file_name = file.name
        file.close()

        result = SimpleProtoParser().parse_file(file_name)
        expected_result = {'model': {'faster_rcnn': {'num_classes': 90, 'image_resizer': {
            'keep_aspect_ratio_resizer': {'min_dimension': 600, 'max_dimension': 1024}}}}}
        self.assertDictEqual(result, expected_result)
        os.unlink(file_name)
Ejemplo n.º 6
0
 def test_correct_proto_reader_from_string_2(self):
     result = SimpleProtoParser().parse_from_string(correct_proto_message_2)
     expected_result = {
         'first_stage_anchor_generator': {
             'grid_anchor_generator': {
                 'height_stride': 16,
                 'width_stride': 16,
                 'scales': [0.25, 0.5, 1.0, 2.0],
                 'aspect_ratios': [0.5, 1.0, 2.0]
             }
         }
     }
     self.assertDictEqual(result, expected_result)
Ejemplo n.º 7
0
 def test_correct_proto_reader_from_string_with_comments(self):
     result = SimpleProtoParser().parse_from_string(correct_proto_message_5)
     expected_result = {
         'initializer': {
             'variance_scaling_initializer': {
                 'factor': 1.0,
                 'uniform': True,
                 'bla': False,
                 'mode': 'FAN_AVG'
             }
         }
     }
     self.assertDictEqual(result, expected_result)
Ejemplo n.º 8
0
 def test_correct_proto_reader_from_string_1(self):
     result = SimpleProtoParser().parse_from_string(correct_proto_message_1)
     expected_result = {
         'model': {
             'faster_rcnn': {
                 'num_classes': 90,
                 'image_resizer': {
                     'keep_aspect_ratio_resizer': {
                         'min_dimension': 600,
                         'max_dimension': 1024
                     }
                 }
             }
         }
     }
     self.assertDictEqual(result, expected_result)
Ejemplo n.º 9
0
 def test_proto_reader_from_non_existing_file(self):
     result = SimpleProtoParser().parse_file('/non/existing/file')
     self.assertIsNone(result)
Ejemplo n.º 10
0
 def test_incorrect_proto_reader_from_string_7(self):
     result = SimpleProtoParser().parse_from_string(
         incorrect_proto_message_7)
     self.assertIsNone(result)
Ejemplo n.º 11
0
 def test_correct_proto_reader_from_string_with_comma_trailing_list(self):
     result = SimpleProtoParser().parse_from_string(correct_proto_message_8)
     expected_result = {'model': {'good_list': [3.0, 5.0]}}
     self.assertDictEqual(result, expected_result)