def __init__(self, file_name: str): self._raw_data_dict = SimpleProtoParser().parse_file(file_name) if not self._raw_data_dict: raise Error( 'Failed to parse pipeline.config file {}'.format(file_name)) self._initialize_model_params()
def test_proto_reader_from_non_readable_file(self): file = tempfile.NamedTemporaryFile('wt', delete=False) file.write(correct_proto_message_1) file_name = file.name file.close() os.chmod(file_name, 0000) result = SimpleProtoParser().parse_file(file_name) self.assertIsNone(result) os.unlink(file_name)
def test_correct_proto_reader_from_string_with_special_characters_in_string( self): result = SimpleProtoParser().parse_from_string( correct_proto_message_11) expected_result = { 'model': { 'path': r"C:\[{],}", 'other_value': [1, 2, 3, 4] } } self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_with_windows_path(self): result = SimpleProtoParser().parse_from_string( correct_proto_message_10) expected_result = { 'train_input_reader': { 'label_map_path': r"C:\mscoco_label_map.pbtxt", 'tf_record_input_reader': { 'input_path': "PATH_TO_BE_CONFIGURED/ mscoco_train.record" } } } self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_3(self): result = SimpleProtoParser().parse_from_string(correct_proto_message_3) expected_result = { 'initializer': { 'variance_scaling_initializer': { 'factor': 1.0, 'uniform': True, 'bla': False, 'mode': 'FAN_AVG' } } } self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_2(self): result = SimpleProtoParser().parse_from_string(correct_proto_message_2) expected_result = { 'first_stage_anchor_generator': { 'grid_anchor_generator': { 'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0], 'aspect_ratios': [0.5, 1.0, 2.0] } } } self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_1(self): result = SimpleProtoParser().parse_from_string(correct_proto_message_1) expected_result = { 'model': { 'faster_rcnn': { 'num_classes': 90, 'image_resizer': { 'keep_aspect_ratio_resizer': { 'min_dimension': 600, 'max_dimension': 1024 } } } } } self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_file(self): file = tempfile.NamedTemporaryFile('wt', delete=False) file.write(correct_proto_message_1) file_name = file.name file.close() result = SimpleProtoParser().parse_file(file_name) expected_result = { 'model': { 'faster_rcnn': { 'num_classes': 90, 'image_resizer': { 'keep_aspect_ratio_resizer': { 'min_dimension': 600, 'max_dimension': 1024 } } } } } self.assertDictEqual(result, expected_result) os.unlink(file_name)
def test_proto_reader_from_non_existing_file(self): result = SimpleProtoParser().parse_file('/non/existing/file') self.assertIsNone(result)
def test_incorrect_proto_reader_from_string_7(self): result = SimpleProtoParser().parse_from_string( incorrect_proto_message_7) self.assertIsNone(result)
def test_correct_proto_reader_from_string_with_comma_trailing_list(self): result = SimpleProtoParser().parse_from_string(correct_proto_message_8) expected_result = {'model': {'good_list': [3.0, 5.0]}} self.assertDictEqual(result, expected_result)