Esempio n. 1
0
    def __init__(self, classes):
        super(SSD, self).__init__()
        self.size = cfg.TRAIN.COMMON.INPUT_SIZE
        self.classes = classes
        self.num_classes = len(self.classes)
        self.priors_cfg = self._init_prior_cfg()
        self.priorbox = PriorBox(self.priors_cfg)
        self.priors_xywh = Variable(self.priorbox.forward(), volatile=True)
        self.priors = torch.cat([
            self.priors_xywh[:, 0:1] - 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] - 0.5 * self.priors_xywh[:, 3:4],
            self.priors_xywh[:, 0:1] + 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] + 0.5 * self.priors_xywh[:, 3:4]
        ], 1)

        self.priors = self.priors * self.size
        # Layer learns to scale the l2 normalized features from conv4_3
        self.L2Norm = L2Norm(512, 20)

        self.softmax = nn.Softmax(dim=-1)

        self._isex = cfg.TRAIN.VMRN.ISEX
        self.VMRN_rel_op2l = _OP2L(cfg.VMRN.OP2L_POOLING_SIZE, cfg.VMRN.OP2L_POOLING_SIZE, 1.0/8.0, True)

        self._train_iter_conter = 0

        self.criterion = MultiBoxLoss(self.num_classes)
    def __init__(self, classes):
        super(_SSD, self).__init__()

        self.size = cfg.TRAIN.COMMON.INPUT_SIZE
        self.classes = classes
        self.num_classes = len(self.classes)
        self.priors_cfg = self._init_prior_cfg()
        self.priorbox = PriorBox(self.priors_cfg)
        self.priors_xywh = Variable(self.priorbox.forward())
        self.priors_xywh.detach()

        self.priors = torch.cat([
            self.priors_xywh[:, 0:1] - 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] - 0.5 * self.priors_xywh[:, 3:4],
            self.priors_xywh[:, 0:1] + 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] + 0.5 * self.priors_xywh[:, 3:4]
        ], 1)

        self.priors = self.priors * self.size
        # Layer learns to scale the l2 normalized features from conv4_3
        self.L2Norm = L2Norm(512, 20)
        self.softmax = nn.Softmax(dim=-1)

        self.criterion = MultiBoxLoss(self.num_classes)
