Beispiel #1
0
def _test_create_det_records(path):
    class_map = {
        name: idx + 1
        for idx, name in enumerate(
            open(os.path.join(
                'tests/dummy_data/det/voc/voc2012.names')).read().splitlines())
    }
    serializer = DatumSerializer('image')
    Path(path).mkdir(parents=True, exist_ok=True)
    det_gen = image.DetDatumGenerator('tests/dummy_data/det/voc',
                                      gen_config=AttrDict(
                                          has_test_annotations=True,
                                          class_map=class_map))
    gen_kwargs = {'image_set': 'ImageSets'}
    sparse_features = [
        'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated',
        'labels_difficult'
    ]
    writer = TFRecordWriter(det_gen,
                            serializer,
                            path,
                            'train',
                            2,
                            sparse_features=sparse_features,
                            **gen_kwargs)
    writer.create_records()
    writer = TFRecordWriter(det_gen,
                            serializer,
                            path,
                            'val',
                            1,
                            sparse_features=sparse_features,
                            **gen_kwargs)
    writer.create_records()
Beispiel #2
0
 def setUp(self):
   class_map = {
       name: idx + 1
       for idx, name in enumerate(
           open(os.path.join('tests/dummy_data/det/voc/voc2012.names')).read().splitlines())
   }
   self.tempdir = '/tmp/test/tfrecord_det'
   self.serializer = DatumSerializer('image')
   Path(self.tempdir).mkdir(parents=True, exist_ok=True)
   DET_GEN = image.DetDatumGenerator('tests/dummy_data/det/voc',
                                     gen_config=AttrDict(has_test_annotations=True,
                                                         class_map=class_map))
   gen_kwargs = {'image_set': 'ImageSets'}
   sparse_features = [
       'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated', 'labels_difficult'
   ]
   self.writer = TFRecordWriter(DET_GEN,
                                self.serializer,
                                self.tempdir,
                                'train',
                                2,
                                sparse_features=sparse_features,
                                **gen_kwargs)
   self.writer.create_records()
   self.parser = DatumParser(self.tempdir)
Beispiel #3
0
 def setUp(self):
     class_map = {
         name: idx + 1
         for idx, name in enumerate(
             open(os.path.join('tests/dummy_data/det/voc/voc2012.names')).
             read().splitlines())
     }
     self.det_gen = image.DetDatumGenerator('tests/dummy_data/det/voc',
                                            gen_config=AttrDict(
                                                has_test_annotations=True,
                                                class_map=class_map))
Beispiel #4
0
def get_default_write_configs(
        problem_type: str,
        label_names_file: Optional[str] = None) -> ConfigBase:
    """Returns default write configs for a problem.

  Args:
    problem_type: Type of the problem,  any one from `datum.problems.types`.
    label_names_file: Path to the label name file, required for `IMAGE_DET` problem.

  Returns:
    A `ConfigBase` config object.

  Raises:
    ValueError: If label_names_file is not valid, raised only for `IMAGE_DET` problem.
  """
    serializer = DatumSerializer(
        problem.PROBLEM_PARAMS[problem_type]["serializer"],
        datum_name_to_encoder_fn=datum_name_to_encoder)
    if problem_type == problem.IMAGE_DET:
        if not label_names_file:
            raise ValueError(
                "label_names_file must be provided for `IMAGE_DET` problem.")
        if not Path(label_names_file).is_file():
            raise ValueError(
                f"Input {label_names_file} does not exist or not a file.")

        class_map = {
            name: idx + 1
            for idx, name in enumerate(
                open(os.path.join(label_names_file)).read().splitlines())
        }
        gen_config = AttrDict(has_test_annotations=True, class_map=class_map)
    else:
        gen_config = None
    generator = partial(problem.PROBLEM_PARAMS[problem_type]["generator"],
                        gen_config=gen_config)
    write_configs = TFRWriteConfigs()
    write_configs.serializer = serializer
    write_configs.generator = generator
    return write_configs
Beispiel #5
0
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

from datum.encoder.encoder import datum_name_to_encoder
from datum.generator.image import SegDatumGenerator
from datum.serializer.serializer import DatumSerializer
from datum.utils.common_utils import AttrDict

config = {
    'generator':
    partial(SegDatumGenerator, gen_config=None),
    'sprase_features': [],
    'serializer':
    DatumSerializer('image', datum_name_to_encoder_fn=datum_name_to_encoder),
    'splits': ['train', 'val'],
    'num_examples': {
        'train': 1,
        'val': 1,
    }
}

cnf = AttrDict(config)
Beispiel #6
0
from functools import partial

from datum.encoder.encoder import datum_name_to_encoder
from datum.generator.image import DetDatumGenerator
from datum.serializer.serializer import DatumSerializer
from datum.utils.common_utils import AttrDict

LABEL_NAMES_FILE = 'tests/dummy_data/det/voc/voc2012.names'

class_map = {
    name: idx + 1
    for idx, name in enumerate(
        open(os.path.join(LABEL_NAMES_FILE)).read().splitlines())
}
config = {
    'generator':
    partial(DetDatumGenerator,
            gen_config=AttrDict(has_test_annotations=True,
                                class_map=class_map)),
    'sprase_features': [],
    'serializer':
    DatumSerializer('image', datum_name_to_encoder_fn=datum_name_to_encoder),
    'splits': ['train', 'val'],
    'num_examples': {
        'train': 1,
        'val': 1,
    }
}

cnf = AttrDict(config)