def _test_create_det_records(path): class_map = { name: idx + 1 for idx, name in enumerate( open(os.path.join( 'tests/dummy_data/det/voc/voc2012.names')).read().splitlines()) } serializer = DatumSerializer('image') Path(path).mkdir(parents=True, exist_ok=True) det_gen = image.DetDatumGenerator('tests/dummy_data/det/voc', gen_config=AttrDict( has_test_annotations=True, class_map=class_map)) gen_kwargs = {'image_set': 'ImageSets'} sparse_features = [ 'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated', 'labels_difficult' ] writer = TFRecordWriter(det_gen, serializer, path, 'train', 2, sparse_features=sparse_features, **gen_kwargs) writer.create_records() writer = TFRecordWriter(det_gen, serializer, path, 'val', 1, sparse_features=sparse_features, **gen_kwargs) writer.create_records()
def setUp(self): class_map = { name: idx + 1 for idx, name in enumerate( open(os.path.join('tests/dummy_data/det/voc/voc2012.names')).read().splitlines()) } self.tempdir = '/tmp/test/tfrecord_det' self.serializer = DatumSerializer('image') Path(self.tempdir).mkdir(parents=True, exist_ok=True) DET_GEN = image.DetDatumGenerator('tests/dummy_data/det/voc', gen_config=AttrDict(has_test_annotations=True, class_map=class_map)) gen_kwargs = {'image_set': 'ImageSets'} sparse_features = [ 'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated', 'labels_difficult' ] self.writer = TFRecordWriter(DET_GEN, self.serializer, self.tempdir, 'train', 2, sparse_features=sparse_features, **gen_kwargs) self.writer.create_records() self.parser = DatumParser(self.tempdir)
def setUp(self): class_map = { name: idx + 1 for idx, name in enumerate( open(os.path.join('tests/dummy_data/det/voc/voc2012.names')). read().splitlines()) } self.det_gen = image.DetDatumGenerator('tests/dummy_data/det/voc', gen_config=AttrDict( has_test_annotations=True, class_map=class_map))
def get_default_write_configs( problem_type: str, label_names_file: Optional[str] = None) -> ConfigBase: """Returns default write configs for a problem. Args: problem_type: Type of the problem, any one from `datum.problems.types`. label_names_file: Path to the label name file, required for `IMAGE_DET` problem. Returns: A `ConfigBase` config object. Raises: ValueError: If label_names_file is not valid, raised only for `IMAGE_DET` problem. """ serializer = DatumSerializer( problem.PROBLEM_PARAMS[problem_type]["serializer"], datum_name_to_encoder_fn=datum_name_to_encoder) if problem_type == problem.IMAGE_DET: if not label_names_file: raise ValueError( "label_names_file must be provided for `IMAGE_DET` problem.") if not Path(label_names_file).is_file(): raise ValueError( f"Input {label_names_file} does not exist or not a file.") class_map = { name: idx + 1 for idx, name in enumerate( open(os.path.join(label_names_file)).read().splitlines()) } gen_config = AttrDict(has_test_annotations=True, class_map=class_map) else: gen_config = None generator = partial(problem.PROBLEM_PARAMS[problem_type]["generator"], gen_config=gen_config) write_configs = TFRWriteConfigs() write_configs.serializer = serializer write_configs.generator = generator return write_configs
# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from datum.encoder.encoder import datum_name_to_encoder from datum.generator.image import SegDatumGenerator from datum.serializer.serializer import DatumSerializer from datum.utils.common_utils import AttrDict config = { 'generator': partial(SegDatumGenerator, gen_config=None), 'sprase_features': [], 'serializer': DatumSerializer('image', datum_name_to_encoder_fn=datum_name_to_encoder), 'splits': ['train', 'val'], 'num_examples': { 'train': 1, 'val': 1, } } cnf = AttrDict(config)
from functools import partial from datum.encoder.encoder import datum_name_to_encoder from datum.generator.image import DetDatumGenerator from datum.serializer.serializer import DatumSerializer from datum.utils.common_utils import AttrDict LABEL_NAMES_FILE = 'tests/dummy_data/det/voc/voc2012.names' class_map = { name: idx + 1 for idx, name in enumerate( open(os.path.join(LABEL_NAMES_FILE)).read().splitlines()) } config = { 'generator': partial(DetDatumGenerator, gen_config=AttrDict(has_test_annotations=True, class_map=class_map)), 'sprase_features': [], 'serializer': DatumSerializer('image', datum_name_to_encoder_fn=datum_name_to_encoder), 'splits': ['train', 'val'], 'num_examples': { 'train': 1, 'val': 1, } } cnf = AttrDict(config)