Esempio n. 3
0
class SSD(nn.Module):
    """Single Shot Multibox Architecture
    The network is composed of a base VGG network followed by the
    added multibox conv layers.  Each multibox layer branches into
        1) conv2d for class conf scores
        2) conv2d for localization predictions
        3) associated priorbox layer to produce default bounding
           boxes specific to the layer's feature map size.
    See: https://arxiv.org/pdf/1512.02325.pdf for more details.
    Args:
        phase: (string) Can be "test" or "train"
        size: input image size
        base: VGG16 layers for input, size of either 300 or 500
        extras: extra layers that feed to multibox loc and conf layers
        head: "multibox head" consists of loc and conf conv layers
    """

    def __init__(self, classes):
        super(SSD, self).__init__()
        self.size = cfg.TRAIN.COMMON.INPUT_SIZE
        self.classes = classes
        self.num_classes = len(self.classes)
        self.priors_cfg = self._init_prior_cfg()
        self.priorbox = PriorBox(self.priors_cfg)
        self.priors_xywh = Variable(self.priorbox.forward(), volatile=True)
        self.priors = torch.cat([
            self.priors_xywh[:, 0:1] - 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] - 0.5 * self.priors_xywh[:, 3:4],
            self.priors_xywh[:, 0:1] + 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] + 0.5 * self.priors_xywh[:, 3:4]
        ], 1)

        self.priors = self.priors * self.size
        # Layer learns to scale the l2 normalized features from conv4_3
        self.L2Norm = L2Norm(512, 20)

        self.softmax = nn.Softmax(dim=-1)

        self._isex = cfg.TRAIN.VMRN.ISEX
        self.VMRN_rel_op2l = _OP2L(cfg.VMRN.OP2L_POOLING_SIZE, cfg.VMRN.OP2L_POOLING_SIZE, 1.0/8.0, True)

        self._train_iter_conter = 0

        self.criterion = MultiBoxLoss(self.num_classes)

    def forward(self, x, im_info, gt_boxes, num_boxes, rel_mat):
        """Applies network layers and ops on input image(s) x.
        Args:
            x: input image or batch of images. Shape: [batch,3,300,300].
        Return:
            Depending on phase:
            test:
                Variable(tensor) of output class label predictions,
                confidence score, and corresponding location predictions for
                each object detected. Shape: [batch,topk,7]
            train:
                list of concat outputs from:
                    1: confidence layers, Shape: [batch*num_priors,num_classes]
                    2: localization layers, Shape: [batch,num_priors*4]
                    3: priorbox layers, Shape: [2,num_priors*4]
        """

        self._train_iter_conter += 1

        sources = list()
        loc = list()
        conf = list()

        self.batch_size = x.size(0)

        # apply vgg up to conv4_3 relu
        if isinstance(self.base, nn.ModuleList):
            for k,v in enumerate(self.base):
                x = v(x)
        else:
            x = self.base(x)

        s = self.L2Norm(x)
        sources.append(s)
        base_feat = s

        # apply vgg up to fc7
        if isinstance(self.conv5, nn.ModuleList):
            for k,v in enumerate(self.conv5):
                x = v(x)
        else:
            x = self.conv5(x)
        sources.append(x)

        # apply extra layers and cache source layer outputs
        for k, v in enumerate(self.extras):
            x = F.relu(v(x), inplace=True)
            if k % 2 == 1:
                sources.append(x)

        # apply multibox head to source layers
        for (x, l, c) in zip(sources, self.loc, self.conf):
            loc.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf.append(c(x).permute(0, 2, 3, 1).contiguous())

        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)

        loc = loc.view(loc.size(0), -1, 4)
        conf = conf.view(conf.size(0), -1, self.num_classes)

        SSD_loss_cls = 0
        SSD_loss_bbox = 0
        if self.training:
            predictions = (
                loc,
                conf,
                self.priors.type_as(loc)
            )
            # targets = torch.cat([gt_boxes[:,:,:4] / self.size, gt_boxes[:,:,4:5]],dim=2)
            targets = gt_boxes
            SSD_loss_bbox, SSD_loss_cls = self.criterion(predictions, targets, num_boxes)

        conf = self.softmax(conf)

        # online data
        if self.training:
            if self._train_iter_conter > cfg.TRAIN.VMRN.ONLINEDATA_BEGIN_ITER:
                obj_rois, obj_num = self._obj_det(conf, loc, self.batch_size, im_info)
                obj_rois = obj_rois.type_as(gt_boxes)
                obj_num = obj_num.type_as(num_boxes)
            else:
                obj_rois = torch.FloatTensor([]).type_as(gt_boxes)
                obj_num = torch.LongTensor([]).type_as(num_boxes)
            obj_labels = None
        else:
            # when testing, this is object detection results
            # TODO: SUPPORT MULTI-IMAGE BATCH
            obj_rois, obj_num = self._obj_det(conf, loc, self.batch_size, im_info)
            if obj_rois.numel() > 0:
                obj_labels = obj_rois[:, 5]
                obj_rois = obj_rois[:, :5]
                obj_rois = obj_rois.type_as(gt_boxes)
                obj_num = obj_num.type_as(num_boxes)
            else:
                # there is no object detected
                obj_labels = torch.Tensor([]).type_as(gt_boxes).long()
                obj_rois = obj_rois.type_as(gt_boxes)
                obj_num = obj_num.type_as(num_boxes)

        if self.training:
            # offline data
            for i in range(self.batch_size):
                obj_rois = torch.cat([obj_rois,
                                      torch.cat([(i * torch.ones(num_boxes[i].item(), 1)).type_as(gt_boxes),
                                                 (gt_boxes[i][:num_boxes[i]][:, 0:4])], 1)
                                      ])
                obj_num = torch.cat([obj_num, torch.Tensor([num_boxes[i]]).type_as(obj_num)])


        obj_rois = Variable(obj_rois)

        VMRN_rel_loss_cls = 0
        rel_cls_prob = torch.Tensor([]).type_as(obj_rois)
        if (obj_num > 1).sum().item() > 0:

            obj_pair_feat = self.VMRN_rel_op2l(base_feat, obj_rois, self.batch_size, obj_num)
            # obj_pair_feat = obj_pair_feat.detach()
            obj_pair_feat = self._rel_head_to_tail(obj_pair_feat)
            rel_cls_score = self.VMRN_rel_cls_score(obj_pair_feat)

            rel_cls_prob = F.softmax(rel_cls_score)

            self.rel_batch_size = obj_pair_feat.size(0)

            if self.training:
                obj_pair_rel_label = self._generate_rel_labels(obj_rois, gt_boxes, obj_num, rel_mat)
                obj_pair_rel_label = obj_pair_rel_label.type_as(gt_boxes).long()

                rel_not_keep = (obj_pair_rel_label == 0)
                # no relationship is kept
                if (rel_not_keep == 0).sum().item() > 0:
                    rel_keep = torch.nonzero(rel_not_keep == 0).view(-1)

                    rel_cls_score = rel_cls_score[rel_keep]

                    obj_pair_rel_label = obj_pair_rel_label[rel_keep]
                    obj_pair_rel_label -= 1
                    VMRN_rel_loss_cls = F.cross_entropy(rel_cls_score, obj_pair_rel_label)
            else:
                if (not cfg.TEST.VMRN.ISEX) and cfg.TRAIN.VMRN.ISEX:
                    rel_cls_prob = rel_cls_prob[::2, :]

        rel_result = None
        if not self.training:
            if obj_rois.numel() > 0:
                pred_boxes = obj_rois.data[:,1:5]
                pred_boxes[:, 0::2] /= im_info[0][3].item()
                pred_boxes[:, 1::2] /= im_info[0][2].item()
                rel_result = (pred_boxes, obj_labels, rel_cls_prob.data)
            else:
                rel_result = (obj_rois.data, obj_labels, rel_cls_prob.data)

        return loc, conf, rel_result, SSD_loss_bbox, SSD_loss_cls, VMRN_rel_loss_cls

    def _generate_rel_labels(self, obj_rois, gt_boxes, obj_num, rel_mat):

        obj_pair_rel_label = torch.Tensor(self.rel_batch_size).type_as(gt_boxes).zero_().long()
        # generate online data labels
        cur_pair = 0
        for i in range(obj_num.size(0)):
            img_index = i % self.batch_size
            if obj_num[i] <=1 :
                continue
            begin_ind = torch.sum(obj_num[:i])
            overlaps = bbox_overlaps(obj_rois[begin_ind:begin_ind + obj_num[i]][:, 1:5],
                                     gt_boxes[img_index][:, 0:4])
            max_overlaps, max_inds = torch.max(overlaps, 1)
            for o1ind in range(obj_num[i]):
                for o2ind in range(o1ind + 1, obj_num[i]):
                    o1_gt = int(max_inds[o1ind].item())
                    o2_gt = int(max_inds[o2ind].item())
                    if o1_gt == o2_gt:
                        # skip invalid pairs
                        if self._isex:
                            cur_pair += 2
                        else:
                            cur_pair += 1
                        continue
                    # some labels are leaved out when labeling
                    if rel_mat[img_index][o1_gt, o2_gt].item() == 0:
                        if rel_mat[img_index][o2_gt, o1_gt].item() == 3:
                            rel_mat[img_index][o1_gt, o2_gt] = rel_mat[img_index][o2_gt, o1_gt]
                        else:
                            rel_mat[img_index][o1_gt, o2_gt] = 3 - rel_mat[img_index][o2_gt, o1_gt]
                    obj_pair_rel_label[cur_pair] = rel_mat[img_index][o1_gt, o2_gt]

                    cur_pair += 1
                    if self._isex:
                        # some labels are leaved out when labeling
                        if rel_mat[img_index][o2_gt, o1_gt].item() == 0:
                            if rel_mat[img_index][o1_gt, o2_gt].item() == 3:
                                rel_mat[img_index][o2_gt, o1_gt] = rel_mat[img_index][o1_gt, o2_gt]
                            else:
                                rel_mat[img_index][o2_gt, o1_gt] = 3 - rel_mat[img_index][o1_gt, o2_gt]
                        obj_pair_rel_label[cur_pair] = rel_mat[img_index][o2_gt, o1_gt]
                        cur_pair += 1

        return obj_pair_rel_label

    def load_weights(self, base_file):
        other, ext = os.path.splitext(base_file)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            self.load_state_dict(torch.load(base_file,
                                 map_location=lambda storage, loc: storage))
            print('Finished!')
        else:
            print('Sorry only .pth and .pkl files supported.')

    def _obj_det(self, conf, loc, batch_size, im_info):
        det_results = torch.Tensor([]).type_as(loc)
        obj_num = []
        if not self.training:
            det_labels = torch.Tensor([]).type_as(loc).long()

        for i in range(batch_size):
            cur_cls_prob = conf[i:i + 1]
            cur_bbox_pred = loc[i:i + 1]
            cur_im_info = im_info[i:i + 1]
            obj_boxes = self._get_single_obj_det_results(cur_cls_prob, cur_bbox_pred, cur_im_info)
            obj_num.append(obj_boxes.size(0))
            if obj_num[-1] > 0:
                det_results = torch.cat([det_results,
                                         torch.cat([i * torch.ones(obj_boxes.size(0), 1).type_as(det_results),
                                                    obj_boxes], 1)
                                         ], 0)
        return det_results, torch.LongTensor(obj_num)

    def _get_single_obj_det_results(self, cls_prob, bbox_pred, im_info):

        scores = cls_prob.data
        thresh = 0.05  # filter out low confidence boxes for acceleration
        results = []

        if cfg.TEST.COMMON.BBOX_REG:
            # Apply bounding-box regression deltas
            box_deltas = bbox_pred.data
            if cfg.TRAIN.COMMON.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
                # Optionally normalize targets by a precomputed mean and stdev
                box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.COMMON.BBOX_NORMALIZE_STDS).type_as(box_deltas) \
                                + torch.FloatTensor(cfg.TRAIN.COMMON.BBOX_NORMALIZE_MEANS).type_as(box_deltas)
                box_deltas = box_deltas.view(1, -1, 4)
            pred_boxes = bbox_transform_inv(self.priors.type_as(bbox_pred).data, box_deltas, 1)
            pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)
        else:
            # Simply repeat the boxes, once for each class
            pred_boxes = np.tile(self.priors.data, (1, scores.shape[1]))

        scores = scores.squeeze()
        pred_boxes = pred_boxes.squeeze()

        for j in xrange(1, self.num_classes):
            inds = torch.nonzero(scores[:, j] > thresh).view(-1)
            # if there is det
            if inds.numel() > 0:
                cls_scores = scores[:, j][inds]
                _, order = torch.sort(cls_scores, 0, True)
                cls_boxes = pred_boxes[inds, :]

                cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
                # cls_dets = torch.cat((cls_boxes, cls_scores), 1)
                cls_dets = cls_dets[order]
                keep = nms(cls_dets, cfg.TEST.COMMON.NMS)
                cls_dets = cls_dets[keep.view(-1).long()]

                final_keep = torch.nonzero(cls_dets[:, -1] > cfg.TEST.COMMON.OBJ_DET_THRESHOLD).squeeze()
                result = cls_dets[final_keep]

                if result.numel()>0 and result.dim() == 1:
                    result = result.unsqueeze(0)
                # in testing, concat object labels
                if final_keep.numel() > 0:
                    if self.training:
                        result = result[:,:4]
                    else:
                        result = torch.cat([result[:,:4],
                                j * torch.ones(result.size(0),1).type_as(result)],1)
                if result.numel() > 0:
                    results.append(result)

        if len(results):
            final = torch.cat(results, 0)
        else:
            final = torch.Tensor([]).type_as(bbox_pred)
        return final

    def create_architecture(self):
        self._init_modules()
        def weights_init(m):
            def xavier(param):
                init.xavier_uniform(param)
            if isinstance(m, nn.Conv2d):
                xavier(m.weight.data)
                m.bias.data.zero_()
        # initialize newly added layers' weights with xavier method
        self.extras.apply(weights_init)
        self.loc.apply(weights_init)
        self.conf.apply(weights_init)

    def _init_prior_cfg(self):
        prior_cfg = {
            'min_dim': self.size,
            'feature_maps': cfg.SSD.FEATURE_MAPS,
            'min_sizes': cfg.SSD.PRIOR_MIN_SIZE,
            'max_sizes': cfg.SSD.PRIOR_MAX_SIZE,
            'steps': cfg.SSD.PRIOR_STEP,
            'aspect_ratios':cfg.SSD.PRIOR_ASPECT_RATIO,
            'clip':cfg.SSD.PRIOR_CLIP
        }
        return prior_cfg

    def resume_iter(self, epoch, iter_per_epoch):
        self._train_iter_conter = epoch * iter_per_epoch
