예제 #1
0
def _test_create_textjson_records(path):
    tempdir = tempfile.mkdtemp()
    data = {
        1: {
            'text': 'this is text file',
            'label': {
                'polarity': 1
            }
        },
        2: {
            'text': 'this is json file',
            'label': {
                'polarity': 2
            }
        },
        3: {
            'text': 'this is label file',
            'label': {
                'polarity': 0
            }
        },
    }
    with open(os.path.join(tempdir, 'train.json'), 'w') as f:
        json.dump(data, f)
    gen_from_json = text.TextJsonDatumGenerator(tempdir)
    serializer = DatumSerializer('text')
    Path(path).mkdir(parents=True, exist_ok=True)
    textjson_gen = text.TextJsonDatumGenerator(tempdir)
    writer = TFRecordWriter(textjson_gen, serializer, path, 'train', 3)
    writer.create_records()
    rmtree(tempdir)
예제 #2
0
def _test_create_det_records(path):
    class_map = {
        name: idx + 1
        for idx, name in enumerate(
            open(os.path.join(
                'tests/dummy_data/det/voc/voc2012.names')).read().splitlines())
    }
    serializer = DatumSerializer('image')
    Path(path).mkdir(parents=True, exist_ok=True)
    det_gen = image.DetDatumGenerator('tests/dummy_data/det/voc',
                                      gen_config=AttrDict(
                                          has_test_annotations=True,
                                          class_map=class_map))
    gen_kwargs = {'image_set': 'ImageSets'}
    sparse_features = [
        'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated',
        'labels_difficult'
    ]
    writer = TFRecordWriter(det_gen,
                            serializer,
                            path,
                            'train',
                            2,
                            sparse_features=sparse_features,
                            **gen_kwargs)
    writer.create_records()
    writer = TFRecordWriter(det_gen,
                            serializer,
                            path,
                            'val',
                            1,
                            sparse_features=sparse_features,
                            **gen_kwargs)
    writer.create_records()
예제 #3
0
 def setUp(self):
   class_map = {
       name: idx + 1
       for idx, name in enumerate(
           open(os.path.join('tests/dummy_data/det/voc/voc2012.names')).read().splitlines())
   }
   self.tempdir = '/tmp/test/tfrecord_det'
   self.serializer = DatumSerializer('image')
   Path(self.tempdir).mkdir(parents=True, exist_ok=True)
   DET_GEN = image.DetDatumGenerator('tests/dummy_data/det/voc',
                                     gen_config=AttrDict(has_test_annotations=True,
                                                         class_map=class_map))
   gen_kwargs = {'image_set': 'ImageSets'}
   sparse_features = [
       'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated', 'labels_difficult'
   ]
   self.writer = TFRecordWriter(DET_GEN,
                                self.serializer,
                                self.tempdir,
                                'train',
                                2,
                                sparse_features=sparse_features,
                                **gen_kwargs)
   self.writer.create_records()
   self.parser = DatumParser(self.tempdir)
예제 #4
0
 def setUp(self):
     self.datum = {
         'image': 'tests/dummy_data/clf/train/image_232.jpg',
         'label1': [1, 2, 3],
         'label2': 1,
         'label3': 1.1,
         'label4': 'test',
     }
     self.serializer = DatumSerializer('image')
예제 #5
0
def _test_create_seg_records(path):
    serializer = DatumSerializer('image')
    Path(path).mkdir(parents=True, exist_ok=True)
    seg_gen = image.SegDatumGenerator('tests/dummy_data/seg/voc')
    gen_kwargs = {'image_set': 'ImageSets'}
    writer = TFRecordWriter(seg_gen, serializer, path, 'train', 1,
                            **gen_kwargs)
    writer.create_records()
    writer = TFRecordWriter(seg_gen, serializer, path, 'val', 1, **gen_kwargs)
    writer.create_records()
