Beispiel #1
0
    def test_record_num(self):
        pGen, pKv, pRpn, pRoi, pBbox, pDataset, pModel, pOpt, pTest, \
        transform, data_name, label_name, metric_list = detection_config.get_config(is_train=True)
        roidbs = pkl.load(open("unittest/data/coco_micro_test.roidb", "rb"), encoding="latin1")
        batch_size = 4

        loader = AnchorLoader(
            roidb=roidbs,
            transform=transform,
            data_name=data_name,
            label_name=label_name,
            batch_size=batch_size,
            shuffle=True,
            num_thread=1,
            kv=mx.kvstore.create(pKv.kvstore)
        )

        num_batch = 0
        while True:
            try:
                data_batch = loader.next()
                num_batch += 1
            except StopIteration:
                break
        self.assertEqual(batch_size * num_batch, loader.total_record)
Beispiel #2
0
    def test_empty_h_loader(self):
        pGen, pKv, pRpn, pRoi, pBbox, pDataset, pModel, pOpt, pTest, \
        transform, data_name, label_name, metric_list = detection_config.get_config(is_train=True)
        roidbs = pkl.load(open("unittest/data/coco_micro_test.roidb", "rb"),
                          encoding="latin1")
        all_h_roidbs = [roidb for roidb in roidbs if roidb['h'] < roidb['w']]

        loader = AnchorLoader(roidb=all_h_roidbs,
                              transform=transform,
                              data_name=data_name,
                              label_name=label_name,
                              batch_size=1,
                              shuffle=True,
                              num_thread=1,
                              kv=mx.kvstore.create(pKv.kvstore))
        with self.assertRaises(StopIteration):
            while True:
                data_batch = loader.next()
Beispiel #3
0
            flipped_roidb.append(new_rec)
        roidb = roidb + flipped_roidb

        loader = AnchorLoader(roidb=roidb,
                              transform=transform,
                              data_name=["data", "im_info", "gt_bbox"],
                              label_name=["rpn_cls_label", "rpn_reg_target", "rpn_reg_weight"],
                              batch_size=2,
                              shuffle=False,
                              kv=None)


        tic = time.time()
        while True:
            try:
                data_batch = loader.next()
                if DEBUG:
                    import uuid
                    print(data_batch.provide_data)
                    print(data_batch.provide_label)
                    print(data_batch.data[0].shape)
                    print(data_batch.label[1].shape)
                    print(data_batch.label[2].shape)
                    data = data_batch.data[0]
                    gt_bbox = data_batch.data[2]
                    for i, (im, bbox) in enumerate(zip(data, gt_bbox)):
                        im = im.transpose((1, 2, 0))[:, :, ::-1].asnumpy()
                        im = np.uint8(im)
                        valid_instance = np.where(bbox[:, -1] != -1)[0]
                        bbox = bbox[valid_instance].asnumpy()
                        for j, bbox_j in enumerate(bbox):