예제 #1
0
def tf_create_data_iterator(
    batch_size, 
    train_set_file=None, 
    val_set_file=None, 
    shape=None, 
    dtype=tf.float32,
    mean_img=None):
    """ """
    def tf_parse_record(tf_record):
        """ """
        feature = tf.parse_single_example(tf_record, features=_IMAGE_TFREC_STRUCTURE)
        image = tf.decode_raw(feature['image'], tf.uint8)
        image = tf.reshape(image, shape)
        # image = tf.image.resize_images(image, [64, 64])
        # image = tf.subtract(image, mean_img)
        label = tf.cast(feature['label'], tf.int64)
        image = tf.cast(image, dtype)
        return image, label

    if shape is None:
        raise ValueError("shape cannot be None")
    if val_set_file is None and train_set_file is None:
        raise ValueError("Both train_set_file and val_set_file are not specified")

    val_set       = None
    train_set     = None
    out_types     = None
    out_shapes    = None
    val_init_op   = None
    train_init_op = None

    # Create the validation dataset object
    if val_set_file is not None:
        val_set = TFRecordDataset(val_set_file)
        val_set = val_set.map(tf_parse_record)
        val_set = val_set.batch(batch_size)
        out_types = val_set.output_types
        out_shapes = val_set.output_shapes

    # Create the training dataset object
    if train_set_file is not None:
        train_set = TFRecordDataset(train_set_file)
        train_set = train_set.map(tf_parse_record)
        train_set = train_set.shuffle(buffer_size=batch_size * 1000)
        train_set = train_set.batch(batch_size)
        out_types = train_set.output_types
        out_shapes = train_set.output_shapes

    # Create a reinitializable iterator from both datasets
    iterator  = Iterator.from_structure(out_types, out_shapes)
    
    if train_set is not None:
        train_init_op   = iterator.make_initializer(train_set)
    
    if val_set is not None:
        val_init_op     = iterator.make_initializer(val_set)

    iter_op = iterator.get_next()
    return train_init_op, val_init_op, iter_op
예제 #2
0
        def dataset(tfrecord, shuffle):
            ds = TFRecordDataset(tfrecord)

            def parse(x):
                example = tf.parse_single_example(
                    x, features={'data': tf.VarLenFeature(tf.int64)})
                example = tf.cast(example['data'].values, tf.int32) + 1
                if max_sequence is not None:
                    example = example[:max_sequence]
                return example

            ds = ds.map(parse, num_parallel_calls=8)
            if shuffle:
                ds = ds.shuffle(buffer_size)
            return ds.padded_batch(batch_size,
                                   padded_shapes=(tf.TensorShape([None])),
                                   padding_values=0)
예제 #3
0
    def _create_data_iterators(self,
                               tf_rec_list,
                               batch_size,
                               dtype=tf.float32):
        """ """
        def tf_parse_record(tf_record):
            """ """
            feature = tf.parse_single_example(tf_record,
                                              features=_IMAGE_TFREC_STRUCTURE)
            image = tf.decode_raw(feature['image'], tf.uint8)
            image = tf.reshape(image, self.shape)
            label = tf.cast(feature['label'], tf.int64)
            return image, label

        data_iter = None
        iter_list = []
        for rec in tf_rec_list:
            rec = DictToAttrs(rec)
            dataset = TFRecordDataset(rec.file)
            dataset = dataset.map(tf_parse_record)
            dataset = dataset.map(lambda image, label:
                                  (rec.preproc(image), label))  #, batch_size)
            if rec.shuffle: dataset = dataset.shuffle(buffer_size=50000)
            dataset = dataset.batch(batch_size)
            out_types = dataset.output_types
            out_shapes = dataset.output_shapes

            if data_iter is None:
                # Create a reinitializable iterator
                data_iter = Iterator.from_structure(out_types, out_shapes)
            iter_init = data_iter.make_initializer(dataset)
            iter_list.append(iter_init)

        return data_iter.get_next(), iter_list
