Example #1
0
 def __init__(self,
              rec_dir,
              rec_prefix,
              batch_size=0,
              shuffle=True,
              aug_list=None,
              devices=None):
     self.rec_dir = rec_dir
     self.rec_prefix = rec_prefix
     self.batch_size = batch_size
     self.shuffle = shuffle
     self.cursor = 0
     self.aug_list = aug_list
     self.devices = devices if devices is not None else mx.cpu()
     self.idx_list = []
     self.rec_handler = MXRec(rec_dir=self.rec_dir, prefix=self.rec_prefix)
     self.parse_idx_file()
     self.max_index = len(self.idx_list)
Example #2
0
class RecDataIterV1(object):
    def __init__(self,
                 rec_dir,
                 rec_prefix,
                 batch_size=0,
                 shuffle=True,
                 aug_list=None,
                 devices=None):
        self.rec_dir = rec_dir
        self.rec_prefix = rec_prefix
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.cursor = 0
        self.aug_list = aug_list
        self.devices = devices if devices is not None else mx.cpu()
        self.idx_list = []
        self.rec_handler = MXRec(rec_dir=self.rec_dir, prefix=self.rec_prefix)
        self.parse_idx_file()
        self.max_index = len(self.idx_list)

    def __iter__(self):
        return self

    def reset(self):
        self.cursor = 0

    def next(self):
        if self.iter_next():
            data = self.getdata()
            return mx.io.DataBatch(data=[data[0]],
                                   label=[data[1]],
                                   pad=self.getpad(),
                                   index=self.getindex())
        else:
            raise StopIteration

    def __next__(self):
        return self.next()

    def iter_next(self):
        if self.cursor + self.batch_size > self.max_index:
            return False
        return True

    def getdata(self):
        data, target = None, None
        if self.iter_next():
            batch_items = self.idx_list[self.cursor:(self.cursor +
                                                     self.batch_size)]
            batch_data = []
            for batch_item in batch_items:
                batch_data.append(self.rec_handler.read_rec(batch_item))
            data_list, target_list = [], []
            for i in range(len(batch_data)):
                image = batch_data[i]['image']
                profile = batch_data[i]['profile']
                points = batch_data[i]['points']
                # print('image: {}'.format(image.shape))
                # print('profile: {}  {}'.format(profile.shape, profile.dtype))
                # print('points: {}'.format(points.shape))
                # image = image[np.newaxis, :, :, :]
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                image = np.transpose(image, (2, 0, 1))
                # profile = profile[:, :, 2]
                profile = profile[np.newaxis, :, :]
                # image = np.concatenate((image, profile), 0)
                # print('image: {}'.format(image.shape))
                # print('profile: {}'.format(profile.shape))
                if self.aug_list:
                    for aug in self.aug_list:
                        image = aug.process(image)
                data_list.append(image)
                target_list.append(points)
            data_array = np.array(data_list)
            target_array = np.array(target_list)
            data = mx.nd.array(data_array, ctx=self.devices)
            target = mx.nd.array(target_array, ctx=self.devices)
        else:
            raise StopIteration
        # return mx.io.DataBatch(data=[data], label=[target])
        self.cursor += self.batch_size
        return data, target

    def getlabel(self):
        pass

    def getindex(self):
        return None

    def getpad(self):
        pass

    def parse_idx_file(self):
        if len(self.idx_list) > 0:
            self.idx_list.clear()
        idx_file = os.path.join(self.rec_dir, self.rec_prefix + '.idx')
        with open(idx_file, 'r') as lf:
            lines = lf.readlines()
        for line in lines:
            line = line.strip()
            idx, num_bytes = line.split('\t')
            idx, num_bytes = int(idx), int(num_bytes)
            self.idx_list.append(idx)