예제 #6
0
 def setUp(self):
   self.tempdir = '/tmp/test/tfrecord_clf'
   self.serializer = DatumSerializer('image')
   Path(self.tempdir).mkdir(parents=True, exist_ok=True)
   CLF_GEN = image.ClfDatumGenerator('tests/dummy_data/clf')
   gen_kwargs = {'image_set': 'ImageSets'}
   sparse_features = [
       'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated', 'is_difficult'
   ]
   self.writer = TFRecordWriter(CLF_GEN,
                                self.serializer,
                                self.tempdir,
                                'train',
                                1,
                                sparse_features=sparse_features,
                                **gen_kwargs)
   self.writer.create_records()
   self.parser = DatumParser(self.tempdir)
예제 #7
0
def _test_create_clf_records(path):
    serializer = DatumSerializer('image')
    Path(path).mkdir(parents=True, exist_ok=True)
    clf_gen = image.ClfDatumGenerator('tests/dummy_data/clf')
    gen_kwargs = {'image_set': 'ImageSets'}
    writer = TFRecordWriter(clf_gen,
                            serializer,
                            path,
                            'train',
                            1,
                            sparse_features=None,
                            **gen_kwargs)
    writer.create_records()
    writer = TFRecordWriter(clf_gen,
                            serializer,
                            path,
                            'val',
                            1,
                            sparse_features=None,
                            **gen_kwargs)
    writer.create_records()
예제 #8
0
def get_default_write_configs(
        problem_type: str,
        label_names_file: Optional[str] = None) -> ConfigBase:
    """Returns default write configs for a problem.

  Args:
    problem_type: Type of the problem,  any one from `datum.problems.types`.
    label_names_file: Path to the label name file, required for `IMAGE_DET` problem.

  Returns:
    A `ConfigBase` config object.

  Raises:
    ValueError: If label_names_file is not valid, raised only for `IMAGE_DET` problem.
  """
    serializer = DatumSerializer(
        problem.PROBLEM_PARAMS[problem_type]["serializer"],
        datum_name_to_encoder_fn=datum_name_to_encoder)
    if problem_type == problem.IMAGE_DET:
        if not label_names_file:
            raise ValueError(
                "label_names_file must be provided for `IMAGE_DET` problem.")
        if not Path(label_names_file).is_file():
            raise ValueError(
                f"Input {label_names_file} does not exist or not a file.")

        class_map = {
            name: idx + 1
            for idx, name in enumerate(
                open(os.path.join(label_names_file)).read().splitlines())
        }
        gen_config = AttrDict(has_test_annotations=True, class_map=class_map)
    else:
        gen_config = None
    generator = partial(problem.PROBLEM_PARAMS[problem_type]["generator"],
                        gen_config=gen_config)
    write_configs = TFRWriteConfigs()
    write_configs.serializer = serializer
    write_configs.generator = generator
    return write_configs
예제 #9
0
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

from datum.encoder.encoder import datum_name_to_encoder
from datum.generator.image import SegDatumGenerator
from datum.serializer.serializer import DatumSerializer
from datum.utils.common_utils import AttrDict

config = {
    'generator':
    partial(SegDatumGenerator, gen_config=None),
    'sprase_features': [],
    'serializer':
    DatumSerializer('image', datum_name_to_encoder_fn=datum_name_to_encoder),
    'splits': ['train', 'val'],
    'num_examples': {
        'train': 1,
        'val': 1,
    }
}

cnf = AttrDict(config)
예제 #10
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

from datum.encoder.encoder import datum_name_to_encoder
from datum.generator.text import TextJsonDatumGenerator
from datum.serializer.serializer import DatumSerializer
from datum.utils.common_utils import AttrDict

_SEED = 5052020

config = {
    'generator':
    partial(TextJsonDatumGenerator, gen_config=None),
    'sprase_features': [],
    'serializer':
    DatumSerializer('text', datum_name_to_encoder_fn=datum_name_to_encoder),
    'splits': ['test'],
    'num_examples': {
        'train': 1,
        'val': 1,
    },
}

cnf = AttrDict(config)
예제 #11
0
 def setUp(self):
   self.serializer = DatumSerializer('image')
   self.tempdir = tempfile.mkdtemp()
   Path(self.tempdir).mkdir(parents=True, exist_ok=True)