Esempio n. 4
0
File: SSD.py Progetto: WHGang/python
    def __init__(self, classes, class_agnostic, feat_name, feat_list=('conv3', 'conv4'), pretrained = True):
        super(SSD, self).__init__(classes, class_agnostic, feat_name, feat_list, pretrained)
        self.FeatExt.feat_layer["conv3"][0].ceil_mode = True
        ##### Important to set model to eval mode before evaluation ####
        self.FeatExt.eval()
        rand_img = torch.Tensor(1, 3, 300, 300)
        rand_feat = self.FeatExt(rand_img)
        self.FeatExt.train()
        n_channels = [f.size(1) for f in rand_feat]

        self.size = cfg.SCALES[0]
        self.priors_cfg = self._init_prior_cfg()
        self.priorbox = PriorBox(self.priors_cfg)
        self.priors_xywh = Variable(self.priorbox.forward())
        self.priors_xywh.detach()

        self.priors = torch.cat([
            self.priors_xywh[:, 0:1] - 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] - 0.5 * self.priors_xywh[:, 3:4],
            self.priors_xywh[:, 0:1] + 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] + 0.5 * self.priors_xywh[:, 3:4]
        ], 1)

        self.priors = self.priors * self.size
        # Layer learns to scale the l2 normalized features from conv4_3
        self.L2Norm = L2Norm(512, 20)
        self.softmax = nn.Softmax(dim=-1)
        self.criterion = MultiBoxLoss(self.n_classes)

        mbox_cfg = []
        for i in cfg.SSD.PRIOR_ASPECT_RATIO:
            mbox_cfg.append(2 * len(i) + 2)

        self.extra_conv = nn.ModuleList()
        self.loc = nn.ModuleList()
        self.conf = nn.ModuleList()

        # conv 4_3 detector
        self.loc.append(
            nn.Conv2d(n_channels[0], mbox_cfg[0] * 4 if self.class_agnostic else mbox_cfg[0] * 4 * self.n_classes
                      , kernel_size=3, padding=1))
        self.conf.append(nn.Conv2d(n_channels[0], mbox_cfg[0] * self.n_classes, kernel_size=3, padding=1))

        # conv 7 detector
        self.extra_conv.append(nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=1),
            nn.ReLU(inplace=True)))
        self.loc.append(nn.Conv2d(1024, mbox_cfg[1] * 4 if self.class_agnostic else mbox_cfg[1] * 4 * self.n_classes,
                                  kernel_size=3, padding=1))
        self.conf.append(nn.Conv2d(1024, mbox_cfg[1] * self.n_classes, kernel_size=3, padding=1))

        def add_extra_conv(extra_conv, loc, conf, in_c, mid_c, out_c, downsamp, mbox, n_cls, cag):
            extra_conv.append(nn.Sequential(
                nn.Conv2d(in_c, mid_c, kernel_size=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(mid_c, out_c, kernel_size=3, stride=2 if downsamp else 1, padding=1 if downsamp else 0),
                nn.ReLU(inplace=True),
            ))
            loc.append(nn.Conv2d(out_c, mbox * 4 if cag else mbox * 4 * n_cls, kernel_size=3, padding=1))
            conf.append(nn.Conv2d(out_c, mbox * n_cls, kernel_size=3, padding=1))

        add_extra_conv(self.extra_conv, self.loc, self.conf, 1024, 256, 512, True, mbox_cfg[2], self.n_classes, self.class_agnostic)
        add_extra_conv(self.extra_conv, self.loc, self.conf, 512, 128, 256, True, mbox_cfg[3], self.n_classes, self.class_agnostic)
        add_extra_conv(self.extra_conv, self.loc, self.conf, 256, 128, 256, False, mbox_cfg[4], self.n_classes, self.class_agnostic)
        add_extra_conv(self.extra_conv, self.loc, self.conf, 256, 128, 256, False, mbox_cfg[5], self.n_classes, self.class_agnostic)

        self.iter_counter = 0
Esempio n. 5
0
File: SSD.py Progetto: WHGang/python
class SSD(objectDetector):

    def __init__(self, classes, class_agnostic, feat_name, feat_list=('conv3', 'conv4'), pretrained = True):
        super(SSD, self).__init__(classes, class_agnostic, feat_name, feat_list, pretrained)
        self.FeatExt.feat_layer["conv3"][0].ceil_mode = True
        ##### Important to set model to eval mode before evaluation ####
        self.FeatExt.eval()
        rand_img = torch.Tensor(1, 3, 300, 300)
        rand_feat = self.FeatExt(rand_img)
        self.FeatExt.train()
        n_channels = [f.size(1) for f in rand_feat]

        self.size = cfg.SCALES[0]
        self.priors_cfg = self._init_prior_cfg()
        self.priorbox = PriorBox(self.priors_cfg)
        self.priors_xywh = Variable(self.priorbox.forward())
        self.priors_xywh.detach()

        self.priors = torch.cat([
            self.priors_xywh[:, 0:1] - 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] - 0.5 * self.priors_xywh[:, 3:4],
            self.priors_xywh[:, 0:1] + 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] + 0.5 * self.priors_xywh[:, 3:4]
        ], 1)

        self.priors = self.priors * self.size
        # Layer learns to scale the l2 normalized features from conv4_3
        self.L2Norm = L2Norm(512, 20)
        self.softmax = nn.Softmax(dim=-1)
        self.criterion = MultiBoxLoss(self.n_classes)

        mbox_cfg = []
        for i in cfg.SSD.PRIOR_ASPECT_RATIO:
            mbox_cfg.append(2 * len(i) + 2)

        self.extra_conv = nn.ModuleList()
        self.loc = nn.ModuleList()
        self.conf = nn.ModuleList()

        # conv 4_3 detector
        self.loc.append(
            nn.Conv2d(n_channels[0], mbox_cfg[0] * 4 if self.class_agnostic else mbox_cfg[0] * 4 * self.n_classes
                      , kernel_size=3, padding=1))
        self.conf.append(nn.Conv2d(n_channels[0], mbox_cfg[0] * self.n_classes, kernel_size=3, padding=1))

        # conv 7 detector
        self.extra_conv.append(nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=1),
            nn.ReLU(inplace=True)))
        self.loc.append(nn.Conv2d(1024, mbox_cfg[1] * 4 if self.class_agnostic else mbox_cfg[1] * 4 * self.n_classes,
                                  kernel_size=3, padding=1))
        self.conf.append(nn.Conv2d(1024, mbox_cfg[1] * self.n_classes, kernel_size=3, padding=1))

        def add_extra_conv(extra_conv, loc, conf, in_c, mid_c, out_c, downsamp, mbox, n_cls, cag):
            extra_conv.append(nn.Sequential(
                nn.Conv2d(in_c, mid_c, kernel_size=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(mid_c, out_c, kernel_size=3, stride=2 if downsamp else 1, padding=1 if downsamp else 0),
                nn.ReLU(inplace=True),
            ))
            loc.append(nn.Conv2d(out_c, mbox * 4 if cag else mbox * 4 * n_cls, kernel_size=3, padding=1))
            conf.append(nn.Conv2d(out_c, mbox * n_cls, kernel_size=3, padding=1))

        add_extra_conv(self.extra_conv, self.loc, self.conf, 1024, 256, 512, True, mbox_cfg[2], self.n_classes, self.class_agnostic)
        add_extra_conv(self.extra_conv, self.loc, self.conf, 512, 128, 256, True, mbox_cfg[3], self.n_classes, self.class_agnostic)
        add_extra_conv(self.extra_conv, self.loc, self.conf, 256, 128, 256, False, mbox_cfg[4], self.n_classes, self.class_agnostic)
        add_extra_conv(self.extra_conv, self.loc, self.conf, 256, 128, 256, False, mbox_cfg[5], self.n_classes, self.class_agnostic)

        self.iter_counter = 0

    def _get_obj_det_result(self, sources):
        loc = []
        conf = []
        # apply multibox head to source layers
        for (x, l, c) in zip(sources, self.loc, self.conf):
            loc.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf.append(c(x).permute(0, 2, 3, 1).contiguous())

        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)

        loc = loc.view(loc.size(0), -1, 4)
        conf = conf.view(conf.size(0), -1, self.n_classes)
        return loc, conf

    def forward(self, data_batch):
        """Applies network layers and ops on input image(s) x.
        Args:
            x: input image or batch of images. Shape: [batch,3,300,300].
        Return:
            Depending on phase:
            test:
                Variable(tensor) of output class label predictions,
                confidence score, and corresponding location predictions for
                each object detected. Shape: [batch,topk,7]
            train:
                list of concat outputs from:
                    1: confidence layers, Shape: [batch*num_priors,num_classes]
                    2: localization layers, Shape: [batch,num_priors*4]
                    3: priorbox layers, Shape: [2,num_priors*4]
        """
        x = data_batch[0]
        im_info = data_batch[1]
        gt_boxes = data_batch[2]
        num_boxes = data_batch[3]

        if self.training:
            self.iter_counter += 1

        sources = []

        s0, x = self.FeatExt(x)
        s0 = self.L2Norm(s0)
        sources.append(s0)

        for m in self.extra_conv:
            x = m(x)
            sources.append(x)

        loc, conf = self._get_obj_det_result(sources)
        SSD_loss_cls, SSD_loss_bbox = 0, 0
        if self.training:
            predictions = (
                loc,
                conf,
                self.priors.type_as(loc)
            )
            SSD_loss_bbox, SSD_loss_cls = self.criterion(predictions, gt_boxes, num_boxes)
        conf = self.softmax(conf)

        return loc, conf, SSD_loss_bbox, SSD_loss_cls

    def create_architecture(self):
        self._init_modules()
        self._init_weights()

    def _init_modules(self):
        pass

    def _init_weights(self):
        def weights_init(m):
            def xavier(param):
                init.xavier_uniform(param)

            if isinstance(m, nn.Conv2d):
                xavier(m.weight.data)
                m.bias.data.zero_()

        # initialize newly added layers' weights with xavier method
        self.extra_conv.apply(weights_init)
        self.loc.apply(weights_init)
        self.conf.apply(weights_init)

    def _init_prior_cfg(self):
        prior_cfg = {
            'min_dim': self.size,
            'feature_maps': cfg.SSD.FEATURE_MAPS,
            'min_sizes': cfg.SSD.PRIOR_MIN_SIZE,
            'max_sizes': cfg.SSD.PRIOR_MAX_SIZE,
            'steps': cfg.SSD.PRIOR_STEP,
            'aspect_ratios':cfg.SSD.PRIOR_ASPECT_RATIO,
            'clip':cfg.SSD.PRIOR_CLIP
        }
        return prior_cfg
