Exemplo n.º 1
0
def _load_validation_data(validation_leveldb, width, height):
    """
    Loads all of our validation data from our leveldb database, producing unrolled numpy input
    vectors ready to test along with their correct, expected target values.
    """

    print "\tLoading validation data..."
    input_vectors = []
    expected_targets = []

    db = plyvel.DB(validation_leveldb)
    for key, value in db:
        datum = Datum()
        datum.ParseFromString(value)

        data = np.fromstring(datum.data, dtype=np.uint8)
        data = np.reshape(data, (3, height, width))
        # Move the color channel to the end to match what Caffe wants.
        data = np.swapaxes(data, 0, 2)  # Swap channel with width.
        data = np.swapaxes(
            data, 0,
            1)  # Swap width with height, to yield final h x w x channel.

        input_vectors.append(data)
        expected_targets.append(datum.label)

    db.close()

    print "\t\tValidation data has %d images" % len(input_vectors)

    return {
        "input_vectors": np.asarray(input_vectors),
        "expected_targets": np.asarray(expected_targets)
    }
Exemplo n.º 2
0
    def __getitem__(self, index):
        if index > len(self.train_list):
            raise IndexError("index exceeds the max-length of the dataset.")
        lmdb_key1, lmdb_key2, label = self.train_list[index]
        datum = Datum()
        real_byte_buffer = self.txn.get(lmdb_key1.encode('utf-8'))
        raw_real_byte_buffer2 = self.txn.get(lmdb_key2.encode('utf-8'))
        datum.ParseFromString(real_byte_buffer)
        image1 = cv2.imdecode(np.fromstring(datum.data, dtype=np.uint8), -1)
        datum.ParseFromString(raw_real_byte_buffer2)
        image2 = cv2.imdecode(np.fromstring(datum.data, dtype=np.uint8), -1)

        image1 = transform(image1)
        image2 = transform(image2)
        img = torch.cat([image1,image2], dim=0)

        return img, label
Exemplo n.º 3
0
    def __getitem__(self, index):
        lmdb_key, label, db_id = self.train_list[index]
        datum = Datum()
        raw_byte_buffer = self.txns[db_id].get(lmdb_key.encode('utf-8'))
        datum.ParseFromString(raw_byte_buffer)
        cv_img = cv2.imdecode(np.frombuffer(datum.data, dtype=np.uint8), -1)
        if random.random() < 0.5:
            cv_img = cv2.flip(cv_img, 1)

        if cv_img.ndim == 2:
            rows = cv_img.shape[0]
            cols = cv_img.shape[1]
            buf = np.zeros((3, rows, cols), dtype=np.uint8)
            buf[0] = buf[1] = buf[2] = cv_img
            input_tensor = (torch.from_numpy(buf) - 127.5) * 0.0078125
        else:
            assert cv_img.ndim == 3
            cv_img = np.transpose(cv_img, (2, 0, 1))
            input_tensor = (torch.from_numpy(cv_img) - 127.5) * 0.0078125

        return input_tensor, label
Exemplo n.º 4
0
    def get_train_mb(self, mb_size, cropped_size=227):
        env = lmdb.open(self.train_db, readonly=True)
        # print env.stat()
        samples = np.zeros([mb_size, cropped_size**2 * 3], dtype=np.float32)
        labels = np.zeros([mb_size, 1000], dtype=np.float32)
        count = 0
        with env.begin(write=False, buffers=False) as txn:
            cursor = txn.cursor()
            for key, value in cursor:
                d = Datum()
                d.ParseFromString(value)
                #print '#channels=', d.channels, 'height=', d.height, 'width=', d.width, 'label=', d.label
                im = np.fromstring(d.data, dtype=np.uint8).reshape(
                    [3, 256, 256]) - self.mean_data

                [crop_h, crop_w] = np.random.randint(256 - cropped_size,
                                                     size=2)

                im_cropped = im[:, crop_h:crop_h + cropped_size,
                                crop_w:crop_w + cropped_size]
                '''
                iim = np.transpose(im_cropped.reshape(cropped_size*cropped_size*3).reshape([3, cropped_size*cropped_size])).reshape([cropped_size, cropped_size, 3])
                img = Image.fromarray(iim)
                img.save('cropimg.jpg', format='JPEG')
                exit(0)
                '''

                samples[count, :] = im_cropped.reshape(cropped_size**2 *
                                                       3).astype(np.float32)
                labels[count, d.label] = 1
                count = count + 1
                if count == mb_size:
                    yield (samples, labels)
                    #samples = np.zeros([mb_size, cropped_size ** 2 * 3])
                    labels = np.zeros([mb_size, 1000], dtype=np.float32)
                    count = 0
        if count != mb_size:
            delete_idx = np.arange(count, mb_size)
            yield (np.delete(samples, delete_idx,
                             0), np.delete(labels, delete_idx, 0))