Exemplo n.º 1
0
    def __init__(self,
                 optimizer,
                 model,
                 training_dataloader,
                 validation_dataloader,
                 log_dir=False,
                 max_epoch=100,
                 resume=False,
                 persist_stride=1,
                 verbose=False):

        self.start_epoch = 1
        self.current_epoch = 1

        self.verbose = verbose
        self.max_epoch = max_epoch
        self.persist_stride = persist_stride

        # initialize log
        self.log_dir = log_dir
        log_file = os.path.join(self.log_dir, 'log.txt')
        logging.basicConfig(filename=log_file, level=logging.DEBUG)
        if not self.log_dir:
            self.log_dir = os.path.join(
                os.path.dirname(os.path.realpath(__file__)), 'logs')
        if not os.path.isdir(self.log_dir):
            os.mkdir(self.log_dir)

        # initialize model
        self.optimizer = optimizer
        self.model = model.float().to(device)
        self.model.load_state_dict(model_zoo.load_url(
            Config.VGG16_PRETRAINED_WEIGHTS),
                                   strict=False)
        self.resume = str(resume) if resume else False

        self.training_dataloader = training_dataloader
        self.validation_dataloader = validation_dataloader

        # initialize anchors
        self.anchors = np.vstack(
            list(
                map(
                    lambda x: np.array(x),
                    generate_anchors(Config.ANCHOR_STRIDE, Config.ANCHOR_SIZE,
                                     Config.IMAGE_SIZE))))
        self.anchors_coord_changed = change_coordinate(self.anchors)
        self.len_anchors = len(self.anchors)

        # resume from some model
        if self.resume:
            state_file = seek_model(self.resume)

            print("loading checkpoint {}".format(state_file))
            checkpoint = torch.load(state_file)
            self.start_epoch = self.current_epoch = checkpoint['epoch']
            self.model.load_state_dict(checkpoint['state_dict'], strict=True)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            print("loaded checkpoint {} (epoch {})".format(
                state_file, self.current_epoch))
Exemplo n.º 2
0
    def forward(self, x):
        if self.training:
            x, targets = x

        features = []
        features.extend(self.backbone(x))

        cls_heads = [self.cls_head(t) for t in features]
        box_heads = [self.box_head(t) for t in features]

        if self.training:
            return self.loss(x, cls_heads, box_heads, targets.float())

        cls_heads = [cls_head.sigmoid() for cls_head in cls_heads]

        # inference post-processing
        decoded = []
        for cls_head, box_head in zip(cls_heads, box_heads):
            stride = x.shape[-1] // cls_head.shape[-1]
            # generate each level's anchors
            if stride not in self.anchors:
                self.anchors[stride] = generate_anchors(
                    stride, self.ratios, self.scales)
            # decode and filter boxes
            decoded.append(
                decode(cls_head, box_head, stride, self.threshold, self.top_n,
                       self.anchors[stride]))
        decoded = [torch.cat(tensors, 1) for tensors in zip(*decoded)]
        return nms(*decoded, self.nms, self.detections)
Exemplo n.º 3
0
 def extract_targets(self, targets, stride, size):
     cls_target, box_target, depth = [], [], []
     for target in targets:
         target = target[target[:, -1] >
                         -1]  # ignore the padding target in dataset
         if stride not in self.anchors:
             self.anchors[stride] = generate_anchors(
                 stride, self.ratios, self.scales)
         snapped = snap_to_anchors(target, [s * stride
                                            for s in size[::-1]], stride,
                                   self.anchors[stride].to(targets.device),
                                   self.classes)
         for l, s in zip((cls_target, box_target, depth), snapped):
             l.append(s)
     return torch.stack(cls_target), torch.stack(box_target), torch.stack(
         depth)
Exemplo n.º 4
0
    def infer(self, image):
        image = cv2.imread(image)
        scale = (image.shape[0] / self.image_size,
                 image.shape[1] / self.image_size)

        image = cv2.resize(image, (self.image_size, ) * 2)
        _input = torch.tensor(image).permute(2, 0,
                                             1).unsqueeze(0).float().to(device)

        predictions = self.model(_input)
        # flatten predictions
        for index, prediction in enumerate(predictions):
            predictions[index] = prediction.view(6, -1).permute(1, 0)
        predictions = torch.cat(predictions)

        # get sorted indices by score
        diff = predictions[:, 5] - predictions[:, 4]
        scores, sorted_indices = torch.sort(diff, descending=True)
        valid_indices = scores > self.threshold
        scores = scores[valid_indices]

        predictions = predictions[sorted_indices][valid_indices]
        # generate anchors then sort and slice
        anchor_configs = (Config.ANCHOR_STRIDE, Config.ANCHOR_SIZE,
                          Config.IMAGE_SIZE)
        anchors = change_coordinate(
            np.vstack(
                list(
                    map(lambda x: np.array(x),
                        generate_anchors(*anchor_configs)))))
        anchors = torch.tensor(
            anchors)[sorted_indices][valid_indices].float().to(device)

        x = (predictions[:, 0] * anchors[:, 2] + anchors[:, 0]) * scale[1]
        y = (predictions[:, 1] * anchors[:, 3] + anchors[:, 1]) * scale[0]
        w = (torch.exp(predictions[:, 2]) * anchors[:, 2]) * scale[1]
        h = (torch.exp(predictions[:, 3]) * anchors[:, 3]) * scale[0]

        bounding_boxes = torch.stack((x, y, w, h), dim=1).cpu().data.numpy()
        bounding_boxes = change_coordinate_inv(bounding_boxes)
        scores = scores.cpu().data.numpy()
        bboxes_scores = np.hstack((bounding_boxes, np.array([scores]).T))

        # nms
        keep = nms(bboxes_scores)

        return bounding_boxes[keep]
Exemplo n.º 5
0
    def __init__(self,
                 model,
                 image_size=Config.IMAGE_SIZE,
                 threshold=Config.PREDICTION_THRESHOLD):
        checkpoint = torch.load(seek_model(model))
        self.model = Net().to(device)
        self.model.load_state_dict(checkpoint['state_dict'], strict=True)
        self.threshold = threshold
        self.image_size = image_size

        anchor_configs = (Config.ANCHOR_STRIDE, Config.ANCHOR_SIZE,
                          Config.IMAGE_SIZE)
        self.anchors = torch.tensor(
            change_coordinate(
                np.vstack(
                    list(
                        map(lambda x: np.array(x),
                            generate_anchors(*anchor_configs)))))).float()