예제 #4
0
def get_tfrecord_loader(
        file_names: List[str],
        batch_size: int = None,
        buffer_size: int = 256,
        epochs: int = None,
        shuffle: bool = True) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
    """Create a dataset to read tfrecord files and return its iterator.

    The iterator expects a list of tfrecord file names to be fed to
    its 'file_names' placeholder.

    :param file_names: tf.placeholder
    :param batch_size: Number of images in each batch
    :param buffer_size: See tf.contrib.Dataset.shuffle
    :param epochs:
    :param shuffle: Whether or not to shuffle the dataset
    :return: Tensor of type Iterator
    """

    feature_map = {
        'image': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64)
    }

    def parse_example_proto(proto: tf.train.Example, image_size: int = 299, channels: int = 3) \
            -> Tuple[Dict[str, tf.Tensor], List[tf.Tensor]]:
        """

        :param proto:
        :param image_size:
        :param channels:
        :return:
        """
        features = tf.parse_single_example(proto, features=feature_map)
        image = tf.decode_raw(features['image'], tf.uint8)
        image = tf.reshape(image, [image_size, image_size, channels])
        image = tf.cast(image, tf.float32)
        image = tf.subtract(image, 116.779)
        label = tf.cast(features['label'], tf.float32)
        return dict(zip(['input_1'], [image])), [label]

    dataset = TFRecordDataset(file_names)
    dataset = dataset.map(parse_example_proto)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=buffer_size)
    dataset = dataset.batch(batch_size)
    dataset.repeat(epochs)
    iterator = dataset.make_one_shot_iterator()
    images, labels = iterator.get_next()
    return images, labels
예제 #5
0
def input_fn(params, sequence_schema, context_schema, part_files):
    dataset = Dataset.from_tensor_slices(part_files).shuffle(len(part_files))
    dataset = dataset.apply(
        parallel_interleave(
            lambda file: TFRecordDataset(file, compression_type='GZIP'),
            cycle_length=params['cycle_length'],
            sloppy=True))
    dataset = dataset.map(partial(parse_example, context_schema,
                                  sequence_schema),
                          num_parallel_calls=cpu_count())
    dataset = dataset.apply(
        shuffle_and_repeat(params['buffer_size'], count=params['epochs']))
    dataset = dataset.batch(params['batch_size'])
    return dataset
예제 #6
0
 def tfrecord2iter(self, num_epochs, batch_size):
     dataset = TFRecordDataset(self.fname)
     dataset = dataset.map(lambda r: self.decode_record(r))
     dataset = dataset.repeat(num_epochs)
     padded_shapes = dict()
     padding_values = dict()
     for field in self.fields:
         padded_shapes[field] = [None]
         if self.fields[field].dtype == tf.int64:
             padding_values[field] = tf.constant(
                 self.vocab.end_token_id, tf.int64)
         elif self.fields[field].dtype == tf.float32:
             padding_values[field] = tf.constant(0.0, tf.float32)
     dataset = dataset.padded_batch(
         batch_size, padded_shapes=padded_shapes, padding_values=padding_values)
     iterator = dataset.make_one_shot_iterator()
     return iterator
