Exemple #1
0
    def __init__(self, file_name: str):
        self._raw_data_dict = SimpleProtoParser().parse_file(file_name)
        if not self._raw_data_dict:
            raise Error(
                'Failed to parse pipeline.config file {}'.format(file_name))

        self._initialize_model_params()
    def test_proto_reader_from_non_readable_file(self):
        file = tempfile.NamedTemporaryFile('wt', delete=False)
        file.write(correct_proto_message_1)
        file_name = file.name
        file.close()
        os.chmod(file_name, 0000)

        result = SimpleProtoParser().parse_file(file_name)
        self.assertIsNone(result)
        os.unlink(file_name)
 def test_correct_proto_reader_from_string_with_special_characters_in_string(
         self):
     result = SimpleProtoParser().parse_from_string(
         correct_proto_message_11)
     expected_result = {
         'model': {
             'path': r"C:\[{],}",
             'other_value': [1, 2, 3, 4]
         }
     }
     self.assertDictEqual(result, expected_result)
 def test_correct_proto_reader_from_string_with_windows_path(self):
     result = SimpleProtoParser().parse_from_string(
         correct_proto_message_10)
     expected_result = {
         'train_input_reader': {
             'label_map_path': r"C:\mscoco_label_map.pbtxt",
             'tf_record_input_reader': {
                 'input_path': "PATH_TO_BE_CONFIGURED/  mscoco_train.record"
             }
         }
     }
     self.assertDictEqual(result, expected_result)
 def test_correct_proto_reader_from_string_3(self):
     result = SimpleProtoParser().parse_from_string(correct_proto_message_3)
     expected_result = {
         'initializer': {
             'variance_scaling_initializer': {
                 'factor': 1.0,
                 'uniform': True,
                 'bla': False,
                 'mode': 'FAN_AVG'
             }
         }
     }
     self.assertDictEqual(result, expected_result)
 def test_correct_proto_reader_from_string_2(self):
     result = SimpleProtoParser().parse_from_string(correct_proto_message_2)
     expected_result = {
         'first_stage_anchor_generator': {
             'grid_anchor_generator': {
                 'height_stride': 16,
                 'width_stride': 16,
                 'scales': [0.25, 0.5, 1.0, 2.0],
                 'aspect_ratios': [0.5, 1.0, 2.0]
             }
         }
     }
     self.assertDictEqual(result, expected_result)
 def test_correct_proto_reader_from_string_1(self):
     result = SimpleProtoParser().parse_from_string(correct_proto_message_1)
     expected_result = {
         'model': {
             'faster_rcnn': {
                 'num_classes': 90,
                 'image_resizer': {
                     'keep_aspect_ratio_resizer': {
                         'min_dimension': 600,
                         'max_dimension': 1024
                     }
                 }
             }
         }
     }
     self.assertDictEqual(result, expected_result)
    def test_correct_proto_reader_from_file(self):
        file = tempfile.NamedTemporaryFile('wt', delete=False)
        file.write(correct_proto_message_1)
        file_name = file.name
        file.close()

        result = SimpleProtoParser().parse_file(file_name)
        expected_result = {
            'model': {
                'faster_rcnn': {
                    'num_classes': 90,
                    'image_resizer': {
                        'keep_aspect_ratio_resizer': {
                            'min_dimension': 600,
                            'max_dimension': 1024
                        }
                    }
                }
            }
        }
        self.assertDictEqual(result, expected_result)
        os.unlink(file_name)
 def test_proto_reader_from_non_existing_file(self):
     result = SimpleProtoParser().parse_file('/non/existing/file')
     self.assertIsNone(result)
 def test_incorrect_proto_reader_from_string_7(self):
     result = SimpleProtoParser().parse_from_string(
         incorrect_proto_message_7)
     self.assertIsNone(result)
 def test_correct_proto_reader_from_string_with_comma_trailing_list(self):
     result = SimpleProtoParser().parse_from_string(correct_proto_message_8)
     expected_result = {'model': {'good_list': [3.0, 5.0]}}
     self.assertDictEqual(result, expected_result)