コード例 #1
0
ファイル: net.py プロジェクト: lematt1991/ssd.pytorch
class SSD:
    name = NAME

    @classmethod
    def mk_hash(cls, path):
        '''
        Create an MD5 hash from a models weight file.
        Arguments:
            path : str - path to TensorBox checkpoint
        '''
        dirs = path.split('/')
        if 'ssd.pytorch' in dirs:
            dirs = dirs[dirs.index('ssd.pytorch'):]
            path = '/'.join(dirs)
        else:
            path = os.path.join('ssd.pytorch', path)

        md5 = hashlib.md5()
        md5.update(path.encode('utf-8'))
        return md5.hexdigest()

    @classmethod
    def zip_weights(cls, path, base_dir='./'):
        if os.path.splitext(path)[1] != '.pth':
            raise ValueError('Invalid checkpoint')

        dirs = path.split('/')

        res = {
            'name': 'TensorBox',
            'instance': '_'.join(dirs[-2:]),
            'id': cls.mk_hash(path)
        }

        zipfile = os.path.join(base_dir, res['id'] + '.zip')

        if os.path.exists(zipfile):
            os.remove(zipfile)

        weight_dir = os.path.dirname(path)

        with ZipFile(zipfile, 'w') as z:
            z.write(path, os.path.join(res['id'], os.path.basename(file)))

        return zipfile

    def __init__(self, weights, classes=['building']):
        self.net = Retina(classes).eval().cuda()
        chkpnt = torch.load(weights)
        self.net.load_state_dict(chkpnt['state_dict'])
        self.transform = transforms.Compose([
            transforms.Resize((300, 300)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def predict_image(self, image, threshold, eval_mode=False):
        """
        Infer buildings for a single image.
        Inputs:
            image :: n x m x 3 ndarray - Should be in RGB format
        """

        t0 = time.time()
        img = self.transform(image)
        out = self.net(Variable(img.unsqueeze(0).cuda(),
                                volatile=True)).squeeze().data.cpu()
        total_time = time.time() - t0

        scores = out[:, :,
                     0]  # class X top K X (score, minx, miny, maxx, maxy)

        max_scores, inds = scores.max(dim=0)

        linear = torch.arange(0, out.shape[1]).long()
        boxes = out[inds, linear].numpy()
        boxes[:, (1, 3)] = np.clip(boxes[:, (1, 3)] * image.width,
                                   a_min=0,
                                   a_max=image.width)
        boxes[:, (2, 4)] = np.clip(boxes[:, (2, 4)] * image.height,
                                   a_min=0,
                                   a_max=image.height)

        df = pandas.DataFrame(boxes, columns=['score', 'x1', 'y1', 'x2', 'y2'])

        if eval_mode:
            return df[df['score'] > threshold], df, total_time
        else:
            return df[df['score'] > threshold]

        pdb.set_trace()

    def predict_all(self, test_boxes_file, threshold, data_dir=None):
        test_boxes = json.load(open(test_boxes_file))
        true_annolist = al.parse(test_boxes_file)
        if data_dir is None:
            data_dir = os.path.join(os.path.dirname(test_boxes_file))

        total_time = 0.0

        for i in range(len(true_annolist)):
            true_anno = true_annolist[i]

            orig_img = imread('%s/%s' %
                              (data_dir, true_anno.imageName))[:, :, :3]

            pred, all_rects, time = self.predict_image(orig_img,
                                                       threshold,
                                                       eval_mode=True)

            pred['image_id'] = i
            all_rects['image_id'] = i

            yield pred, all_rects, test_boxes[i]
コード例 #2
0
class RetinaNet:
    name = NAME

    @classmethod
    def mk_hash(cls, path):
        '''
        Create an MD5 hash from a models weight file.
        Arguments:
            path : str - path to RetinaNet checkpoint
        '''
        dirs = path.split('/')
        if 'retina_net' in dirs:
            dirs = dirs[dirs.index('retina_net'):]
            path = '/'.join(dirs)
        else:
            path = os.path.join('retina_net', path)

        md5 = hashlib.md5()
        md5.update(path.encode('utf-8'))
        return md5.hexdigest()

    @classmethod
    def zip_weights(cls, path, base_dir='./'):
        if os.path.splitext(path)[1] != '.pth':
            raise ValueError('Invalid checkpoint')

        dirs = path.split('/')

        res = {
            'name' : 'RetinaNet',
            'instance' : '_'.join(dirs[-2:]),
            'id' : cls.mk_hash(path)
        }

        zipfile = os.path.join(base_dir, res['id'] + '.zip')

        if os.path.exists(zipfile):
            os.remove(zipfile)

        weight_dir = os.path.dirname(path)

        with ZipFile(zipfile, 'w') as z:
            z.write(path, os.path.join(res['id'], os.path.basename(path)))

        return zipfile

    def __init__(self, weights, classes=['building'], cuda = True):
        chkpnt = torch.load(weights)
        self.config = chkpnt['args']
        self.net = Retina(self.config).eval()
        self.net.load_state_dict(chkpnt['state_dict'])
        self.transform = transforms.Compose([
            transforms.Resize((self.config.model_input_size, self.config.model_input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.net = self.net.cuda()
        self.net.anchors.anchors = self.net.anchors.anchors.cuda()
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        self.cuda = cuda

    def predict_image(self, image, eval_mode = False):
        """
        Infer buildings for a single image.
        Inputs:
            image :: n x m x 3 ndarray - Should be in RGB format
        """

        t0 = time.time()
        img = self.transform(image)
        if self.cuda:
            img = img.cuda()

        out = self.net(Variable(img.unsqueeze(0), requires_grad=False)).squeeze().data.cpu().numpy()
        total_time = time.time() - t0
        
        out = out[1] # ignore background class

        out[:, (1, 3)] = np.clip(out[:, (1, 3)] * image.width, a_min=0, a_max=image.width)
        out[:, (2, 4)] = np.clip(out[:, (2, 4)] * image.height, a_min=0, a_max=image.height)

        out = out[out[:, 0] > 0]

        return pandas.DataFrame(out, columns=['score', 'x1' ,'y1', 'x2', 'y2'])

    def predict_all(self, test_boxes_file, batch_size=8, data_dir = None):
        if data_dir is None:
            data_dir = os.path.join(os.path.dirname(test_boxes_file))
        
        annos = json.load(open(test_boxes_file))

        total_time = 0.0

        for batch in range(0, len(annos), batch_size):
            images,  sizes = [], []
            for i in range(min(batch_size, len(annos) - batch)):
                img = Image.open(os.path.join(data_dir, annos[batch + i]['image_path']))
                images.append(self.transform(img))
                sizes.append(torch.Tensor([img.width, img.height]))

            images = torch.stack(images)
            sizes = torch.stack(sizes)

            if self.cuda:
                images = images.cuda()
                sizes = sizes.cuda()

            out = self.net(Variable(images, requires_grad=False)).data

            hws = torch.cat([sizes, sizes], dim=1).view(-1, 1, 1, 4).expand(-1, out.shape[1], out.shape[2], -1)

            out[:, :, :, 1:] *= hws
            out = out[:, 1, :, :].cpu().numpy()

            for i, detections in enumerate(out):
                anno = annos[batch + i]
                pred = cv2.imread('../data/' + anno['image_path'])

                detections = detections[detections[:, 0] > 0]
                df = pandas.DataFrame(detections, columns=['score', 'x1', 'y1', 'x2', 'y2'])
                df['image_id'] = anno['image_path']

                truth = pred.copy()

                for box in df[['x1', 'y1', 'x2', 'y2']].values.round().astype(int):
                    cv2.rectangle(pred, tuple(box[:2]), tuple(box[2:4]), (0,0,255))

                for r in anno['rects']:
                    box = list(map(lambda x: int(r[x]), ['x1', 'y1', 'x2', 'y2']))
                    cv2.rectangle(truth, tuple(box[:2]), tuple(box[2:]), (0, 0, 255))

                data = np.concatenate([pred, truth], axis=1)
                cv2.imwrite('samples/image_%d.jpg' % (batch + i), data)

                yield df