예제 #7
0
    def __init__(self,
                 root,
                 transform=None,
                 target_transform=None,
                 is_test=False,
                 keep_difficult=False,
                 batch_size=32,
                 shuffle=True):
        """
        Dataset for TFRecord data.
        Args:
            root: the root of the TFRecord, the directory contains the following files:
                label_map.txt, train.record, val.record, num_train.txt, num_val.txt
        """
        self.root = pathlib.Path(root)
        self.batch_size = batch_size
        self.transform = transform
        self.target_transform = target_transform
        if is_test:
            image_sets_files = [os.path.join(root, "val.record")]
            if os.path.isfile(self.root / "num_val.txt"):
                with open(self.root / "num_val.txt", 'r') as f:
                    self.num_records = int(f.read())
            else:
                logging.critical(
                    "No num_val.txt found. Please create one with the total number of validation images"
                )
                raise SystemExit(-1)
            self.ids = [i for i in range(self.num_records)]
        else:
            image_sets_files = []
            for file_name in os.listdir(self.root):
                if file_name.startswith('train') and file_name.endswith(
                        '.record'):
                    image_sets_files.append(os.path.join(root, file_name))
            if os.path.isfile(self.root / "num_train.txt"):
                with open(self.root / "num_train.txt", 'r') as f:
                    self.num_records = int(f.read())
            else:
                logging.critical(
                    "No num_train.txt found. Please create one with the total number of training images"
                )
                raise SystemExit(-1)

        self.keys_to_features = {
            'image/encoded': FixedLenFeature((), tf.string, default_value=''),
            'image/format': FixedLenFeature((),
                                            tf.string,
                                            default_value='jpeg'),
            'image/filename': FixedLenFeature((), tf.string, default_value=''),
            'image/source_id': FixedLenFeature((), tf.string,
                                               default_value=''),
            'image/height': FixedLenFeature((), tf.int64, default_value=1),
            'image/width': FixedLenFeature((), tf.int64, default_value=1),
            # Object boxes and classes.
            'image/object/bbox/xmin': VarLenFeature(tf.float32),
            'image/object/bbox/xmax': VarLenFeature(tf.float32),
            'image/object/bbox/ymin': VarLenFeature(tf.float32),
            'image/object/bbox/ymax': VarLenFeature(tf.float32),
            'image/object/class/label': VarLenFeature(tf.int64),
            'image/object/class/text': VarLenFeature(tf.string),
            'image/object/difficult': VarLenFeature(tf.int64),
        }

        if shuffle:
            np.random.shuffle(image_sets_files)
        self.dataset = TFRecordDataset(
            image_sets_files,
            num_parallel_reads=3 if len(image_sets_files) > 1 else 1)
        self.dataset = self.dataset.map(self.parse_sample)
        self.keep_difficult = keep_difficult
        self.num_batches = self.num_records // self.batch_size

        # if the labels file exists, read in the class names
        label_file_name = self.root / "label_map.txt"

        self._get_classes(label_file_name)
        self.class_dict = {
            class_name: i
            for i, class_name in enumerate(self.class_names)
        }