class _SSD(nn.Module):
    """Single Shot Multibox Architecture
    The network is composed of a base VGG network followed by the
    added multibox conv layers.  Each multibox layer branches into
        1) conv2d for class conf scores
        2) conv2d for localization predictions
        3) associated priorbox layer to produce default bounding
           boxes specific to the layer's feature map size.
    See: https://arxiv.org/pdf/1512.02325.pdf for more details.
    Args:
        phase: (string) Can be "test" or "train"
        size: input image size
        base: VGG16 layers for input, size of either 300 or 500
        extras: extra layers that feed to multibox loc and conf layers
        head: "multibox head" consists of loc and conf conv layers
    """

    def __init__(self, classes):
        super(_SSD, self).__init__()

        self.size = cfg.TRAIN.COMMON.INPUT_SIZE
        self.classes = classes
        self.num_classes = len(self.classes)
        self.priors_cfg = self._init_prior_cfg()
        self.priorbox = PriorBox(self.priors_cfg)
        self.priors_xywh = Variable(self.priorbox.forward())
        self.priors_xywh.detach()

        self.priors = torch.cat([
            self.priors_xywh[:, 0:1] - 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] - 0.5 * self.priors_xywh[:, 3:4],
            self.priors_xywh[:, 0:1] + 0.5 * self.priors_xywh[:, 2:3],
            self.priors_xywh[:, 1:2] + 0.5 * self.priors_xywh[:, 3:4]
        ], 1)

        self.priors = self.priors * self.size
        # Layer learns to scale the l2 normalized features from conv4_3
        self.L2Norm = L2Norm(512, 20)
        self.softmax = nn.Softmax(dim=-1)

        self.criterion = MultiBoxLoss(self.num_classes)

    def forward(self, x, im_info, gt_boxes, num_boxes):
        """Applies network layers and ops on input image(s) x.
        Args:
            x: input image or batch of images. Shape: [batch,3,300,300].
        Return:
            Depending on phase:
            test:
                Variable(tensor) of output class label predictions,
                confidence score, and corresponding location predictions for
                each object detected. Shape: [batch,topk,7]
            train:
                list of concat outputs from:
                    1: confidence layers, Shape: [batch*num_priors,num_classes]
                    2: localization layers, Shape: [batch,num_priors*4]
                    3: priorbox layers, Shape: [2,num_priors*4]
        """
        sources = list()
        loc = list()
        conf = list()

        # apply vgg up to conv4_3 relu
        if isinstance(self.base, nn.ModuleList):
            for layer in self.base:
                x = layer(x)
        else:
            x = self.base(x)

        s = self.L2Norm(x)
        sources.append(s)

        # apply vgg up to fc7
        for conv in self.SSD_feat_layers:
            x = conv(x)
        sources.append(x)

        # apply extra layers and cache source layer outputs
        for k, v in enumerate(self.extras):
            x = F.relu(v(x), inplace=True)
            if k % 2 == 1:
                sources.append(x)

        # apply multibox head to source layers
        for (x, l, c) in zip(sources, self.loc, self.conf):
            loc.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf.append(c(x).permute(0, 2, 3, 1).contiguous())

        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)

        loc = loc.view(loc.size(0), -1, 4)
        conf = conf.view(conf.size(0), -1, self.num_classes)

        if self.training:
            predictions = (
                loc,
                conf,
                self.priors.type_as(loc)
            )
            # targets = torch.cat([gt_boxes[:,:,:4] / self.size, gt_boxes[:,:,4:5]],dim=2)
            targets = gt_boxes
            SSD_loss_bbox, SSD_loss_cls = self.criterion(predictions, targets, num_boxes)
        else:
            SSD_loss_cls = 0
            SSD_loss_bbox = 0

        conf = self.softmax(conf)

        return loc, conf, SSD_loss_bbox, SSD_loss_cls

    def create_architecture(self):
        self._init_modules()
        def weights_init(m):
            def xavier(param):
                init.xavier_uniform(param)
            if isinstance(m, nn.Conv2d):
                xavier(m.weight.data)
                m.bias.data.zero_()
        # initialize newly added layers' weights with xavier method
        self.extras.apply(weights_init)
        self.loc.apply(weights_init)
        self.conf.apply(weights_init)

    def _init_prior_cfg(self):
        prior_cfg = {
            'min_dim': self.size,
            'feature_maps': cfg.SSD.FEATURE_MAPS,
            'min_sizes': cfg.SSD.PRIOR_MIN_SIZE,
            'max_sizes': cfg.SSD.PRIOR_MAX_SIZE,
            'steps': cfg.SSD.PRIOR_STEP,
            'aspect_ratios':cfg.SSD.PRIOR_ASPECT_RATIO,
            'clip':cfg.SSD.PRIOR_CLIP
        }
        return prior_cfg