예제 #8
0
class RecordDataset(Sequence):
    def __init__(self,
                 root,
                 transform=None,
                 target_transform=None,
                 is_test=False,
                 keep_difficult=False,
                 batch_size=32,
                 shuffle=True):
        """
        Dataset for TFRecord data.
        Args:
            root: the root of the TFRecord, the directory contains the following files:
                label_map.txt, train.record, val.record, num_train.txt, num_val.txt
        """
        self.root = pathlib.Path(root)
        self.batch_size = batch_size
        self.transform = transform
        self.target_transform = target_transform
        if is_test:
            image_sets_files = [os.path.join(root, "val.record")]
            if os.path.isfile(self.root / "num_val.txt"):
                with open(self.root / "num_val.txt", 'r') as f:
                    self.num_records = int(f.read())
            else:
                logging.critical(
                    "No num_val.txt found. Please create one with the total number of validation images"
                )
                raise SystemExit(-1)
            self.ids = [i for i in range(self.num_records)]
        else:
            image_sets_files = []
            for file_name in os.listdir(self.root):
                if file_name.startswith('train') and file_name.endswith(
                        '.record'):
                    image_sets_files.append(os.path.join(root, file_name))
            if os.path.isfile(self.root / "num_train.txt"):
                with open(self.root / "num_train.txt", 'r') as f:
                    self.num_records = int(f.read())
            else:
                logging.critical(
                    "No num_train.txt found. Please create one with the total number of training images"
                )
                raise SystemExit(-1)

        self.keys_to_features = {
            'image/encoded': FixedLenFeature((), tf.string, default_value=''),
            'image/format': FixedLenFeature((),
                                            tf.string,
                                            default_value='jpeg'),
            'image/filename': FixedLenFeature((), tf.string, default_value=''),
            'image/source_id': FixedLenFeature((), tf.string,
                                               default_value=''),
            'image/height': FixedLenFeature((), tf.int64, default_value=1),
            'image/width': FixedLenFeature((), tf.int64, default_value=1),
            # Object boxes and classes.
            'image/object/bbox/xmin': VarLenFeature(tf.float32),
            'image/object/bbox/xmax': VarLenFeature(tf.float32),
            'image/object/bbox/ymin': VarLenFeature(tf.float32),
            'image/object/bbox/ymax': VarLenFeature(tf.float32),
            'image/object/class/label': VarLenFeature(tf.int64),
            'image/object/class/text': VarLenFeature(tf.string),
            'image/object/difficult': VarLenFeature(tf.int64),
        }

        if shuffle:
            np.random.shuffle(image_sets_files)
        self.dataset = TFRecordDataset(
            image_sets_files,
            num_parallel_reads=3 if len(image_sets_files) > 1 else 1)
        self.dataset = self.dataset.map(self.parse_sample)
        self.keep_difficult = keep_difficult
        self.num_batches = self.num_records // self.batch_size

        # if the labels file exists, read in the class names
        label_file_name = self.root / "label_map.txt"

        self._get_classes(label_file_name)
        self.class_dict = {
            class_name: i
            for i, class_name in enumerate(self.class_names)
        }

    def _get_classes(self, label_file_name):
        if os.path.isfile(label_file_name):
            classes = []
            with open(label_file_name, 'r') as infile:
                first_line = infile.readline()
                if first_line.rstrip() == 'item {':
                    parse_tfrecord = True
                else:
                    parse_tfrecord = False
                if not parse_tfrecord:  # Classes are a text file with class label on each line
                    classes.append(first_line)
                    for line in infile:
                        classes.append(line.rstrip())
                else:  # Classes are in tfrecord format
                    for line in infile:
                        line = line.strip()
                        if line.startswith('name: '):
                            line = line.replace('\'', '')
                            classes.append(line[5:].strip())

            if 'BACKGROUND' not in classes:
                classes.insert(0, 'BACKGROUND')

            self.class_names = tuple(classes)
            logging.info("VOC Labels read from file: " + str(self.class_names))
        else:
            logging.info("No labels file, using default VOC classes.")
            self.class_names = ('BACKGROUND', 'aeroplane', 'bicycle', 'bird',
                                'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
                                'cow', 'diningtable', 'dog', 'horse',
                                'motorbike', 'person', 'pottedplant', 'sheep',
                                'sofa', 'train', 'tvmonitor')

    def __getitem__(self, idx):
        inputs, target1, target2 = [], [], []
        for sample in self.dataset.shuffle(500).take(self.batch_size):
            boxes, labels, is_difficult = self._get_annotation(sample)
            if not self.keep_difficult and is_difficult:
                boxes = boxes[is_difficult == 0]
                labels = labels[is_difficult == 0]
            image = self._read_image(sample)
            if self.transform:
                image, boxes, labels = self.transform(image, boxes, labels)
            if self.target_transform:
                boxes, labels = self.target_transform(boxes, labels)
            inputs.append(image)
            target1.append(boxes.numpy())
            target2.append(labels.numpy())

        tmp_inputs = np.array(inputs, dtype=np.float32)
        tmp_target1 = np.array(target1)
        tmp_target2 = np.array(target2)
        tmp_target2 = np.expand_dims(tmp_target2, 2)
        tmp_target = np.concatenate([tmp_target1, tmp_target2], axis=2)
        return tmp_inputs, tmp_target

    def __len__(self):
        return int(np.ceil(self.num_records / float(self.batch_size)))

    def get_annotation(self, idx):
        for sample in self.dataset.skip(idx).take(1):
            return sample['image/source_id'].numpy().decode(
                'utf-8'), self._get_annotation(sample)

    def _get_annotation(self, sample):
        # Get info about each object in image
        num_objects = sample['image/object/bbox/xmax'].shape[0]
        boxes = []
        labels = []
        is_difficult = []
        height = sample['image/height'].numpy()
        width = sample['image/width'].numpy()
        for i in range(num_objects):
            # Undo bbox coord normalization
            x_max = int(sample['image/object/bbox/xmax'].values[i].numpy() *
                        width)
            x_min = int(sample['image/object/bbox/xmin'].values[i].numpy() *
                        width)
            y_max = int(sample['image/object/bbox/ymax'].values[i].numpy() *
                        height)
            y_min = int(sample['image/object/bbox/ymin'].values[i].numpy() *
                        height)
            boxes.append([x_min, y_min, x_max, y_max])

            labels.append(sample['image/object/class/label'].values[i].numpy())
            # is_difficult.append(sample['image/object/difficult'].values[i].numpy())

        return (np.array(boxes,
                         dtype=np.float32), np.array(labels, dtype=np.int64),
                np.array(is_difficult, dtype=np.uint8))

    def _read_image(self, sample):
        return tf.image.decode_image(sample['image/encoded']).numpy()

    def get_image(self, idx):
        for sample in self.dataset.take(1):
            return tf.image.decode_image(sample['image/encoded']).numpy()

    def parse_sample(self, data_record):
        sample = parse_single_example(data_record, self.keys_to_features)
        return sample